본문 바로가기
알고리즘!/Graph

[C/C++] 백준 1197: 최소 스패닝 트리 (Kruskal 알고리즘)

by soeayun 2023. 10. 10.

1197: 최소스패닝 트리

https://www.acmicpc.net/problem/1197

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

 

최소 스패닝 트리는 크게 프림의 알고리즘과 크루스칼 알고리즘을 통해 구현이 가능하다. 이 문제에서는 크루스칼 알고리즘을 이용하여 최소 스패닝 문제를 해결했다!

 

크루스칼(Kruskal 알고리즘)?

 크루스칼 알고리즘의 가장 기본적인 아이디어는 사이클이 만들어지지 않는 한에서 간선을 가중치의 오름차순으로 정렬한 뒤, 스패닝 트리에 하나씩 추가하는 것이다. 여기서 중요한 것은 사이클이 만들어지면 안되기 때문에 사이클이 만드는 간선들은 고려하지 않아야한다. 그리고 사이클이 생기는지 확인하기 위해서는 (A.)간선의 양 끝 점이 같은 컴포넌트에 속해 있는지 확인해야하고 (B.)만약 같은 컴포넌트에 속해 있지 않다면 두 집합을 합쳐여한다. 이를 위해서는 먼저 상호 배타적 집합 자료 구조에 대해서 알아야한다.

 

상호배타적 집합?

상호배타적 집합이란?

 상호배타적 집합은 크게 초기화,합치기,찾기 연산으로 이루어져있고 (A.)합치기(union)(B.)찾기 (find) 두 연산을 지원하기 때문에 유니언-파인드(union-find) 자료구조라고 불리기도 한다. 이때 찾기(find) 연산은 어떤 원소 x가 주어졌을 때 이 원소가 속한 집합을 반환하는 것이고 합치기(union) 연산은 두 원소 x,y가 주어졌을 때 이들이 속한 두 집합을 하나로 합치는 역할을 한다.

 

상호배타적 집합의 표현

 상호배타적 집합을 표현할 때 가장 기본이 되는 것은 한 집합에 속하는 원소들은 하나의 트리로 묶어주어 결국 트리들의 집합으로 나타내는 것이다. 이때 임의의 두 원소 x,y가 같은 트리에 속해있는지 확인하는 방법은 각 원소 x,y의 부모노드, 즉 루트를 찾아 같은지 확인하면 된다. 만약 루트가 같다면 두 원소는 같은 트리에 있는 것이고 루트가 다르다면 두 원소는 다른 트리에 존재하는 것이 된다. 따라서 (B.) 찾기 연산을 할 때에는 주어진 원소가 포함된 트리의 루트를 찾는 것으로 구현이 된다. 

 (A.) 합치기 연산각 트리의 루트를 찾은 뒤, 하나의 루트를 다른 하나의 루트의 자손으로 넣으면 된다. 이를 이용해 구현한 코드는 아래와 같다

 

struct DisjointSet{
  vector<int> parent; 
  DisjointSet(int n):parent(n+1){ 
    for(int i=0;i<=n;i++)
      parent[i]=i;
  }

  int find(int u) { //현재 트리에서 부모노드(root)를 찾는 과정
    if(u==parent[u]) //root는 자기 자신을 가르키고 있음
      return u;
    return find(parent[u]);
  }

  void merge(int u,int v){
    u=find(u),v=find(v); //각 노드의 부모를 찾음
    if(u==v) return; //이미 u와 v가 같은 트리에 속하는 경우 merge 하지 않음  
    parent[u]=v; //한쪽의 노드가 다른쪽 노드의 자식이 됨!! => 하나의 트리로 들어가게 됨   
  }
};

 

상호배타적 집합의 최적화

 하지만 위의 코드에는 단점이 존재한다. 바로 무작위로 두 트리를 합치는 과정에서 트리의 모양이 한쪽으로 지나치게 치우칠 수 있다는 것이다. 가장 극단적으로는 n개의 원소에 대하여 높이가 n-1인 트리가 만들어질 수 있고 이 경우 합치기 연산과 찾기 연산 모두 O(n)의 시간복잡도를 갖게 된다. 따라서 합치기 연산과 찾기 연산에서의 시간을 줄이기 위해 2가지 최적화가 필요하다.

 

