Educational Codeforces Round 64 Div.2 D

并查集

合法的简单路径只要三种情况,要么全是0边,要么全是1边,或者是先0后1的边。

于是我们可以把合法路径分成两种类型,一种是只过0的边或者只过1的边,一种是先过1再过0的边。

对于第一种情况,我们可以把某个点所在的0的联通块或者1的联通块大小统计出来,合法的第一种路径为联通块大小-1。

对于第二种情况,一定存在一个点他的一个邻边是0一个邻边是1,也就是中转点。那么根据乘法原理,合法的路径数就是联通块1大小x联通块2大小-1。

再综合起来看两种情况,对于第一种情况来说,一个点如果不能做中转点,那么它肯定是没有入度或者出度的点,那么他所在的0/1联通块必定有一个只有他本身。也就是1,所以第一种情况可以看成特殊的第二种情况。

综上来说,我们用某个点的两种联通块大小相乘再减去1,所得到的答案的意义为:从某个点出发经过改点中转(0->1)的路径数与从改点出发到达某个点且路径上全为0边或者全为1边的路径数之和。

最后所有点统计答案即可。

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define full(a, b) memset(a, b, sizeof a)
using namespace std;
typedef long long ll;
inline int lowbit(int x){ return x & (-x); }
inline int read(){
    int X = 0, w = 0; char ch = 0;
    while(!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    return w ? -X : X;
}
inline int gcd(int a, int b){ return a % b ? gcd(b, a % b) : b; }
inline int lcm(int a, int b){ return a / gcd(a, b) * b; }
template<typename T>
inline T max(T x, T y, T z){ return max(max(x, y), z); }
template<typename T>
inline T min(T x, T y, T z){ return min(min(x, y), z); }
template<typename A, typename B, typename C>
inline A fpow(A x, B p, C lyd){
    A ans = 1;
    for(; p; p >>= 1, x = 1LL * x * x % lyd)if(p & 1)ans = 1LL * x * ans % lyd;
    return ans;
}
const int N = 200005;
int parent[2][N], size[2][N];

int find(int i, int p){
    while(p != parent[i][p]) parent[i][p] = parent[i][parent[i][p]], p = parent[i][p];
    return p;
}

bool isConnect(int i, int p, int q){
    return find(i, p) == find(i, q);
}

void merge(int i, int p, int q){
    int pRoot = find(i, p), qRoot = find(i, q);
    if(pRoot == qRoot) return;
    if(size[i][qRoot] < size[i][pRoot]) swap(qRoot, pRoot);
    parent[i][pRoot] = qRoot;
    size[i][qRoot] += size[i][pRoot];
}

int main(){

    int n = read();
    for(int i = 0; i <= n; i ++){
        parent[0][i] = parent[1][i] = i;
        size[0][i] = size[1][i] = 1;
    }
    for(int i = 0; i < n - 1; i ++){
        int x = read(), y = read(), c = read();
        if(isConnect(c, x, y)) continue;
        merge(c, x, y);
    }
    ll ans = 0;
    for(int i = 1; i <= n; i ++){
        ans += size[0][find(0, i)] * 1LL *size[1][find(1, i)] - 1;
    }
    cout << ans << endl;
    return 0;
}
原文地址:https://www.cnblogs.com/onionQAQ/p/10814746.html