본문 바로가기
TIL - 프로그래밍/Python 알고리즘

[백준] 2042. 구간 합 구하기 - Python (세그먼트 트리 개념 설명)

by chaemj97 2023. 7. 4.
728x90

여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합 구하기

배열에서 특정한 범위의 데이터 합을 가장 빠르게 구하는 방법은 무엇인가?
data = [1,2,3,4,5]

방법 1. 단순 배열을 이용해 선형적으로 구하기

인덱스 i부터 j까지 데이터 더하기

print(sum(data[i:j+1]))

앞에서 하나씩 더해가므로 데이터의 개수가 n이면 시간 복잡도 O(N), n이 매우 커지면 구간의 합을 구하는 속도가 너무 느리기 때문에 더 좋은 알고리즘이 필요하다.


방법2. 트리 구조 이용하여 구하기

세그먼트 트리

배열의 특정 구간에 대한 정보를 추가로 담고 있다. 

트리 구조의 특성상 합을 구할 때 시간 복잡도 O(logN)

 

1️⃣ 구간 합 트리 생성하기

가장 최상단의 노드에는 전체 원소의 합이 들어간다.

이후 두번째, 세번째 노드를 구한다. 두번째 노드는 인덱스 0~2번 구간의 합, 세번째 노드는 인덱스 3~4번 구간의 합이 들어간다. 즉, 부모 노드의 데이터의 범위를 반씩 나누어 그 구간의 합들을 저장한다. 

빨간색 : 세그먼트 트리 인덱스

초록색 : 배열의 구간

원 안 숫자 : 특정 구간의 배열의 합

 

자식 노드의 인덱스를 쉽게 구하기 위해 인덱스 시작은 1부터 시작한다. 

tree = [0, 15, 6, 9, 3, 3, 4, 5, 1, 2]

트리의 길이는 포화 이진트리의 길이로 만들면 된다. 하지만 주로 데이터의 개수에 4를 곱한 크기만큼 미리 트리 공간을 할당한다.

# n : 데이터의 크기
seg_tree = [0 for _ in range(4*n)]

# 세그먼트 트리 만들기
# seg_tree[x] 값 구하기
def build_tree(x,left,right):
    # 1. 구간에 데이터 1개
    if left == right:
        seg_tree[x] = num[left]
        return seg_tree[x]
        
    # 2. 구간에 데이터 여러개
    # 부모 노드의 구간을 둘로 나눈다.
    mid = (left + right)//2
    # 왼쪽 자식 노드
    left_value = build_tree(2*x,left,mid)
    # 오른쪽 자식 노드
    right_value = build_tree(2*x+1,mid+1,right)
    # 부모 노드는 자식노드들의 합
    seg_tree[x] = left_value + right_value
    return seg_tree[x]

 

2️⃣ 구간 합 트리 생성하기

인덱스가 1~4인 구간의 데이터 합 구하기 위해선 아래 색칠된 세 노드의 합만 구하면 된다.

구하고자 하는 답은 2+3+9 = 14 가 된다.구간의 합은 '범위 안에 있는 경우'에 한해서만 더해주면 된다.

# 세그먼트 트리로 구간 합 구하기
# b~c구간합 구하기
# 트리의 구간 left~right
# 현재 노드 x
def find_tree(b,c,x,left,right):
    # 1. 구하고 싶은 구간(b~c)이 현재 트리 구간에 포함 X
    if c < left or right < b:
        return 0
        
    # 2. 구하고 싶은 구간(b~c) 안에 현재 트리 포함
    if b <= left and right <=c:
        return seg_tree[x]
        
    # 3. 구간이 겹치는 경우
    mid = (left + right)//2
    left_value = find_tree(b,c,x*2,left,mid)
    right_value = find_tree(b,c,x*2+1,mid+1,right)
    return left_value + right_value

 

3️⃣ 구간 합 트리 값 바꾸기

특정 인덱스의 값을 수정할 때는 해당 인덱스를 포함하고 있는 모든 구간의 합 노드들을 갱신해줘야 한다. 예를 들어 인덱스 2의 값을 6으로 수정한다고 하면 아래 색칠된 3개의 구간 합 노드를 모두 수정하면 된다.

