본문 바로가기

프로그래밍

segment tree (세그먼트 트리)

세그먼트 트리 (Segment Tree)

세그먼트 트리란 무엇일까요? 다음의 문제를 통해 알아봅시다.

 

문제

arr[0 ... n-1] 인 배열이 주어졌을때,

  1. $ 0 \leq l \leq r \leq n-1$ 인 $l$ 부터 $r$ 까지 배열의 원소들 합을 구하기.
  2. 특정 배열의 원소를 새로운 값 $x$ 로 바꾸기. 다시 말하면 $ 0 \leq i \leq n-1$ 인 상황에서 arr[i] = x 을 해야 합니다.

정답

가장 간단한 솔루션은 $l$ 부터 $r$ 까지 loop 를 만들어서 주어진 범위 내의 배열 원소들의 합을 계산하는 것입니다. 배열의 값을 업데이트 하기 위해서는, arr[i] = x 를 하면 됩니다. 주어진 범위의 합을 계산하는 과정은 $O(n)$ 의 시간복잡도가 걸리고, 값을 업데이트 하기 위해서는 $O(1)$ 의 시간복잡도가 걸립니다.

다른 솔루션은 또다른 배열을 만들고 그 배열의 $i$ 번째 위치에 처음 (index $0$) 부터 $i$ 까지 의 기존 배열 원소들의 합을 저장하는것 입니다. 저장을 마친뒤 범위가 주어지면 그 범위 내의 합은 $O(1)$ 의 시간 복잡도 만에 계산되겠죠? 그냥 바로 꺼내 쓰기만 하면 되니까요. 하지만 배열의 값을 업데이트 한다면, 또다른 배열에 저장된 원소들의 합의 수정이 필요하므로 $O(1)$ 이 아닌 $O(n)$ 의 시간복잡도가 필요하게 됩니다. 이런 방식은 숫자들의 계산이 많고 적은 업데이트가 필요한 경우에만 적합한 솔루션이 되겠네요.

만약 배열의 범위의 합을 계산하는 방법과 값들을 업데이트 하는 방법들이 시간적으로 동일하다면 어떨까요? 배열이 주어지면 업데이트와 범위들의 합 모두 $O(log (n))$ 의 시간복잡도에 해낼 수 있을까요? 세그먼트 트리를 이용하면 두 연산(업데이트 및 합) 모두 $O(log(n))$ 만에 해낼 수 있습니다.

세그먼트 트리의 표현

  1. 잎사귀 노드들은 입력한 배열의 원소들 입니다.
  2. 각 내부 노드들은 잎사귀 노드들의 합병을 의미합니다. 합병은 주어진 문제마다 합병하는 방법이 달라질 수 있습니다. 이 문제에서는 합병이라는 것은 내부 노드 자식들인 잎사귀 노드들의 합 입니다.

세그먼트 트리를 표현하기 위해서 주로 배열을 사용합니다. 인덱스 $i$ 를 기준으로 했을때, 각 노드의 왼쪽 자식들의 인덱스는 $2 * i +1 $ 이고, 오른쪽 자식들의 인덱스는 $2 *i +2 $ 이고, 부모 노드는 $(i-1)/2$ 에 있습니다.

주어진 배열을 이용해 세그먼트 트리 만들기

$n$ 의 크기의 배열이 주어진다면, 배열 arr[0 . . . n-1] 부터 시작하겠습니다. 영어 segment(세그먼트)는 '부분' 이라는 뜻으로 현재 배열의 전체 크기 $0$ 부터 $n-1$ 까지의 길이 역시 세그먼트로 볼 수 있겠습니다.
세그먼트를 만드는 함수를 호출할때 매번 배열이 길이가 $1$ 이 될때까지 함수 호출로 넘겨진 배열을 절반으로 쪼개게 됩니다. (이 부분에서 segment tree 의 유래가 온듯 싶습니다.) 그리고 쪼개어진 배열들은 재귀적으로 다시 세그먼트를 만드는 함수에 양쪽 모두 넘겨지게되고, 각각의 세그먼트 배열은 배열들의 합을 구한뒤 적절한 노드에 값을 저장하게 됩니다.

