2887번: 행성 터널
첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이
www.acmicpc.net
최소 신장 트리를 구하면 되는 문제이다.
하지만 이전 문제들보다 어려운 점이 있다면, 노드(행성)들 사이에 연결된 간선에 대한 정보가 주어져 있지 않다.
각 노드들의 좌표들만 주어져 있고, 두 노드의 좌표를 사용해 두 노드 간의 간선 비용을 구할 수 있다.
행성의 개수 N이 최대 10만개이므로 모든 노드의 거리를 구해서 크루스칼 알고리즘을 사용하려면 O(N(N+1) / 2)의 시간이 걸리고, 이는
무조건 시간 초과가 발생한다.
밑 코드가 이렇게 구현할 경우 시간초과가 나는 코드이다.
첫번째 트라이 -> 시간 초과
# 2887번: 행성 터널
import sys
import heapq
input = sys.stdin.readline
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
n = int(input()) # 행성 수
parent = [i for i in range(n+1)] # 특정 노드의 루트 노드를 저장하는 리스트
graph = [[0]] # 행성의 3차원 좌표를 저장하는 리스트
for _ in range(n):
data = list(map(int, input().split()))
graph.append(data)
queue = [] # (거리, 시작점, 출발점)을 저장하는 우선순위 큐
# 우선순위 큐에 모든 점들 사이의 거리를 추가 -> 여기서 메모리초과 발생
for i in range(1, n+1):
for j in range(i+1, n+1):
a = graph[i]
b = graph[j]
dist = min(abs(a[0] - b[0]), abs(a[1] - b[1]),abs(a[2] - b[2]))
heapq.heappush(queue, (dist, i, j))
answer = 0
while queue:
dist, a, b = heapq.heappop(queue)
if find_parent(parent, a) != find_parent(parent, b):
union_parent(parent, a, b)
answer += dist
print(answer)
그렇다면 도대체 어떻게 간선 사이의 정보를 추려서, 크루스칼 알고리즘을 사용해야 할까?
이 부분이 떠오르지 않아 결국 해설을 봤다.
행성 A와 행성 B 사이 터널의 비용은
min(|xa-xb|, |ya-yb|, |za-zb|)이다.
이러한 특징을 이용해 고려할 간선의 개수를 줄일 수 있다.
어떠한 간선들이 추려질 후보가 될까?
당연히 x,y,z 관련 없이 비용이 가장 작은 간선이 선택될 확률이 높을 것이다.
그렇다면 x,y,z 기준으로 각각 정렬한 뒤, 인접한 행성들의 비용을 구해 간선을 구성하면 된다.
또한 정렬할 때, 정점을 알아야 하므로 행성 번호도 주어 구분한다.
x, y, z 좌표들이 각각 들어있는 리스트를 정렬한 후, 인접한 노드들에 대한 거리를 구한다.
예를 들어 x축만 고려한다고 했을, 때, 그 인접한 노드 사이의 거리들이 최소 신장 트리를 만드는 경우들이다.
x축, y축, z축 모두 인접한 노드 사이의 거리를 구한 후,
마지막으로 edge라는 리스트에 (거리, 연결된 노드a, 연결된 노드b) 튜플을 추가한다.
edges 리스트에는 총 3 * (N-1)개의 간선 정보가 들어가므로 시간복잡도는 충분하다.
그리고 edge리스트를 크루스칼 알고리즘을 수행하면서 최소 신장 트리를 만드는데 드는 비용을 구할 수 있다.
위 아이디어를 생각한 풀이.
# 2887번: 행성 터널
import sys
input = sys.stdin.readline
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
n = int(input()) # 행성 수
parent = [i for i in range(n)] # 특정 노드의 루트 노드를 저장하는 리스트
INF = int(1e9)
# 행성의 3차원 좌표를 x, y, z 각각 저장하는 리스트
x_coord = []
y_coord = []
z_coord = []
for i in range(n):
x, y, z = map(int, input().split())
x_coord.append((x, i))
y_coord.append((y, i))
z_coord.append((z, i))
x_coord.sort()
y_coord.sort()
z_coord.sort()
# (거리, 도시a, 도시a와 연결된 도시b) 튜플을 저장하는 리스트들
edges = []
for i in range(n-1):
edges.append((x_coord[i+1][0] - x_coord[i][0], x_coord[i][1], x_coord[i+1][1]))
edges.append((y_coord[i+1][0] - y_coord[i][0], y_coord[i][1], y_coord[i+1][1]))
edges.append((z_coord[i+1][0] - z_coord[i][0], z_coord[i][1], z_coord[i+1][1]))
edges.sort() # 거리를 기준으로 오름차순 정렬
answer = 0 # 모든 행성을 터널로 연결하는데 필요한 최소 비용
for edge in edges:
dist, a, b = edge
if find_parent(parent, a) != find_parent(parent, b):
union_parent(parent, a, b)
answer += dist
print(answer)