마찬가지로 수정할 노드로는 '범위 안에 있는 경우'에 한해서만 수정하면 된다.

# 세그먼트 트리 값 업데이트
# 인덱스 idx의 값을 val로 바꾸기
def update_tree(x,left,right,idx,val):
    # 구간에 데이터 1개, 그 데이터가 idx에 해당
    if left == right == idx:
        seg_tree[x] = val
        return
        
    # 현재 구간에 idx가 포함 X
    if idx < left or right < idx:
        return
    
    # 자식 노드에 idx가 포함된다면 부모 노드도 변한다
    mid = (left + right)//2
    # 왼쪽 자식 업데이트
    update_tree(x*2,left,mid,idx,val)
    # 오른쪽 자식 업데이트
    update_tree(x*2+1,mid+1,right,idx,val)
    
    # 업데이트 된 자식 노드를 통해 현재 노드 업데이트
    seg_tree[x] = seg_tree[x*2] + seg_tree[x*2+1]

세그먼트 트리를 이용하면 구간 합을 계산할 때 기존의 방법보다 훨씬 속도가 빨라진다.

 

세그먼트 트리를 이용하면 구간 합을 구하거나 수정할 때
시간 복잡도는 O(logN)이다. 

 

관련 문제 풀기

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

import sys
input = sys.stdin.readline

# 수의 개수, 수 변경 횟수, 구간의 합 횟수
n,m,k = map(int,input().split())
num = [int(input()) for _ in range(n)]

# 세그먼트 트리
# seg_tree[1] : 모든 노드의 합
# seg_tree[2] : 0~n//2번 노드의 합
# seg_tree[3] : n//2+1~n번 노드의 합
seg_tree = [0 for _ in range(4*n)]

# 1. 세그먼트 트리 만들기
# seg_tree[x] 값 구하기
def build_tree(x,left,right):
    if left == right:
        seg_tree[x] = num[left]
        return seg_tree[x]
    mid = (left + right)//2
    left_value = build_tree(2*x,left,mid)
    right_value = build_tree(2*x+1,mid+1,right)
    seg_tree[x] = left_value + right_value
    return seg_tree[x]

build_tree(1,0,n-1)

# 2. 세그먼트 트리로 구간 합 구하기
# b~c구간합 구하기
# 트리의 구간 left~right
# 현재 노드 x
def find_tree(b,c,x,left,right):
    # 구하고 싶은 구간(b~c)가 현재 트리 구간에 포함 X
    if c < left or right < b:
        return 0
    # 구하고 싶은 구간(b~c) 안에 현재 트리 포함
    if b <= left and right <=c:
        return seg_tree[x]
    # 구간이 겹치는 경우
    mid = (left + right)//2
    left_value = find_tree(b,c,x*2,left,mid)
    right_value = find_tree(b,c,x*2+1,mid+1,right)
    return left_value + right_value

# 3. 세그먼트 트리 값 업데이트
# 인덱스 idx의 값을 val로 바꾸기
def update_tree(x,left,right,idx,val):
    # 길이 1인 구간
    if left == right == idx:
        seg_tree[x] = val
        return
    # 현재 구간에 idx가 포함 X
    if idx < left or right < idx:
        return
    
    mid = (left + right)//2
    # 왼쪽 자식 업데이트
    update_tree(x*2,left,mid,idx,val)
    # 오른쪽 자식 업데이트
    update_tree(x*2+1,mid+1,right,idx,val)
    
    # 업데이트 된 자식 노드를 통해 현재 노드 업데이트
    seg_tree[x] = seg_tree[x*2] + seg_tree[x*2+1]
    
for _ in range(m+k):
    a,b,c = map(int,input().split())
    # b번째 수를 c로 바꾸기
    if a == 1:
        update_tree(1,0,n-1,b-1,c)
    # b번째 수부터 c번째 수까지 합 구하기
    else:
        s = find_tree(b-1,c-1,1,0,n-1)
        print(s)
728x90
반응형

댓글