P2634 [国家集训队]聪聪可可 (点分治)

题意:给定一棵树 任选两个点 求这两个点之间距离是3的倍数的概率

题解:点分治模板题

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e4 + 5;

int n, cnt, sum, ans, rt;
struct node {
    int to, nex, val;
}E[MAXN << 1];
int head[MAXN];
int vis[MAXN], sz[MAXN], maxs[MAXN], va[5], tmp[5];
int dis[MAXN];

int gcd(int x, int y) {
    if(y == 0) return x;
    return gcd(y, x % y);
}

void getrt(int x, int fa) {
    maxs[x] = 0; sz[x] = 1;
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(v == fa || vis[v]) continue;
        sz[x] += sz[v];
        maxs[x] = max(maxs[x], sz[v]);
        getrt(v, x);
    }
    maxs[x] = max(maxs[x], sum - sz[x]);
    if(maxs[x] < maxs[rt]) rt = x;
}

void getdis(int x, int fa) {
    tmp[dis[x]]++;
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(v == fa || vis[v]) continue;
        dis[v] = (dis[x] + E[i].val) % 3;
        getdis(v, x);
    }
}

void calc(int x) {
    va[1] = va[2] = 0; va[0] = 1;
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(vis[v]) continue;

        tmp[0] = tmp[1] = tmp[2] = 0;
        dis[v] = E[i].val % 3;
        getdis(v, x);
        ans += 2 * tmp[0] * va[0];
        ans += 2 * tmp[1] * va[2];
        ans += 2 * tmp[2] * va[1];
        va[0] += tmp[0];
        va[1] += tmp[1];
        va[2] += tmp[2];
    }
}

void solve(int x) {
    vis[x] = 1;
    calc(x);
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(vis[v]) continue;

        sum = sz[v]; maxs[rt = 0] = n + 5;
        getrt(v, -1);
        solve(rt);
    }
}

int main() {
    ans = 0;
    cnt = 0;
    scanf("%d", &n);
    for(int i = 1; i < n; i++) {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        E[++cnt].to = b; E[cnt].nex = head[a]; head[a] = cnt; E[cnt].val = c % 3;
        E[++cnt].to = a; E[cnt].nex = head[b]; head[b] = cnt; E[cnt].val = c % 3;
    }
    rt = 0; maxs[rt] = n + 5;
    sum = n;
    getrt(1, -1);
    solve(rt);
    int gg = gcd(ans + n, n * n);
    printf("%d/%d
", (ans + n) / gg, n * n / gg);
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/lwqq3/p/11631391.html