IT recording...

[BOJ] 2042 구간 합 구하기 - python 본문

Algorithm

[BOJ] 2042 구간 합 구하기 - python

I-one 2021. 12. 21. 15:55
  • Segment Tree(세그먼트 트리)

부분합을 선형 계산을 통해 구하려면 O(N)의 시간이 걸린다. 그 후 특정 부분합을 구하려면 O(1)

But, 노드의 값이 빈번히 update된다면 어떨까?

→ 원소의 개수가 N개인 배열이 있을 때 첫번째 값이 M번 변경된다면 O(MN) 번의 시간복잡도가 필요하다.

⇒ Segment Tree를 이용하면 O(logN)의 시간 복잡도 내에 부분합을 구할 수 있다.

     : 이진 트리를 바탕으로 트리를 구성하며, update를 해야 할 부분만 업데이트 할 수 있다.

  • leaf 노드는 배열이 가지는 값을 가지게 된다.
  • 이외의 노드들은 부분합 등 특정 조건을 만족하는 조건으로 이루어진다.

 

  • 트리의 최대 노드 개수
def segmentation_size(node):
    height = int(math.ceil(math.log10(node) / math.log10(2))) + 1  # 세그먼트 트리의 높이 구하기 (log노드수/log2 + 1)
    size = int(math.pow(2, height)) #트리 최대 노드 개수
    return size

→ 트리의 높이를 구한 후 2의 지수승을 통해 최대 노드 개수를 구한다.

  • 초기화
def segmentation_init(start, end, index):
    if start == end:
        tree[index] = num[start-1]  # leaf 노드에 도달했으면 배열 값 삽입
        return tree[index]

    mid = (start + end) // 2  # leaf노드가 아니면 좌,우 가지치기 진행
    tree[index] = segmentation_init(start, mid, index * 2) + segmentation_init(mid + 1, end, index * 2 + 1)
    return tree[index]

→ leaf 노드에 도달하면 (start==end) 배열의 값을 트리에 삽입하고, 아니라면 가지치기를 계속 진행한다.

→ 여기서는 가지치기를 통해 리프노드가 아닌 노드들에 부분합을 저장하므로, 재귀를 이용하여

tree[index] = init(start,mid,index*2) + init(mid+1,end,index*2+1) 

를 통해 tree를 초기화 한다.

 

 

  • 부분합 구하기
def segmentation_target(start, end, index, target_start, target_end):
    if target_start > end or target_end < start:  # target 범위가 겹치지 않을 때
        return 0

    if target_start <= start and end <= target_end:  # tree특정 노드의 범위가 target 범위에 포함될 때
        return tree[index]

    mid = (start + end) // 2  # target의 범위가 일부만 걸쳐있으면 다 찾아내야 하니까 가지치기 진행
    return segmentation_target(start, mid, index * 2, target_start, target_end) 
						+ segmentation_target(mid + 1, end,index * 2 + 1,target_start,target_end)

→ 구하고자 하는 target범위가 노드의 범위와 겹치지 않으면 재귀를 중지한다.

→ 구하고자 하는 target범위가 노드의 범위를 완전히 포함하면 해당 값을 리턴한다. (재귀여서 최종적으로 더하게 됨)

→ 구하고자 하는 tartet범위가 노드의 범위를 일부 걸치면 가지 치기를 통해 완전히 포함하는 범위를 찾아낼 때까지 재귀를 시행한다.

 

 

  • 트리 업데이트
def segmentation_update(start, end, index, updateIdx, diff):
    if updateIdx < start or end < updateIdx:  # 범위에서 벗어나면 할 필요 없음
        return
    #print("update : (", start, end, ")", index)
    tree[index] += diff  # 범위에 해당하는 거니까 업데이트

    if start == end: #가지 끝까지 갔으면 가지치기X
        return

    mid = (start + end) // 2  # 가지치기 해서 updateIdx에 해당하는 부분의 노드 업데이트 진행
    segmentation_update(start, mid, index * 2, updateIdx, diff)
    segmentation_update(mid + 1, end, index * 2 + 1, updateIdx, diff)
    return

→ udpate하려는 index가 재귀되는 노드의 start,end에 포함되는지를 확인한 후, (포함안되면 할 필요 없으므로 바로 return)→ 리프 노드를 방문하게 되면 return한다.

