(프로그래머스) 합승택시요금.

문제

문제는 문제 사이트에 가서 확인 하시길 바랍니다.

나의 코드

import collections
def bfs(n, start, graph):
    priceArr = [0 for _ in range(n+1)]
    que = collections.deque([start])
    while que:
        cur = que.popleft()
        for nextNode in range(1, n+1):
            if graph[cur][nextNode] == 0 or nextNode == start:
                continue
            if priceArr[nextNode] == 0 or priceArr[nextNode] > priceArr[cur] + graph[cur][nextNode]:
                priceArr[nextNode] = priceArr[cur] + graph[cur][nextNode]
                que.append(nextNode)     
    return priceArr
    

def solution(n, s, a, b, fares):
    answer = 0
    graph = [[0 for _ in range(n+1)] for _ in range(n+1)]
    for x, y, price in fares:
        graph[x][y] = price
        graph[y][x] = price
    
    sPriceArr = bfs(n, s, graph)
    
    aSolo = sPriceArr[a]
    bSolo = sPriceArr[b]
    
    aWithB = float('inf')
    for mid in range(1, n+1):
        if s == mid:
            continue
        firstCourse = sPriceArr[mid]
        midPriceArr = bfs(n, mid, graph)
        secondCourse, thirdCourse = midPriceArr[a], midPriceArr[b] 
        
        if mid != a and secondCourse == 0:
            continue
        if mid != b and thirdCourse == 0:
            continue
        
        aWithB = min(aWithB, firstCourse + secondCourse + thirdCourse)

    return min(aSolo + bSolo, aWithB)

풀이 해설

  1. bfs를 이용하고자 하였습니다.
  2. 시작점 부터 모든 노드까지의 최소요금을 계산합니다.
  3. 그리고 환승 구간을 모두 확인하여 가장 적은 요금을 계산하였습니다.
  4. bfs를 무조건 n번 시행해야합니다.
  5. 통과는 하지만 너무 오래 걸립니다.

다른 사람의 더 좋은 코드

def solution(n, s, a, b, fares):
    answer = 0
    priceArr = [[20000001 for _ in range(n+1)] for _ in range(n+1)]
    
    for x, y, price in fares:
        priceArr[x][y] = price
        priceArr[y][x] = price
    
    for x in range(n+1):
        priceArr[x][x] = 0
    
    for i in range(1, n+1):
        for j in range(1, n+1):
            for k in range(1, n+1):
                if priceArr[j][k] > priceArr[j][i] + priceArr[i][k]:
                    priceArr[j][k] = priceArr[j][i] + priceArr[i][k]
                    
    answer = 20000001 * 2
    for mid in range(1, n +1):
        answer = min(answer, priceArr[s][mid] + priceArr[mid][a] + priceArr[mid][b])
    
    return answer

풀이 해설

나의 코드보다 압도적으로 좋은 코드이며 빠른 속도를 자랑한다. 스스로 생각해내기에는 솔직히 어려운 풀이 였다고 생각한다. 그래서 왜 이게 성립될 수 있는지에 대해 알아보자.

제일 중요한건 삼중 for문으로써 모든 점 사이의 최소 택시값이 정해진다는 것이다. 왜 저게 되는걸까?

삼중 for문에서 i는 지나는 점을 의미하고 j, k는 시작점과 도착점이다.

그럼 최소 하나의 ij-i의 길이도 가지고있고, i-k의 길이도 가지고 있어야 한다. 이것이 가능할까?

j, k 사이에 n개의 있다고 하자. 그 중에서 for문에서 가장 나중에 나오는 조합은 무엇일까?

j , k, n개의 노드의 값중 가장 큰 값이 i가 되는것이다.

예를 들어 1- 10-20-30-100-60-200 이런식으로 연결되어잇다고 생각하면

i= 200, j = 100, k = 1일때 가장 마지막으로 나오는 조합이다.

그럼 우리는 이때 200-100, 100-1의 거리를 알고 있어야한다.

이제 한번 시뮬레이션을 해보자.

  1. 200, 100, 1 => 지나가는 점을 100으로 잡고,
  2. 200, 60, 100 => 여기는 60이 200과 100과 인접해있기 때문에 바로 계산 가능
  3. 100, 30, 1 => 30과 100은 인접해 있기 때문에 해결
  4. 30, 20, 1 => 30은 20과 인접
  5. 20, 10, 1 => 10은 1과 20이 인접해있기때문에 해결

이런 식으로 해결이 가능합니다.

여기서 중요한점!!! 아래에서 위의 순서로 i, j, k 조합이 for문에 등장할 것입니다.

  1. i = 100, j = 200, k = 1
  2. i = 60, j = 200, k = 100
  3. i = 30, j = 100, k = 1
  4. i = 20, j = 30, k = 1
  5. i = 10, j = 20, k = 1

이렇게 되면 5 -> 1의 순서로 조합이 등장하기 때문에 무조건 1번의 값이 정확하게 정해지게 됩니다.

이것이 되는 이유는 중간 점을 가장 큰 점으로 잡았기 때문에 아래 단계에서 발생하는 모든 경우의 수보다 늦게 등장하게 되기 때문이다.