[프로그래머스] 섬 연결하기

2 minute read

문제 설명

n개의 섬 사이에 다리를 건설하는 비용(costs)이 주어질 때, 최소의 비용으로 모든 섬이 서로 통행 가능하도록 만들 때 필요한 최소 비용을 return 하도록 solution을 완성하세요.

다리를 여러 번 건너더라도, 도달할 수만 있으면 통행 가능하다고 봅니다. 예를 들어 A 섬과 B 섬 사이에 다리가 있고, B 섬과 C 섬 사이에 다리가 있으면 A 섬과 C 섬은 서로 통행 가능합니다.

제한사항

  • 섬의 개수 n은 1 이상 100 이하입니다.
  • costs의 길이는 ((n-1) * n) / 2이하입니다.
  • 임의의 i에 대해, costs[i][0] 와 costs[i][1]에는 다리가 연결되는 두 섬의 번호가 들어있고, costs[i][2]에는 이 두 섬을 연결하는 다리를 건설할 때 드는 비용입니다.
  • 같은 연결은 두 번 주어지지 않습니다. 또한 순서가 바뀌더라도 같은 연결로 봅니다. 즉 0과 1 사이를 연결하는 비용이 주어졌을 때, 1과 0의 비용이 주어지지 않습니다.
  • 모든 섬 사이의 다리 건설 비용이 주어지지 않습니다. 이 경우, 두 섬 사이의 건설이 불가능한 것으로 봅니다.
  • 연결할 수 없는 섬은 주어지지 않습니다.

풀이

40분
최소신장트리(MST)문제이다. 간선을 가중치의 오름차순으로 정렬해두고, 사이클을 형성하지 않으면 간선을 선택한다. 이 때, 사이클 형성 여부는 Union-find 방식을 사용하여 검사하였다. 문제를 풀면서는 효율성에 대한 고민을 그다지 하지 않았지만, 문제를 풀고 나서 union-find를 조금 더 고쳐보았다. find함수에서 경로 압축을하고, uni함수에서 트리의 크기에 따라 달리 합쳐서 최적화했다.

코드

// 04:15 ~ 04:53
#include <string>
#include <vector>
#include <algorithm>

using namespace std;

int root[100];

int find(int a){
    if(root[a] == a) return a;
    return find(root[a]);
}

bool uni(int a, int b){
    a = find(a);
    b = find(b);
    
    if(a == b) return false;
    
    root[b] = a;
    return true;
}

bool cmp(vector<int> a, vector<int> b){
    return a[2] < b[2];
}

int solution(int n, vector<vector<int>> costs) {
    int cnt = 0;
    int sum = 0;
    
    for(int i = 0; i < n; ++i) root[i] = i;
    sort(costs.begin(), costs.end(), cmp);
    
    for(vector<int> v : costs){
        if(uni(v[0],v[1])){
            ++cnt;
            sum += v[2];
        }
        if(cnt == n - 1) break;
    }
    
    return sum;
}

개선 코드

// 04:15 ~ 04:53
#include <string>
#include <vector>
#include <algorithm>

using namespace std;

int root[100];
int level[100];

int find(int a){
    if(root[a] == a) return a;
    return root[a] = find(root[a]);
}

bool uni(int a, int b){
    a = find(a);
    b = find(b);
    
    if(a == b) return false;
    
    if(level[a] == level[b]){
        root[a] = b;
        level[a]++;
    }
    else if(level[a] < level[b]){
        root[a] = b;
    }
    else{
        root[b] = a;
    }
    
    return true;
}

bool cmp(vector<int> a, vector<int> b){
    return a[2] < b[2];
}

int solution(int n, vector<vector<int>> costs) {
    int cnt = 0;
    int sum = 0;
    
    for(int i = 0; i < n; ++i){
        root[i] = i;
        level[i] = 0;
    } 
    sort(costs.begin(), costs.end(), cmp);
    
    for(vector<int> v : costs){
        if(uni(v[0],v[1])){
            ++cnt;
            sum += v[2];
        }
        if(cnt == n - 1) break;
    }
    
    return sum;
}