HDU 5923 Prediction

这题是2016 CCPC 东北四省赛的B题, 其实很简单. 现场想到的就是正解, 只是在合并两个并查集这个问题上没想清楚.

做法

并查集合并 + 归并

  1. 对每个节点 $u$, 将 $u$ 到根的那些边添到一个初始为空的并查集中, 得到的并查集记作 $a_u$.
  2. 询问相当于将 $k$ 个并查集合并. 采用二路归并, 合并次数是 $O(n cdot log(n))$.
    $ n/2 + n/4 + n/8 + dots + 1 = O(n cdot log(n)) $

合并两个并查集

详细讨论将并查集 $B$ 合并到并查集 $A$ 中这一问题.
这个问题与

给定两无向图 $A, B, V_B subset V_A; quad A(E_A, V_A) o A'( E_A, E_A cup E_B) $.

等价.

做法

$ forall u in E_B, quad A.mathrm{unite}(u, B.mathrm{root}(u)) $

正确性

只要验证

在$B$中连通的任意两点 $u, v$, 在$ A'$中也连通.

是否满足.

Implementation

#include <bits/stdc++.h>
using namespace std;

const int N{1<<9};
const int M=1e4+5;

int n, m;

struct DSU{
    int par[N];
    int cnt;

    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }

    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }

    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }

    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }

    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};

DSU a[M], b[M];

vector<int> g[M];

struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];

void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);

    for(auto v: g[u]){
        dfs(v, u);
    }
}



void solve(int n){
    for(int i=1; i<n; i<<=1){   // error-prone
        for(int j=0; j+i<n; j+=i<<1){
            b[j].unite(b[j+i]);
        }
    }
    printf("%d
", b[0].cnt);
}

// int par[M];

int main(){

    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:
", ++cas);
        // int n, m;
        cin>>n>>m;

        for(int i=1; i<=m; ++i){
            g[i].clear();
        }

        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }

        for(int i=1; i<=m; ++i){
            E[i].read();
        }

        a[0].init();
        dfs(1, 0);

        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);
            for(int i=0; i<k; i++){
                int x;
                scanf("%d", &x);
                b[i].copy(a[x]);
            }
            solve(k);
        }
    }
    return 0;
}

Pitfalls

归并

for(int i=1; i<n; i<<=1){   // error-prone
    for(int j=0; j+i<n; j+=i<<1){
        b[j].unite(b[j+i]);
    }
}

容易写错.

我第一发是这样写的

for(int i=2; i<=n; i<<=1){
    for(int j=0; j+i/2<n; j+=i){
        b[j].unite(b[j+i/2]);
    }
}

n==3时, 只做了1轮归并.

应采纳第一种写法, 很清楚.


UPD
太SB了.

  1. 根本不用归并, 直接逐个合并就好了.
  2. 根本不用 b[i].copy(a[x]); , 只要从一个边集为空的图 (以下简称"空图") 开始, 不断把$k$个并查集合并进去就好了.
  3. 不从空图开始, 而从某个并查集开始, 会快很多.
#include <bits/stdc++.h>
using namespace std;

const int N{1<<9};
const int M=1e4+5;

int n, m;

struct DSU{
    int par[N];
    int cnt;

    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }

    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }

    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }

    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }

    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};

DSU a[M], b[M];

vector<int> g[M];

struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];

void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);

    for(auto v: g[u]){
        dfs(v, u);
    }
}



int solve(int n){
    if(k==0){
        return n;
    }
    int x;
    scanf("%d", &x);
    a[0].copy(a[x]);
    for(int i=1; i<n; i++){
        scanf("%d", &x);
        a[0].unite(a[x]);
    }
    return a[0].cnt;
}

int main(){

    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:
", ++cas);

        cin>>n>>m;

        for(int i=1; i<=m; ++i){
            g[i].clear();
        }

        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }

        for(int i=1; i<=m; ++i){
            E[i].read();
        }

        a[0].init();
        dfs(1, 0);

        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);        
            printf("%d
", solve(k));
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Patt/p/5971439.html