주어진 배열이 [1, 3, 5, 7, 9 ,11] 이라 가정하고 세그먼트 트리를 만드는 과정을 천천히 살펴보겠습니다.
다음의 규칙을 따릅니다.

  1. 각 노드는 아래에 배열의 범위를 표시합니다.
  2. 루트(root) 노드는 배열의 전체 범위인 [0:n-1] 을 나타냅니다.
  3. 각 노드는 그 노드 아래의 모든 잎사귀 노드들의 합을 나타냅니다.

시작은 n = 6 으로부터 범위가 [0:5] 로 정해집니다.

범위가 절반으로 나눠집니다. 이 때 기준은 n // 2 가 됩니다.

같은 과정을 반복하던중 lr 의 범위가 똑같은 노드를 만들어 냈습니다.
이것은 잎사귀 노드로 더 이상 분리될 수 없는 노드입니다.

그래서 여기에는 주어진 배열에 해당하는 index를 참고해서 값을 넣었습니다. (arr = [1, 3, 5, 7, 9 ,11])

나머지도 같은 방법을 반복합니다. (재귀로 호출)

이제 전부 잎사귀 노드의 값이 채워졌다면, 그 말은 이 노드들의 값을 부모 노드로 반환할 차례라는 것입니다.
반환된 잎사귀 노드의 값들을 합쳐서 해당 부모 노드의 값을 채우기 시작합니다. 함수로 따지면, 재귀 호출의 레벨 정도가 점점 낮아지는것을 뜻합니다.

1과 3을 더하면 4고, 7과 9를 더하면 16이 되죠?

같은 방식으로 모두 채워줍니다. 세그먼트 트리가 완성됬습니다.

다음은 세그먼트 트리의 관련된 여러 함수중 세그먼트 트리를 만드는 함수의 코드입니다.

// 배열 (array[ss..se]) 를 세그먼트 트리로 만드는 재귀함수.
// si 는 세그먼트 트리 st 의 현재 index 를 말합니다.
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
    // 만약 주어진 범위 내에서 배열이 하나의 원소밖에 가지고 있지 않다면, 
    // 그것을 현재 세그먼트 트리의 현재 노드에 저장하고 반환한다.
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }
 
    // 만약 주어진 범위 내에서 배열이 하나 이상의 원소를 가지고 있다면, 
    // 왼쪽과 오른쪽 서브 트리들을 나누어 재귀적으로 호출하고 
    // 후에 반환된 값들의 합을 현재 노드에 저장한다.
    int mid = getMid(ss, se);
    st[si] =  constructSTUtil(arr, ss, mid, st, si*2+1) +
              constructSTUtil(arr, mid+1, se, st, si*2+2);
    return st[si];
}

getMid(int s, int e) 함수는 주어진 두 index의 가운데 index를 반환하는 함수입니다.
이 함수의 호출은 다음과 같이 이루어집니다. constructSTUtil(arr, 0, n-1, st, 0);
constructSTUtil 함수의 마지막 매개변수 si 는 root node 부터 차례로 왼쪽부터 0, 1, 2, 3 ... 으로 이루어 집니다. 그림으로 표현하면 다음과 같습니다.

완성된 세그먼트 트리의 모든 레벨은 마지막 레벨 (잎사귀 노드) 부분을 제외하고 완벽히 채워지게 됩니다. 다시 말씀드리면, 세그먼트 트리는 완전 이진 트리 가 되는것이죠. 왜냐하면 세그먼트 트리를 만드는 함수를 호출할때 마다 배열을 절반으로 나누기 때문입니다. 완성된 트리는 언제나 $n$ 개의 잎사귀 노드를 가지는 완전 이진 트리이기 때문에 (주어진 배열의 원소의 갯수가 $n$개), 내부 노드들은 $n-1$ 개가 될것입니다. 그렇기 때문에 총 노드들의 개수는 $n + (n - 1 ) = 2n- 1$개 입니다.

세그먼트 트리의 높이(level)는 $\left\lceil log_2(n) \right\rceil$ 입니다. 배열을 이용해 표현된 트리, 그리고 부모 및 자식간의 노드 관계는 완전 이진 트리 형태로 반드시 유지되어야 하므로, 세그먼트 트리를 만들기 위해서 할당되어질 메모리 크기는 $2 *2^{(\left\lceil log_2(n) \right\rceil)} -1$ 이 됩니다. (메모리 크기 = 완전 이진 트리의 최대 노드 갯수)