→ 리프 노드가 아니라면 더 가지치기를 해야하므로 양쪽으로 update를 재귀수행 한다.

→ 포함되는 모든 노드들에 diff를 더하여 업데이트한다.

 

 

세그먼트 트리(Segment Tree) 알고리즘

이게 뭐야? 트리 종류 중에 하나이며, 연속된 구간(특정 범위)의 합(최솟값, 최댓값, 곱 등)을 구하는데 많이 쓰인다. 아래에서 선형구현과 비교하며 왜 쓰는지, 어떻게 사용하는지 gif를 준비해

imksh.com

 

풀이 코드

import sys
import math

def input():
    return sys.stdin.readline().rstrip()

# 세그멘테이션 트리 사용
# 부분합을 구해야 하고, 특정 부분에 대한 수의 변경이 빈번하게 일어날 때, segmentation tree를 사용한다.

def segmentation_size(node):
    height = int(math.ceil(math.log10(node) / math.log10(2))) + 1  # 세그먼트 트리의 높이 구하기 (log노드수/log2 + 1)
    size = int(math.pow(2, height)) #트리 최대 노드 개수
    return size

def segmentation_init(start, end, index):
    if start == end:
        tree[index] = num[start-1]  # leaf 노드에 도달했으면 배열 값 삽입
        return tree[index]

    mid = (start + end) // 2  # leaf노드가 아니면 좌,우 가지치기 진행
    tree[index] = segmentation_init(start, mid, index * 2) + segmentation_init(mid + 1, end, index * 2 + 1)
    return tree[index]

def segmentation_target(start, end, index, target_start, target_end):
    if target_start > end or target_end < start:  # target 범위가 겹치지 않을 때
        return 0

    if target_start <= start and end <= target_end:  # tree특정 노드의 범위가 target 범위에 포함될 때
        return tree[index]

    mid = (start + end) // 2  # target의 범위가 일부만 걸쳐있으면 다 찾아내야 하니까 가지치기 진행
    return segmentation_target(start, mid, index * 2, target_start, target_end) + segmentation_target(mid + 1, end,
                                                                                                      index * 2 + 1,                                                                                             target_e
def segmentation_update(start, end, index, updateIdx, diff):
    if updateIdx < start or end < updateIdx:  # 범위에서 벗어나면 할 필요 없음
        return

    tree[index] += diff  # 범위에 해당하는 거니까 업데이트

    if start == end: #가지 끝까지 갔으면 가지치기X
        return

    mid = (start + end) // 2  # 가지치기 해서 updateIdx에 해당하는 부분의 노드 업데이트 진행
    segmentation_update(start, mid, index * 2, updateIdx, diff)
    segmentation_update(mid + 1, end, index * 2 + 1, updateIdx, diff)
    return


# 입력
N, M, K = map(int, input().split(" "))
num = [int(input()) for _ in range(N)]
command = [list(map(int, input().split(" "))) for _ in range(M + K)]

#tree 초기화
tree = [0] * segmentation_size(len(num))
segmentation_init(1, len(num), 1)  # start,end,index 

for co in command:
    if co[0] == 1:  # co[1]자리를 co[2]으로 변경(업데이트)
        diff = co[2] - num[co[1]-1]
        num[co[1]-1] = co[2] #num도 업데이트 해줘야 함
        segmentation_update(1,len(num),1,co[1],diff)

    elif co[0] == 2:  # co[1]부터 co[2]까지 더해서 출력(부분합)
        print(segmentation_target(1,len(num),1,co[1],co[2]))

 

 

느낀점

"211221"

소요 시간 : 2D 이상

  • 세그먼트트리라는 새로운 자료구조를 접할 수 있는 문제였다.
  • 앞으로 세그먼트트리 사용하는 문제들 좀 풀어봐야겠다

'Algorithm' 카테고리의 다른 글

[BOJ] 2014 소수의 곱 - Java  (0) 2022.02.04
[BOJ] 2904 수학은 너무 쉬워 - Java  (0) 2022.02.04
[BOJ] 2504 괄호의 값 - python  (0) 2021.12.15
[BOJ] 6416 트리인가? - python  (0) 2021.12.07
[BOJ] 1991 트리 순회 - python  (0) 2021.12.07
Comments