2017 Multi-University Training Contest

题解:

(并查集处理往上跳的时候,一定要先让u,v往上跳到并查集的祖先,不然会wa掉)

代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 100;
typedef long long LL;
int f[maxn], g[maxn], p[maxn], deep[maxn];
LL W[maxn];
int ffind(int x) { return f[x] == x ? f[x] : f[x] = ffind(f[x]); }
int gfind(int x) { return g[x] == x ? g[x] : g[x] = gfind(g[x]); }
struct Line{
    int u1, v1, u2, v2;
    int cost;
    bool operator <(const Line& B) const{
        return cost < B.cost;
    }
};
vector<int> G[maxn];
vector<Line> V;

void dfs(int x, int fa, int d){
    deep[x] = d;
    p[x] = fa;
    for(int i = 0; i < G[x].size(); i++){
        int to = G[x][i];
        if(to == fa) continue;
        dfs(to, x, d+1);
    }
}

void Merge(int u, int v, LL w){
    u = ffind(u); v = ffind(v);
    while(ffind(u) != ffind(v)){
        if(deep[u] < deep[v]) swap(u, v);
        int fa = ffind(u);
        u = p[fa];
        f[fa] = ffind(u);
        u = ffind(u);
        if(gfind(fa) != gfind(u)){
            W[gfind(u)] += (W[gfind(fa)] + w);
            g[gfind(fa)] = gfind(u);
        }
    }
}

int main()
{
    int T, n, m, x, y;
    cin>>T;
    while(T--){
        cin>>n>>m;
        memset(W, 0, sizeof(W));
        for(int i = 1; i <= n; i++) g[i] = f[i] = i;
        for(int i = 1; i <= n; i++) G[i].clear();
        V.clear();
        V.resize(m);
        for(int i = 1; i < n; i++){
            scanf("%d %d", &x, &y);
            G[x].push_back(y);
            G[y].push_back(x);
        }
        dfs(1, 1, 1);
        for(int i = 0; i < m; i++){
            scanf("%d %d %d %d %d", &V[i].u1, &V[i].v1, &V[i].u2, &V[i].v2, &V[i].cost);
        }
        sort(V.begin(), V.end());
        for(int i = 0; i < V.size(); i++){
            Line line = V[i];
            int u = line.u1, v = line.v1, lca1, lca2;
            Merge(u, v, line.cost);
            lca1 = ffind(u);
            u = line.u2, v = line.v2;
            Merge(u, v, line.cost);
            lca2 = ffind(u);
            if(gfind(lca1) != gfind(lca2)) {
                W[gfind(lca2)] += (W[gfind(lca1)] + line.cost);
                g[gfind(lca1)] = gfind(lca2);
            }
        }
        int num = 0;
        for(int i = 1; i <= n; i++) if(gfind(i) == gfind(1)) num++;
        cout<<num<<" "<<W[gfind(1)]<<endl;
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Saurus/p/7286439.html