주어진 범위의 합

트리가 완성되었으니 이 완성된 세그먼트 트리를 활용해 어떻게 주어진 범위의 합을 구할지 알아보겠습니다.
아래와 같은 알고리즘을 사용합니다.

int getSum(node, l, r) 
{
   if 노드의 범위가 l 과 r 사이에 있다면
        return 값 in 노드
   else if 노드의 범위가 '완벽하게' l 과 r 사이에 없다면
        return 0
   else 노드의 범위가 '부분적으로' l 과 r 사이에 없다면
    return getSum(node's left child, l, r) + 
           getSum(node's right child, l, r)

그림으로 한번 알아볼까요?
다음과 같이 3개의 케이스를 나눠보겠습니다.

  1. 노드의 범위가 lr 사이에 있는 경우
  2. 노드의 범위가 완벽하게 lr 사이에 없는 경우
  3. 노드의 범위가 부분적으로 lr 사이에 없는 경우

범위는 l = 0 , r = 4[0:4] 의 합을 구하는것으로 하겠습니다.

첫번째 루트 노드는 [0:5] 입니다. [0:4] 까지는 필요하지만 [5:5] 는 필요 없습니다. 그래서 부분적으로 겹칩니다. 3번째 경우네요. 정해진 알고리즘에 따라 왼쪽과 오른쪽의 자식 노드의 합을 구하기로 하겠습니다.

이제 왼쪽의 자식노드는 1번째 경우네요. [0:2][0:4] 에 완벽하게 들어맞습니다.
오른쪽 자식노드의 경우는 부분적으로 들어맞네요. 3번째 경우입니다.
왼쪽 노드는 9 를 반환하고, 오른쪽 노드는 다시 왼쪽과 오른쪽 자식노드의 값을 반환 하기로 합니다.

[3:4] 범위의 노드는 첫번째 경우로 16 을 반환하게 됩니다.
[5:5] 범위의 노드는 완벽하게 들어맞지 않으므로 두번째 경우로 0 을 반환하게 됩니다.

[0:2]범위의 노드는 아까 언급했다시피 9를 반환했고, [3:5] 범위의 노드는 160의 값을 합쳐 총 16을 다시 반환하게 됩니다. 결과적으로 총 합이 25가 세그먼트 트리에서 [0:4] 범위의 합이 되는것입니다.

값을 업데이트 하기

지금까지 주어진 배열이 [1, 3, 5, 7, 9 ,11] 라고 가정하면서 트리를 만들어 왔습니다. 하지만 arr[2] 부분을 5 가 아니라 10 으로 바꾸면 어떨까요? 배열의 값이 바뀌었기 때문에 세그먼트 트리의 잎사귀 노드가 우선적으로 바뀔것이고, 그에 따른 부모 노드들은 주어진 범위의 합을 나타내기 때문에 부모 노드들의 값 역시 바뀔것입니다. 우리는 이 트리를 업데이트할 필요가 있습니다.

트리를 만드는것이나 주어진 범위의 합을 구하는것처럼, 값을 업데이트 하는것 역시 재귀적으로 가능합니다. 일단 우리에게 주어진것은 바뀔 배열의 index와 그 위치에 넣을 diff 라고 하는 값, 이 두가지 입니다. diff 는 바뀔 배열의 index 에 들어있는 현재의 값에서 바꿀 값의 차이를 말합니다. 지금은 5 에서 10으로 바꾸려 한것이니까 diff 값은 10 - 5 = 5 , 즉, diff = 5가 되겠습니다.

트리의 루트 노드로 부터 시작해서, 주어진 index 가 해당 노드 범위 내에 포함된다면 diff 라는 값을 더해줍니다. 이 작업은 세그먼트 트리에 있는 모든 노드에 빠짐없이 적용됩니다. 만약 해당 노드의 범위에 주어진 index 가 포함되지 않는다면, 어떠한 작업도 하지 않습니다.

이 역시 그림으로 설명해드리겠습니다.

arr = [1, 3, 5, 7, 9 ,11]를 기준으로 하여 만들었던 세그먼트 트리에 arr[2] = 10 이라는 작업을 하였습니다. diff = 5 입니다. 루트 노드부터 시작하면, [0:5] 이므로 index 2 가 그 사이에 포함됩니다.

36+5 = 41 이므로 루트 노드의 값은 41로 업데이트 됩니다. 이제, 다시 index 2 를 포함하는 노드를 찾아서 똑같이 반복해 줍니다.

[0:2] 범위의 노드와 [0:1] 범위를 가진 노드의 값을 바꿔줍니다.

이행(Implementation)된 세그먼트 트리 코드

아래의 코드는 위에서 언급했던 총 3가지의 작업을 수행할 수 있습니다. 그 세가지는 생성, 합 구하기, 업데이트 입니다. 위에서 보여드렸던 일부 코드도 아래에 포함되어 있습니다.

// 세그먼트 트리의 생성, 쿼리, 업데이트를 위한 C 프로그램 
#include <stdio.h>
#include <math.h>
 
// 주어진 두 인덱스의 가운데를 구하는 유틸리티 함수
int getMid(int s, int e) {  return s + (e -s)/2;  }
 
/*  주어진 범위에 해당하는 배열으 총 합을 구하는 재귀함수
    이 함수의 매개변수는 다음과 같습니다.
 
    st    --> 세그먼트 트리의 포인터
    si    --> 세그먼트 트리에서 현재 노드의 인덱스. 
    		  루트 노드의 인덱스는 무조건 0 이기 때문에 시작을 0 으로 맞춰준다.
    ss & se  --> 현재 노드에 적용되는 시작과 끝의 인덱스 (st[si])
    qs & qe  --> 쿼리 범위의 시작과 끝 (합 구하기) */
int getSumUtil(int *st, int ss, int se, int qs, int qe, int si)
{
    // 첫번째 케아스 : 쿼리 범위가 현재 노드가 가진 범위 안에 있다면 
    // 현재 노드의 값을 반환한다. 
    if (qs <= ss && qe >= se)
        return st[si];
 
    // 두번째 케이스 : 쿼리 범위가 현재 노드가 가진 범위의 완벽히 밖이라면
    // 0을 반환한다. 
    if (se < qs || ss > qe)
        return 0;
 
    // 세번째 케이스 : 쿼리 범위가 부분적으로 현재 노드가 가진 범위의 내에 있다면
    // 현재 노드의 자식 노드들이 가진 값의 합을 반환한다.
    int mid = getMid(ss, se);
    return getSumUtil(st, ss, mid, qs, qe, 2*si+1) +
           getSumUtil(st, mid+1, se, qs, qe, 2*si+2);
}
 
/* 범위내에 주어진 index 에 해당하는 트리를 업데이트 하는 재귀함수.
   주어진 매개면수 st, si, ss 그리고 se 는 getSumUtil() 와 같습니다.
    i    --> 업데이트할 원소의 index. 이 index는 입력된 배열의 index 입니다. 
   diff --> i 범위 내에 있는 모든 노드에 추가될 값 */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)
{
    // Base Case: 만약 주어진 index 가 현재 노드가 가진의 범위에 존재하지 않다면 무시합니다. 
    if (i < ss || i > se)
        return;
 
    // 만약 주어진 index 가 현재 노드가 가진 범위 내에 존재한다면,
    // 노드가 가진 값과 그 노드의 자식 노드들의 값을 업데이트 합니다. 
    st[si] = st[si] + diff;
    if (se != ss)
    {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i, diff, 2*si + 1);
        updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);
    }
}
 
