IT recording...
[BOJ] 2042 구간 합 구하기 - python 본문
- Segment Tree(세그먼트 트리)
부분합을 선형 계산을 통해 구하려면 O(N)의 시간이 걸린다. 그 후 특정 부분합을 구하려면 O(1)
But, 노드의 값이 빈번히 update된다면 어떨까?
→ 원소의 개수가 N개인 배열이 있을 때 첫번째 값이 M번 변경된다면 O(MN) 번의 시간복잡도가 필요하다.
⇒ Segment Tree를 이용하면 O(logN)의 시간 복잡도 내에 부분합을 구할 수 있다.
: 이진 트리를 바탕으로 트리를 구성하며, update를 해야 할 부분만 업데이트 할 수 있다.
- leaf 노드는 배열이 가지는 값을 가지게 된다.
- 이외의 노드들은 부분합 등 특정 조건을 만족하는 조건으로 이루어진다.
- 트리의 최대 노드 개수
def segmentation_size(node):
height = int(math.ceil(math.log10(node) / math.log10(2))) + 1 # 세그먼트 트리의 높이 구하기 (log노드수/log2 + 1)
size = int(math.pow(2, height)) #트리 최대 노드 개수
return size
→ 트리의 높이를 구한 후 2의 지수승을 통해 최대 노드 개수를 구한다.
- 초기화
def segmentation_init(start, end, index):
if start == end:
tree[index] = num[start-1] # leaf 노드에 도달했으면 배열 값 삽입
return tree[index]
mid = (start + end) // 2 # leaf노드가 아니면 좌,우 가지치기 진행
tree[index] = segmentation_init(start, mid, index * 2) + segmentation_init(mid + 1, end, index * 2 + 1)
return tree[index]
→ leaf 노드에 도달하면 (start==end) 배열의 값을 트리에 삽입하고, 아니라면 가지치기를 계속 진행한다.
→ 여기서는 가지치기를 통해 리프노드가 아닌 노드들에 부분합을 저장하므로, 재귀를 이용하여
tree[index] = init(start,mid,index*2) + init(mid+1,end,index*2+1)
를 통해 tree를 초기화 한다.
- 부분합 구하기
def segmentation_target(start, end, index, target_start, target_end):
if target_start > end or target_end < start: # target 범위가 겹치지 않을 때
return 0
if target_start <= start and end <= target_end: # tree특정 노드의 범위가 target 범위에 포함될 때
return tree[index]
mid = (start + end) // 2 # target의 범위가 일부만 걸쳐있으면 다 찾아내야 하니까 가지치기 진행
return segmentation_target(start, mid, index * 2, target_start, target_end)
+ segmentation_target(mid + 1, end,index * 2 + 1,target_start,target_end)
→ 구하고자 하는 target범위가 노드의 범위와 겹치지 않으면 재귀를 중지한다.
→ 구하고자 하는 target범위가 노드의 범위를 완전히 포함하면 해당 값을 리턴한다. (재귀여서 최종적으로 더하게 됨)
→ 구하고자 하는 tartet범위가 노드의 범위를 일부 걸치면 가지 치기를 통해 완전히 포함하는 범위를 찾아낼 때까지 재귀를 시행한다.
- 트리 업데이트
def segmentation_update(start, end, index, updateIdx, diff):
if updateIdx < start or end < updateIdx: # 범위에서 벗어나면 할 필요 없음
return
#print("update : (", start, end, ")", index)
tree[index] += diff # 범위에 해당하는 거니까 업데이트
if start == end: #가지 끝까지 갔으면 가지치기X
return
mid = (start + end) // 2 # 가지치기 해서 updateIdx에 해당하는 부분의 노드 업데이트 진행
segmentation_update(start, mid, index * 2, updateIdx, diff)
segmentation_update(mid + 1, end, index * 2 + 1, updateIdx, diff)
return
→ udpate하려는 index가 재귀되는 노드의 start,end에 포함되는지를 확인한 후, (포함안되면 할 필요 없으므로 바로 return)→ 리프 노드를 방문하게 되면 return한다.
→ 리프 노드가 아니라면 더 가지치기를 해야하므로 양쪽으로 update를 재귀수행 한다.
→ 포함되는 모든 노드들에 diff를 더하여 업데이트한다.
- 참고[세그먼트 트리 업데이트, 트리 크기 구하기] https://pangtrue.tistory.com/299
- [세그먼트 트리 init,query] https://imksh.com/23
풀이 코드
import sys
import math
def input():
return sys.stdin.readline().rstrip()
# 세그멘테이션 트리 사용
# 부분합을 구해야 하고, 특정 부분에 대한 수의 변경이 빈번하게 일어날 때, segmentation tree를 사용한다.
def segmentation_size(node):
height = int(math.ceil(math.log10(node) / math.log10(2))) + 1 # 세그먼트 트리의 높이 구하기 (log노드수/log2 + 1)
size = int(math.pow(2, height)) #트리 최대 노드 개수
return size
def segmentation_init(start, end, index):
if start == end:
tree[index] = num[start-1] # leaf 노드에 도달했으면 배열 값 삽입
return tree[index]
mid = (start + end) // 2 # leaf노드가 아니면 좌,우 가지치기 진행
tree[index] = segmentation_init(start, mid, index * 2) + segmentation_init(mid + 1, end, index * 2 + 1)
return tree[index]
def segmentation_target(start, end, index, target_start, target_end):
if target_start > end or target_end < start: # target 범위가 겹치지 않을 때
return 0
if target_start <= start and end <= target_end: # tree특정 노드의 범위가 target 범위에 포함될 때
return tree[index]
mid = (start + end) // 2 # target의 범위가 일부만 걸쳐있으면 다 찾아내야 하니까 가지치기 진행
return segmentation_target(start, mid, index * 2, target_start, target_end) + segmentation_target(mid + 1, end,
index * 2 + 1, target_e
def segmentation_update(start, end, index, updateIdx, diff):
if updateIdx < start or end < updateIdx: # 범위에서 벗어나면 할 필요 없음
return
tree[index] += diff # 범위에 해당하는 거니까 업데이트
if start == end: #가지 끝까지 갔으면 가지치기X
return
mid = (start + end) // 2 # 가지치기 해서 updateIdx에 해당하는 부분의 노드 업데이트 진행
segmentation_update(start, mid, index * 2, updateIdx, diff)
segmentation_update(mid + 1, end, index * 2 + 1, updateIdx, diff)
return
# 입력
N, M, K = map(int, input().split(" "))
num = [int(input()) for _ in range(N)]
command = [list(map(int, input().split(" "))) for _ in range(M + K)]
#tree 초기화
tree = [0] * segmentation_size(len(num))
segmentation_init(1, len(num), 1) # start,end,index
for co in command:
if co[0] == 1: # co[1]자리를 co[2]으로 변경(업데이트)
diff = co[2] - num[co[1]-1]
num[co[1]-1] = co[2] #num도 업데이트 해줘야 함
segmentation_update(1,len(num),1,co[1],diff)
elif co[0] == 2: # co[1]부터 co[2]까지 더해서 출력(부분합)
print(segmentation_target(1,len(num),1,co[1],co[2]))
느낀점
"211221"
소요 시간 : 2D 이상
- 세그먼트트리라는 새로운 자료구조를 접할 수 있는 문제였다.
- 앞으로 세그먼트트리 사용하는 문제들 좀 풀어봐야겠다
'Algorithm' 카테고리의 다른 글
[BOJ] 2014 소수의 곱 - Java (0) | 2022.02.04 |
---|---|
[BOJ] 2904 수학은 너무 쉬워 - Java (0) | 2022.02.04 |
[BOJ] 2504 괄호의 값 - python (0) | 2021.12.15 |
[BOJ] 6416 트리인가? - python (0) | 2021.12.07 |
[BOJ] 1991 트리 순회 - python (0) | 2021.12.07 |