개발하는SM

[트리] 세그먼트 트리 - 부분 합 효율적으로 구하기 본문

Algorithm - 이론

[트리] 세그먼트 트리 - 부분 합 효율적으로 구하기

개발하는SM 2021. 8. 22. 22:41
여러 개의 데이터가 연속적으로 존재할 때,
특정한 범위의 데이터 합을 가장 빠르고 간단하게 구할 수 있는 자료구조

예시 데이터 : A[] = {1,9,3,8,4,5,5,9,10,3,4,5};

 

위와 같이 단순 배열을 사용해 특정 구간의 합을 선형적으로 구할 경우,

(1~10) 범위의 데이터 합을 구하려면 원소 하나씩 접근하여 더해줘야 함. 따라서 O(10).

특정 구간에 포함되는 데이터의 개수가 N개일 경우 O(N) 이 된다.

 

하지만, 세그먼트 트리 구조를 이용해 구한다면, O(logN) 의 시간복잡도로 구할 수 있다.

세그먼트 트리는 아래와 같은 절차대로 활용할 수 있다.

 

1. 구간 합 트리(Segment Tree) 생성

구간 합 트리의 각 Node 에는 각각 위 예시 데이터 배열의 구간 합을 저장한다.

구간 합 트리를 저장할 배열을 tree 라고 했을 때,

tree[1] = (A[0] + A[1] + ..... + A[11] );

tree[2] = (A[0] + A[1] + ... + A[5] );

tree[3] = (A[6] + A[7] + .... + A[11] );

 

위와 같은 형태로 자식 노드들이 부모 노드의 데이터 범위를 반씩 분할하여 그 구간의 합들을 저장하도록 초기화한다.

출처 : https://m.blog.naver.com/ndb796/221282210534

구간 합 트리를 초기화 할 때는 아래와 같이 재귀적으로 구현하면 간단히 구현할 수 있다.

 

주의사항

  • 구간 합 트리의 크기는 원래 배열의 원소 개수(N) * 4 이다
  • 구간 합 트리의 인덱스는 1부터 시작하는 것이 좋다 ( 1로 시작했을 때, 자식노드를 찾아가는 것이 더 효율적이다 
    // 1. 구간 합 트리 (Segment Tree) 초기화
    public static int init(int start, int end, int node){
        // 트리의 최상단에는 모든 구간 원소의 합이 들어감 ( 0 ~ 12 )
        // Segment Tree 의 최상단 index = 1
        if(start == end){
            tree[node] = A[start];
            return tree[node];
        } 

        int mid = (start + end) / 2;

        return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1);
    }

 

2. 구간 합 트리를 활용한 구간 합 구하기

1번에서 만든 구간 합 트리를 활용해서 구간 합을 구하는 과정은 아래와 같다.

구간 합을 구하는 함수의 인자는 아래와 같이 5개 이다.

 

int start : 현재 tree node 가 포함하는 구간의 시작 index

int end : 현재 tree node 가 포함하는 구간의 끝 index

int left : 구하고자 하는 범위의 시작 index

int right : 구하고자 하는 범위의 끝 index

int node : 현재 tree의 어떤 노드를 탐색중인지

 

현재 세그먼트 트리 노드가 포함하는 범위를 X ( A[start] + A[start+1] + .... + A[end] ) 라고 하고, 

내가 지금 구하고자 하는 범위를 Y ( A[left] + A[left +1] + .... + A[right] ) 라고 할 때..

아래와 같이 경우의 수를 나눌 수 있다.

 

1. Y 가 X 전체 범위를 포함하는 경우 - return tree[node]

2. X 가 Y 전체 범위를 포함하는 경우 - 왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 재탐색

3. X 와 Y 가 일부만 겹치는 경우 - 왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 재탐색

4. X 와 Y 가 전혀 겹치지 않는 경우 - return 0

 

이것도 재귀로 아래와 같이 구현할 수 있다.

    // 2. Segment Tree 를 활용한 구간 합 구하기
    // start / end 는 현재 node 가 포함하는 합 범위
    // left / right 은 구해야 하는 범위
    public static int getSum(int start, int end, int left, int right, int node){

        // 현재 노드의 범위를 완전 벗어난 경우
        if(end < left || right < start) return 0;

        // 구하고자 하는 범위보다 현재 노드가 포함하는 범위가 좁을 경우
        if(left <= start && end <= right) return tree[node];

        int mid = (start + end) / 2;

        return getSum(start, mid, left, right, node*2) + getSum(mid+1, end, left, right, node*2+1);
    }

 