// 세그먼트 트리와 입력된 배열을 업데이트할 함수. 
// 이 함수는 세그먼트 트리를 업데이트 하기 위해 updateValueUtil() 함수를 사용합니다. 
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
    // 주어진 index 의 오류를 체크합니다.
    if (i < 0 || i > n-1)
    {
        printf("Invalid Input");
        return;
    }
 
    // 현재 값과 기존의 값의 차이 (diff) 를 계산하기
    int diff = new_val - arr[i];
 
    // 입력된 배열의 값을 업데이트하기
    arr[i] = new_val;
 
    // 세그먼트 트리에 있는 노드들의 값들 업데이트 하기 
    updateValueUtil(st, 0, n-1, i, diff, 0);
}
 
// index qs(query start) 부터 qe (query end) 까지의
// 범위의 모든 값들을 더해 반환하는 함수. 이 함수는 getSumUtil() 함수의 사용이 메인입니다.
int getSum(int *st, int n, int qs, int qe)
{
    // 입력된 값의 오류를 체크합니다.
    if (qs < 0 || qe > n-1 || qs > qe)
    {
        printf("Invalid Input");
        return -1;
    }
 
    return getSumUtil(st, 0, n-1, qs, qe, 0);
}
 
// 배열 (array[ss..se]) 를 세그먼트 트리로 만드는 재귀함수.
// si 는 세그먼트 트리 st 의 현재 index 를 말합니다.
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
    // 만약 주어진 범위 내에서 배열이 하나의 원소밖에 가지고 있지 않다면, 
    // 그것을 현재 세그먼트 트리의 현재 노드에 저장하고 반환한다.
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }
 
    // 만약 주어진 범위 내에서 배열이 하나 이상의 원소를 가지고 있다면, 
    // 왼쪽과 오른쪽 서브 트리들을 나누어 재귀적으로 호출하고 
    // 후에 반환된 값들의 합을 현재 노드에 저장한다.
    int mid = getMid(ss, se);
    st[si] =  constructSTUtil(arr, ss, mid, st, si*2+1) +
              constructSTUtil(arr, mid+1, se, st, si*2+2);
    return st[si];
}
 