1. 기울어진 트리 해결하기

  한쪽으로 치우친 트리를 해결하기 위해서는 항상 높이가 더 낮은 트리를 더 높은 트리 밑에 집어넣어야 한다. 이러기 위해서는 각 트리의 루트 노드의 높이가 얼마인지를 저장하는 새로운 vector rank가 필요하다. 두 노드를 합칠 때 높이를 비교해서 낮은 쪽을 트리의 서브트리로 포함시키는 반면에(합쳐진 루트의 높이는 변함 없음) 두 트리의 높이가 같은 경우에는 트리의 높이가 1 커지므로 트리의 높이를 1 늘려줘야 한다. 

 

 2. 중복된 find 연산

 위 코드 find() 부분을 자세히 보게 된다면 연산이 중복된 계산을 여러번 하고 있다는 것을 알 수 있다. find(u)를 통해 u가 속한 트리의 루트를 찾아냈다고 할 때 parent[u]를 찾아낸 루트로 아예 바꿔버리면 재귀함수가 반복되면서 그 과정 속에 지나갔던 모든 노드들의 parent값이 루트가 될 것이다. 이렇게 된다면 다음번에 find(u)를 호출했을 때 경로를 따라 올라갈 것 없이 바로 루트를 찾을 수 있다.

 이 두가지 최적화의 수행을 통해 시간복잡도를 O(logn)이하로 만들 수 있다. 아래는 최적화된 상호배타적 집합 코드이다.

struct DisjointSet{
  vector<int> parent,rank; //항상 높이가 더 낮은 트리를 더 높은 트리 밑에 집어 넣어 높이가 높아지는 것을 방지=> O(n)이 되는 경우를 방지!!
  DisjointSet(int n):parent(n+1),rank(n+1,1){ //처음 초기화 해준건가??
    for(int i=0;i<=n;i++)
      parent[i]=i;
  }

  int find(int u) { //현재 트리에서 부모노드(root)를 찾는 과정
    if(u==parent[u]) //root는 자기 자신을 가르키고 있음
      return u;
    return parent[u]=find(parent[u]);
  }

  void merge(int u,int v){
    u=find(u),v=find(v); //각 노드의 부모를 찾음
    if(u==v) return; //이미 u와 v가 같은 트리에 속하는 경우 merge 하지 않음
    if(rank[u]>rank[v]) //u가 더 heigh가 낮게 되어 v와 합쳐졌을 때 검색에 시간이 오래 걸리지 않도록
      swap(u,v); //두개 바꾸는게 있었네..     
    parent[u]=v; //한쪽의 노드가 다른쪽 노드의 자식이 됨!! => 하나의 트리로 들어가게 됨
    if(rank[u]==rank[v]) ++rank[v]; //두 트리의 높이가 같을 때에는 두 트리의 높이를 1늘려줘야함
    //나머지 경우에는 v의 높이에 맞게 저장됨
  }
};

 

크루스칼 알고리즘 구현!

 먼저 main 함수에서 인접리스트를 저장한다. 이후 kruskal() 함수에서 새로운 edge_sort 함수를  vector<pair<int,pair<int,int>>> 로 정의하는데 이 함수는 저장한 인접리스트를 받아 간선의 가중치가 작은 순서대로 정렬하기 위해 가장 처음 변수에 가중치를 받고 두번째와 세번째 변수에 노드를 저장한다. 이후 sort()함수를 이용하여 오름차순으로 정렬한다.

vector<pair<int,pair<int,int>>> edge_sort;  //kruskal이  cycle이 되지 않는 가중치 최소 간선들을 구하는거이기 때문에 sort를 해줘야함 
  for(int u=1;u<=V;u++)
    for(int j=0;j<adj[u].size();j++){
      int v=adj[u][j].first, value=adj[u][j].second;
      edge_sort.push_back(make_pair(value,make_pair(u,v)));
    }
    sort(edge_sort.begin(),edge_sort.end());

 

