개념정리

[트리] MST, Union Find, Kruskal Algorithm, Prim Algorithm

프로버티기 2025. 2. 8. 20:54

MST

Spanning Tree : 최소한의 간선을 사용하여 그래프 내 모든 정점을 이어주기, N-1개의 간선

Union Find 개념으로 간선을 선택했을 때 사이클이 일어나는지 확인

 

MST는 가중치의 합을 최소로 하는 Spanning Tree


Kruskal Algorithm

전체에서 가중치가 작은 간선부터 고르며, 선택한 간선으로 인해 사이클이 발생하지는 않는지 확인

최종적으로 선택된 간선의 수 N-1개, N-1개의 간선이 MST를 이룸

 

- 간선 정렬 : O(ElogE); 그래프의 간선은 최대 V(V-1)/2 개이므로 O(ElogV)로 표현하기도 한다 

- 각 간선에 대해 Union - Find : O(logN)

 

따라서, 시간 복잡도 : O(ElogE), 희소그래프의경우 E와 V는 비슷하여 O(VlogV)로 단순화될 수 있다 

 

1. 간선을 가중치 기준으로 오름차순 정렬한다

2. 각각의 간선에 대해 간선을 이루고 있는 두 노드 u, v를 보며

u, v의 루트 노드가 다른 경우에만 mst에 간선을 넣어주고 

u, v를 같은 루트 노드를 갖도록 만들어준다 

 

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main {
	static int N, E;
	static int[] parent;
	static PriorityQueue<Node> pq;

	static class Node implements Comparable<Node> {
		int a, b, cost;

		Node(int a, int b, int cost) {
			this.a = a;
			this.b = b;
			this.cost = cost;
		}

		@Override
		public int compareTo(Node node) {
			return this.cost - node.cost;
		}
	}

	public static void main(String[] args) throws IOException {
		System.setIn(new FileInputStream("src/input.txt"));
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		N = Integer.parseInt(st.nextToken());
		E = Integer.parseInt(st.nextToken());
		pq = new PriorityQueue<>();
		for (int e = 0; e < E; e++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int cost = Integer.parseInt(st.nextToken());
			pq.add(new Node(a, b, cost));
		}

		parent = new int[8];
		for (int i = 1; i <= 7; i++) {
			parent[i] = i;
		}
		int sum = 0;
		while (!pq.isEmpty()) {
			Node now = pq.poll();
			int parentA = find(now.a);
			int parentB = find(now.b);
			if (parentA != parentB) {
				union(now.a, now.b);
				sum += now.cost;
			}
		}
		System.out.println(sum);
	}

	static int find(int x) {
		if (parent[x] == x) {
			return x;
		}
		return parent[x] = find(parent[x]);
	}

	static void union(int x, int y) {
		x = find(x);
		y = find(y);
		parent[x] = y;
	}

}

Prime Algorithm

한 지점에서 시작하여 점점 확장을 진행하는 방법이다 

아무 정점에서나 시작하면 된다 

현재까지의 비용과 새로운 간선의 비용을 비교해서 갱신하면 된다 

정점 V개와 간선 E개의 그래프에서 정점 삽입/삭제에 O(logV), 인접 간선 탐색에 O(E)가 필요하다 

따라서, 시간 복잡도가 O((V+E)logV)가 나온다 

 

1. 비용 배열을 INF로 초기화하여 출발지의 값만 0으로 설정한다

처음 해당 노드가 선택되어야만 MST를 만드는 것을 시작할 수 있기 때문이다 

2.우선 순위 큐에 0으로부터 갈 수 있는 모든 노드를 구해 넣는다 

cost[v]와 length(u, v)를 비교하여 갱신한다 

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main {

	static int V, E; // 정점의 개수, 간선의 개수
	static List<Node>[] adj;

	public static void main(String[] args) throws IOException {
		System.setIn(new FileInputStream("src/input.txt"));
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		V = Integer.parseInt(st.nextToken());
		E = Integer.parseInt(st.nextToken());

		// 인접 리스트 초기화
		adj = new ArrayList[V + 1];
		for (int i = 0; i <= V; i++) {
			adj[i] = new ArrayList<>();
		}

		// 간선 정보 입력
		for (int e = 0; e < E; e++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int c = Integer.parseInt(st.nextToken());
			adj[a].add(new Node(b, c));
			adj[b].add(new Node(a, c));
		}

		// 최소 스패닝 트리 계산
		System.out.println(prim());
	}

	static int prim() {
		int[] dist = new int[V + 1]; // 각 정점까지의 최소 비용
		boolean[] visited = new boolean[V + 1]; // 방문 여부
		Arrays.fill(dist, Integer.MAX_VALUE);

		PriorityQueue<Node> pq = new PriorityQueue<>();
		pq.add(new Node(1, 0)); // 시작 정점 (1번 정점)
		dist[1] = 0;

		int cost = 0;

		while (!pq.isEmpty()) {
			Node now = pq.poll();

			// 이미 MST에 포함된 정점은 스킵
			if (visited[now.to])
				continue;

			visited[now.to] = true; // 현재 정점 방문
			cost += now.cost;

			// 인접한 정점들 갱신
			for (Node next : adj[now.to]) {
				if (!visited[next.to] && dist[next.to] > next.cost) {
					dist[next.to] = next.cost;
					pq.add(new Node(next.to, next.cost));
				}
			}
		}

		return cost; // MST의 총 비용 반환
	}

	static class Node implements Comparable<Node> {
		int to, cost;

		Node(int to, int cost) {
			this.to = to;
			this.cost = cost;
		}

		@Override
		public int compareTo(Node o) {
			return this.cost - o.cost;
		}
	}
}