/* 주어진 배열을 통해 세그먼트 트리를 만드는 함수.
   이 함수는 세그먼트 트리를 위해 메모리를 할당하고 
   할당된 메모리를 채우기 위해 constructSTUtil() 함수를 호출합니다.
 */
int *constructST(int arr[], int n)
{
    // 세그먼트 트리의 메모리를 할당합니다.
 
    //세그먼트 트리의 높이 x 를 찾습니다.
    int x = (int)(ceil(log2(n))); 
 
    //x 를 이용해 세그먼트 트리의 최대 사이즈를 구합니다. 
    int max_size = 2*(int)pow(2, x) - 1; 
 
    // 메모리를 할당합니다. 
    int *st = new int[max_size];
 
    // 할당된 메모리 st 를 채웁니다.
    constructSTUtil(arr, 0, n-1, st, 0);
 
    // 완성된 세그먼트 트리를 반환합니다.
    return st;
}
 
// 테스트를 위한 드라이버 프로그램
int main()
{
    int arr[] = {1, 3, 5, 7, 9, 11};
    int n = sizeof(arr)/sizeof(arr[0]);
 
    // 주어진 배열을 통해 세그먼트 트리를 만듭니다.
    int *st = constructST(arr, n);
 
    // 배열의 index 1 부터 3 까지 원소들의 합을 출력합니다.
    printf("Sum of values in given range = %dn", 
            getSum(st, n, 1, 3));
 
    // 업데이트: arr[1] = 10 으로 설정하고 
    // 이에 따른 세그먼트 트리의 값들을 업데이트 합니다. 
    updateValue(arr, st, n, 1, 10);
 
    // 값이 업데이트 된후 다시 index 1 부터 3까지의 합을 구합니다. 
    printf("Updated sum of values in given range = %dn",
             getSum(st, n, 1, 3));
    return 0;
}

출력값

Sum of values in given range = 15
Updated sum of values in given range = 22

시간 복잡도

트리 생성에 필요한 시간복잡도는 $O(n)$ 입니다. 총 $2n-1$ 개의 노드들이 있고 트리 생성 과정에서 각 노드들의 값은 오직 한번만 계산됩니다.

쿼리(합 구하기) 작업의 시간복잡도는 $O(log(n))$ 입니다. 업데이트를 위한 시간복잡도 역시 $O(log(n))$ 입니다.
포스팅 처음에 언급했던 방법의 시간을 생각하면 훨씬 절약된 시간입니다.


문제를 풀어야 하는데 자꾸 개념을 보고있네요..ㅎㅎ

그래도 개념을 모르면 문제를 못푸니 열심히 글을 쓰겠습니다.

이 글은 여기를 기반으로 하여 작성되었습니다.

'프로그래밍' 카테고리의 다른 글

포함 배제의 원리 (Inclusion-Exclusion)  (0) 2018.02.20
[DP] 도시의 집  (0) 2018.02.20
행렬 멱법을 이용한 피보나치 값 구하기  (0) 2018.02.18
Matrix Exponentiation (행렬 멱법)  (1) 2018.02.17
[DP] 게임 이기기  (0) 2018.02.17