3. 배열의 일부 값을 변경하는 경우

배열 중간에 특정 수를 변경하는 경우,

해당 수를 포함하는 모든 구간 합 트리 노드들을 변경해 주어야 합니다.

변경된 idx 가 현재 노드의 범위 내에 포함 되는 경우 / 포함되지 않는 경우를 나누어 생각해주면 됩니다.

 

    // 3. 배열의 특정 값을 update
    // 수정해야 할 idx 를 포함하고 있는 경우만 update 함
    public static void update(int start, int end, int idx, int diff, int node){
        
        // 현재 노드의 범위에 포함되지 않음
        if(idx < start || end < idx) return;

        // 포함되는 경우
        int mid = (start + end) / 2;
        tree[node]-=diff;
        if(start == end) return;
        update(start, mid, idx, diff, node*2);
        update(mid+1, end, idx, diff, node*2+1);
    }

 

Segment Tree 활용 전체 코드

import java.util.*;

public class SegmentTree {

    // A.length = 12
    public static int A[] = {1,9,3,8,4,5,5,9,10,3,4,5};
    public static int tree[];

    // 1. 구간 합 트리 (Segment Tree) 초기화
    public static int init(int start, int end, int node){
        // 트리의 최상단에는 모든 구간 원소의 합이 들어감 ( 0 ~ 12 )
        // Segment Tree 의 최상단 index = 1
        if(start == end){
            tree[node] = A[start];
            return tree[node];
        } 

        int mid = (start + end) / 2;

        return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1);
    }

    // 2. Segment Tree 를 활용한 구간 합 구하기
    // start / end 는 현재 node 가 포함하는 합 범위
    // left / right 은 구해야 하는 범위
    public static int getSum(int start, int end, int left, int right, int node){

        // 현재 노드의 범위를 완전 벗어난 경우
        if(end < left || right < start) return 0;

        // 구하고자 하는 범위보다 현재 노드가 포함하는 범위가 좁을 경우
        if(left <= start && end <= right) return tree[node];

        int mid = (start + end) / 2;

        return getSum(start, mid, left, right, node*2) + getSum(mid+1, end, left, right, node*2+1);
    }

    // 3. 배열의 특정 값을 update
    // 수정해야 할 idx 를 포함하고 있는 경우만 update 함
    public static void update(int start, int end, int idx, int diff, int node){
        
        // 현재 노드의 범위에 포함되지 않음
        if(idx < start || end < idx) return;

        // 포함되는 경우
        int mid = (start + end) / 2;
        tree[node]-=diff;
        if(start == end) return;
        update(start, mid, idx, diff, node*2);
        update(mid+1, end, idx, diff, node*2+1);
    }

    public static void main(String[] args) throws Exception {
        
        // 배열을 활용하여 특정 구간의 합을 가장 빠르게 구하는 방법
        // Tree 형태를 활용한 구현
        tree = new int[A.length * 4];
        init(0, A.length -1, 1);    

        //printTree();

        printA();

        // 6 번째 원소를 100 으로 update
        // 1. update 함수 이용
        int idx = 6;
        int diff = A[6] - 100;
        A[6] = 100;
        update(0, A.length-1, idx, diff, 1);

        // printA();

        // System.out.println(getSum(0,A.length-1, 6,6,1));

        // 2. A 배열을 직접 업데이트하고 다시 init
        // A[6] = 100;
        // init(0,A.length-1, 1);
        // printA();

        System.out.println(getSum(0, A.length-1, 6,6,1));
        
    }


    public static void printTree(){
        for(int a : tree){
            System.out.print(a + " ");
        }
    }

    public static void printA(){
        System.out.print("A : ");
        for(int a : A){
            System.out.print(a + " ");
        }
        System.out.println();
    }
}

 

참고 사이트

https://m.blog.naver.com/ndb796/221282210534

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

https://www.acmicpc.net/blog/view/9

 

세그먼트 트리 (Segment Tree)

문제 배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제를 생각해봅시다. 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구해서 출력하기 i번째 수를 v로 바꾸기. A[i

www.acmicpc.net

관련 문제

https://www.acmicpc.net/problem/2042