본문 바로가기

프로그래밍

[DP] 도시의 집

DP 연습 문제 : 도시의 집

문제

한 나라에 $n $개의 도시가 있다. 각 도시들은 $1, 2, 3, \dotsc n$ 의 번호가 주어진다. 또한 각 도시들은 $a[i]$개의 집이 있다. $q$ 개의 쿼리가 주어진다면, 도시 $l$ 과 도시 $r$ 사이에 얼마나 많은 집이 있는지 알려주어야 한다.

입력값 :

  • 테스트 케이스 ,$t$
  • 도시의 개수 , $n$
  • 각 도시의 집의 개수 $a[i]$
  • 쿼리의 개수, $q$
  • 도시 $l$ 과 $r$

제약:

  • $1<=t<=10$
  • $1<=n<=100000$
  • $1<=a[i]<=10^9$
  • $1<=q<=100000$
  • $1<=l<=r<=n$

예시

입력값

1
5
4 3 2 1 5
3
1 2
2 4
1 4

출력값

7
6
10

설명

쿼리(q) 가 총 3개이고 ,첫번째 쿼리는 $l = 1, r = 2$ 이므로 첫번째 도시에 4개의 집 , 두번째 도시의 3개의 집이 있으므로, $3+4 = 7$. 두번째 쿼리는 $l = 2, r= 4$ 이므로 두번째 도시에 3개 , 세번째 도시에 2개 , 네번째 도시에 1개의 집이 있으므로, $3 +2 + 1 = 6$. 세번째 쿼리는 $l=1, r=4$ 이므로 $4+3+2+1=10$.

주의 : 정말 큰 숫자가 입력값으로 오기때문에 빠른 입출력이 필요합니다.
문제를 먼저 여기서 풀어보세요.

정답

처음 문제를 봤을때는 ''뭐야 이문제..'' 하면서 DP 고 뭐고 loop 를 사용해서 구했습니다. 쿼리를 입력받으면 for loop 를 사용해서 l 부터 r 까지 돌렸죠. 역시나 이렇게 되면 매번 쿼리를 만날때마다 l-r+1번의 계산이 필요해서 시간복잡도가 $O(n)$ 만큼 필요하죠. 결국 TLE(Time Limit Exceeded) 이 났습니다.

그래서 어떻게 할까 하다가, 조금더 빠른 세그먼트 트리를 이용한 방법을 생각해 냈습니다. 이 계산 방법은 $O(log(n))$ 의 시간복잡도를 가집니다. 세그먼트 트리에 대한 자세한 내용은 여기를 참조해주세요. 아래 코드는 세그먼트 트리를 이용한 python 코드의 일부분 입니다.

T = int(sys.stdin.readline())

for _ in range(T):
    numCity = int(sys.stdin.readline())
    house = [0] * (numCity + 1)
    ans = []

    house = list(map(int, sys.stdin.readline().split(" ")))
    treeSize = 2 * pow(2,math.ceil(math.log(numCity, 2))) - 1
    st = [0] * treeSize
    
    #세그먼트 트리를 만듭니다.
    construction(house, st, 0, numCity - 1 , 0)

    query = int(sys.stdin.readline())
    for __ in range(query):
        start, end = list(map(int, sys.stdin.readline().split(" ")))
        #세그먼트 트리를 이용해서 주어진 쿼리에 적힌 범위의 합을 구합니다.
        ans.append(str(getSum(st, start-1, end-1, 0, numCity-1, 0)))
    sys.stdout.write("\n".join(ans))

세그먼트 트리를 만드는 방법이나 주어진 쿼리에 적힌 범위의 합을 구하는 방법은 앞서 말씀드린 포스팅에 적힌 내용과 다른점이 없습니다.

하지만 이것도 TLE 를 받았습니다. 더 빠른 계산을 원하는것 같습니다.$O(log(n))$ 도 느리다고 생각하는것 같네요. 이 코드는 이제 쓸모없어졌지만 그래도 주목할만한 점은 있습니다.
바로 입력값을 읽기 위한 sys.stdin.readline() 함수와 출력값을 적기 위한 sys.stdout.write() 함수 입니다.
python 3 에서 일반적으로 input() 을 이용해 한줄씩 입력을 받는것이 좀더 편하지만, 위 같이 적게되면 훨씬 크고 많은 값들을 빠르게 입력받는것이 가능합니다.

그러면 이제부터 시간복잡도가 $O(1)$ 인 방법을 찾아보겠습니다. 사실 세그먼트 트리를 이용한 방법보다 훨씬 간단합니다. 말 그대로 DP 를 이용하는 방법인데요. 주어진 도시 갯수 $n$ 보다 $1$ 개 더 많은 크기의 DP 배열을 만들고, 그곳에다가 범위의 합을 저장하는것입니다. 예를 들면, $i$ 번째 도시가 가진 집의 갯수를 $a[i]$ 라 한다면, $DP[i]$ 는 $1$ 번째 도시부터 $i$ 번째까지 도시의 집 갯수를 전부 합한 값을 저장하게 되는것입니다.

조금 더 구체적으로 예를 들어보겠습니다. 앞서 문제에 적힌 예제를 통해서 말이죠. 총 도시의 개수는 $n = 5$ 이고, 각 도시가 가진 집의 개수는 4 3 2 1 5 으로 입력이 주어졌으므로 a = [4, 3, 2, 1, 5] 로 설정해보겠습니다. 그렇다면, base case로 DP[0]= 0 이라 설정해놓고, $i \geq1$ 인 경우DP[i] = a[i] + DP[i-1] 로 계산을 하게됩니다. 즉, $i-1$ 번째 도시까지 도시들이 가진 모든 집을 더한 값에 현재 $i$ 번째 도시가 가진 집을 더하면 $DP[i]$ 가 완성 되면서, 이 값은 $i$ 번째 도시까지 도시들이 가진 모든 집을 더한 값이 되는것입니다. 완성된 DP 배열은 위 예제에 한에서 다음과 같을것입니다. DP = [0, 4, 7, 9, 10, 15]

이제 $DP$ 가 완성되었으니 어떻게 하면 한번에 합을 구할까요? 어렵게 생각하실 필요 없습니다. $r$ 번째 도시까지 도시들이 가진 모든 집을 더한 값에 $l-1$ 번째 도시까지 도시들이 가진 모든 집을 더한 값을 빼주면 $l$ 번째 도시부터 $r$ 번째 도시까지 도시들이 가진 모든 집을 더한 값이 나타나게 됩니다. 다음은 python 으로 작성한 시간 복잡도가 $O(1)$ 인 코드 입니다.

T = int(sys.stdin.readline())

# use DP
for _ in range(T):
    numCity = int(sys.stdin.readline())
    house = [0] * (numCity + 1)
    ans = []
    house = list(map(int, sys.stdin.readline().split(" ")))

    dp = [0] * (numCity + 1)
    dp[0] = 0
    for i in range(1, numCity + 1 ) :
       dp[i] = dp[i-1] + house[i-1]

    query = int(sys.stdin.readline())
    for __ in range(query):
        start, end = list(map(int, sys.stdin.readline().split(" ")))
        sys.stdout.write(str(dp[end]-dp[start-1]) +'\n')

이전의 코드와 달리 출력하는 방식을 조금 바꿨습니다. 나머지 코드는 위의 설명과 동일하게 작성되었습니다.


이 문제 풀겠다고 반나절을 꼬박 투자했네요. 실력을 좀더 키워야겠습니다.