๊ทธ๋ํ G์ spanning tree ์ค ๊ฐ์ค์น์ ํฉ์ด ์ต์์ธ spanning tree๋ฅผ ์๋ฏธํ๋ค.
- ๊ฐ์ ์ ๊ฐ์ค์น์ ํฉ์ด ์ต์์ฌ์ผ ํ๋ค.
- n๊ฐ์ ์ ์ ์ ๊ฐ์ง๋ ๊ทธ๋ํ์ ๋ํด ๋ฐ๋์ (n-1)๊ฐ์ ๊ฐ์ ๋ง์ ์ฌ์ฉํด์ผ ํ๋ค.
- ์ฌ์ดํด์ด ํฌํจ๋์ด์๋ ์๋๋ค.
spanning tree
: ๊ทธ๋ํ ๋ด์ ์๋ ๋ชจ๋ ์ ์ ์ ์ฐ๊ฒฐํ๊ณ ์ฌ์ดํด์ด ์๋ ๊ทธ๋ํ๋ฅผ ์๋ฏธํ๋ค.
ํ์์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก ๊ฐ์ค์น๋ฅผ ์ ํํ์ฌ ๋ชจ๋ ์ ์ ์ ์ต์ ๋น์ฉ์ผ๋ก ์ฐ๊ฒฐํ๋ ์ต์ ์ ํด๋ต์ ์ฐพ๋ ๋ฐฉ์์ด๋ค.
- ์ต์ ๋น์ฉ์ ๊ฐ์ ์ผ๋ก ๊ตฌ์ฑ๋๋ฉฐ ์ฌ์ดํด์ ํฌํจํ์ง ์๊ณ ๊ฐ ๋จ๊ณ์์ ์ฌ์ดํด์ ์ด๋ฃจ์ง ์๋ ์ต์ ๋น์ฉ ๊ฐ์ ์ ํ
- ๊ฐ์ ์ ํ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค.
- ์ด์ ๋จ๊ณ์์ ๋ง๋ค์ด์ง ์ ์ฅ ํธ๋ฆฌ์๋ ๊ด๋ จ์์ด ๋ฌด์กฐ๊ฑด ์ต์ ๊ฐ์ ๋ง์ ์ ํ
- union-find ์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํ๋ฉฐ ์๊ฐ ๋ณต์ก๋๋ ๊ฐ์ ์ ์ ๋ ฌํ๋ ์๊ฐ์ ์ข์ฐ๋๋ค.
- O(ElogE + E)
ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ ์ฝ๋
import sys
v, e = map(int, input().split())
# ๋ถ๋ชจ ํ
์ด๋ธ ์ด๊ธฐํ
parent = [0] * (v+1)
for i in range(1, v+1):
parent[i] = i
# ๋ถ๋ชจ๋ฅผ ์ฐพ๋ ์ฐ์ฐ
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
# union ์ฐ์ฐ
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
# ๊ฐ์ ์ ๋ณด ๋ด์ ๋ฆฌ์คํธ์ ์ต์ ์ ์ฅ ํธ๋ฆฌ ๊ณ์ฐ ๋ณ์ ์ ์
edges = []
total_cost = 0
# ๊ฐ์ ์ ๋ณด ์ฃผ์ด์ง๊ณ ๋น์ฉ์ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
for _ in range(e):
a, b, cost = map(int, input().split())
edges.append((cost, a, b))
# ๊ฐ์ ์ ๋ณด ๋น์ฉ ๊ธฐ์ค์ผ๋ก ์ค๋ฆ์ฐจ์ ์ ๋ ฌ
edges.sort()
# ๊ฐ์ ์ ๋ณด ํ๋์ฉ ํ์ธํ๋ฉด์ ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ ์ํ
for i in range(e):
cost, a, b = edges[i]
# find ์ฐ์ฐ ํ, ๋ถ๋ชจ๋
ธ๋ ๋ค๋ฅด๋ฉด ์ฌ์ดํด ๋ฐ์ X์ผ๋ฏ๋ก union ์ฐ์ฐ ์ํ -> ์ต์ ์ ์ฅ ํธ๋ฆฌ์ ํฌํจ!
if find_parent(parent, a) != find_parent(parent, b):
union_parent(parent, a, b)
total_cost += cost
print(total_cost)
์์ ์ ์ ์์ ๋ถํฐ ์ถ๋ฐํ์ฌ ์ ์ฅํธ๋ฆฌ ์งํฉ์ ๋จ๊ณ์ ์ผ๋ก ํ์ฅํด๋๊ฐ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์ ์ ์ ํ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ค.
- ์ด์ ๋จ๊ณ์์ ๋ง๋๋ ์ง ์ ์ฅํธ๋ฆฌ๋ฅผ ์ ์ ํ์ฅํ๋ ๋ฐฉ์์ด๋ค. (ํฌ๋ฃจ์ค์นผ์ ์ด์ ์ ์ฅํธ๋ฆฌ์๋ ์ฐ๊ด์ด ์์.)
- ์๊ฐ ๋ณต์ก๋์ ๊ฐ์ฅ ํฐ ์ํฅ์ ๋ฏธ์น๋ ๊ฒ์ ๊ฐ์ค์น๊ฐ ๊ฐ์ฅ ์์ ์ ์ ์ ์ฐพ์๋ด๋ ๊ฒ๊ณผ ์ธ์ ํ ์ ์ ์ ํ์์ด๋ค.
- ๋ชจ๋ ๋ ธ๋์ ๋ํด ํ์ O(V)์ด๊ณ ์ฐ์ ์์ํ ์ฌ์ฉ์ผ๋ก O(logV) ์ต์ข ์๊ฐ ๋ณต์ก๋๋ O(ElogV)์ด๋ค.
ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ ์ฝ๋
import heapq
import collections
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n, m = map(int,input().split()) # ๋
ธ๋ ์, ๊ฐ์ ์
graph = collections.defaultdict(list) # ๋น ๊ทธ๋ํ ์์ฑ
visited = [0] * (n+1) # ๋
ธ๋์ ๋ฐฉ๋ฌธ ์ ๋ณด ์ด๊ธฐํ
# ๋ฌด๋ฐฉํฅ ๊ทธ๋ํ ์์ฑ
for i in range(m): # ๊ฐ์ฑ ์ ๋ณด ์
๋ ฅ ๋ฐ๊ธฐ
u, v, weight = map(int,input().split())
graph[u].append([weight, u, v])
graph[v].append([weight, v, u])
# ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ
def prim(graph, start_node):
visited[start_node] = 1 # ๋ฐฉ๋ฌธ ๊ฐฑ์
candidate = graph[start_node] # ์ธ์ ๊ฐ์ ์ถ์ถ
heapq.heapify(candidate) # ์ฐ์ ์์ ํ ์์ฑ
mst = [] # mst
total_weight = 0 # ์ ์ฒด ๊ฐ์ค์น
while candidate:
weight, u, v = heapq.heappop(candidate) # ๊ฐ์ค์น๊ฐ ๊ฐ์ฅ ์ ์ ๊ฐ์ ์ถ์ถ
if visited[v] == 0: # ๋ฐฉ๋ฌธํ์ง ์์๋ค๋ฉด
visited[v] = 1 # ๋ฐฉ๋ฌธ ๊ฐฑ์
mst.append((u,v)) # mst ์ฝ์
total_weight += weight # ์ ์ฒด ๊ฐ์ค์น ๊ฐฑ์
for edge in graph[v]: # ๋ค์ ์ธ์ ๊ฐ์ ํ์
if visited[edge[2]] == 0: # ๋ฐฉ๋ฌธํ ๋
ธ๋๊ฐ ์๋๋ผ๋ฉด, (์ํ ๋ฐฉ์ง)
heapq.heappush(candidate, edge) # ์ฐ์ ์์ ํ์ edge ์ฝ์
return total_weight
print(prim(graph,1))