이후 구조체 DisjointSet sets(V+1)을 선언하여 상호배타적 집합 sets를 정의한다.

그 후 edge_sort vector을 모두 탐색하며 만약 노드 u와 노드 v가 같은 트리에 있는지, 같은 트리에 없다면 두 노드가 속한 트리를 합쳐 최소스패닝 트리를 계속해서 만들면 된다. 최소 스패닝 트리가 될 때만 가중치들을 더하게 된다면 가중치의 합이 최소인 트리를 만들 수 있다. 

 

소스코드!

 

#include <iostream>
#include <algorithm>
#include<bits/stdc++.h>
#include<queue>
#include<limits.h>


using namespace std; //모든 식별자가 고유하도록 보장하는 코드 영역
int V;
vector<pair<int,int>>adj[10001]; //그래프의 인접리스트
vector<pair<int,int>> selected; //최소 스패닝 트리에 포함된 간선의 목록을 저장

struct DisjointSet{
  vector<int> parent,rank; //항상 높이가 더 낮은 트리를 더 높은 트리 밑에 집어 넣어 높이가 높아지는 것을 방지=> O(n)이 되는 경우를 방지!!
  DisjointSet(int n):parent(n+1),rank(n+1,1){ //처음 초기화 해준건가??
    for(int i=0;i<=n;i++)
      parent[i]=i;
  }

  int find(int u) { //현재 트리에서 부모노드(root)를 찾는 과정
    if(u==parent[u]) //root는 자기 자신을 가르키고 있음
      return u;
    return parent[u]=find(parent[u]);
  }

  void merge(int u,int v){
    u=find(u),v=find(v); //각 노드의 부모를 찾음
    if(u==v) return; //이미 u와 v가 같은 트리에 속하는 경우 merge 하지 않음
    if(rank[u]>rank[v]) //u가 더 heigh가 낮게 되어 v와 합쳐졌을 때 검색에 시간이 오래 걸리지 않도록
      swap(u,v); //두개 바꾸는게 있었네..     
    parent[u]=v; //한쪽의 노드가 다른쪽 노드의 자식이 됨!! => 하나의 트리로 들어가게 됨
    if(rank[u]==rank[v]) ++rank[v]; //두 트리의 높이가 같을 때에는 두 트리의 높이를 1늘려줘야함
    //나머지 경우에는 v의 높이에 맞게 저장됨
  }
};

int kruskal(){
  int result=0;
  vector<pair<int,pair<int,int>>> edge_sort;  //kruskal이  cycle이 되지 않는 가중치 최소 간선들을 구하는거이기 때문에 sort를 해줘야함 
  for(int u=1;u<=V;u++)
    for(int j=0;j<adj[u].size();j++){
      int v=adj[u][j].first, value=adj[u][j].second;
      edge_sort.push_back(make_pair(value,make_pair(u,v)));
    }
  sort(edge_sort.begin(),edge_sort.end());
  DisjointSet sets(V+1); //노드의 개수가 V개이므로 구조체도 V만큼 초기화
  for(int i=0;i<edge_sort.size();i++){
    int value=edge_sort[i].first; 
    int u=edge_sort[i].second.first, v=edge_sort[i].second.second; //pair을 여러번 했을 때 first와 second를 여러번 써줄 수 있음
    if(sets.find(u)==sets.find(v)) //이미 두 노드가 같은 트리에 있다면 무시한다
      continue;
    //cout<<"U & V are"<<u<<" "<<v<<"\n";
    sets.merge(u,v); //같은 트리에 없다면 합친다
    //selected.push_back(make_pair(u,v)); //지나간 노드 저장
    result+=value;
  }
  return result;
  
}

int main() {    
  ios::sync_with_stdio(0);
  cin.tie(0);
 int k;
  cin>>V>>k;
  for(int i=0;i<k;i++){
    int tmp1,tmp2,tmp3;
    cin>>tmp1>>tmp2>>tmp3;
    adj[tmp1].push_back(make_pair(tmp2,tmp3));
    adj[tmp2].push_back(make_pair(tmp1,tmp3));
  }
 
  int result=kruskal();
  cout<<result;
}