P2634 树上路径长度为3的倍数的点对数 点分治

在计算答案的时候维护一个数组num num[i]为当前所有点距离根距离%3的数量

则当前块的答案为num[0]*num[0]+2*num[1]*num[2]

#include<bits/stdc++.h>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN = 1e5 + 5;
const int MAXM = 1e5 + 5;
int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], ed = 1;
int cost[MAXM << 1];
inline void addedge(int u, int v, int c) {
        to[++ed] = v;
        cost[ed] = c;
        nxt[ed] = Head[u];
        Head[u] = ed;
}
inline void ADD(int u, int v, int c) {
        addedge(u, v, c);
        addedge(v, u, c);
}
int n, anser;
int sz[MAXN], f[MAXN], dep[MAXN], sumsz, root;
bool vis[MAXN];
int num[5];
void getroot(int x, int fa) {
        sz[x] = 1;
        f[x] = 0;
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa || vis[v]) {
                        continue;
                }
                getroot(v, x);
                sz[x] += sz[v];
                f[x] = max(f[x], sz[v]);
        }
        f[x] = max(f[x], sumsz - sz[x]);
        if (f[x] < f[root]) {
                root = x;
        }
}
void getdeep(int x, int fa) {
        num[dep[x] % 3]++;
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (v == fa || vis[v]) {
                        continue;
                }
                dep[v] = dep[x] + cost[i];
                getdeep(v, x);
        }
}
int calc(int x, int d) {
        num[0] = num[1] = num[2] = 0;
        dep[x] = d;
        getdeep(x, 0);
        int ansnow = num[0] * num[0] + 2 * num[1] * num[2];
        return ansnow;
}
void solve(int x) {
        anser += calc(x, 0);
        vis[x] = 1;
        int totsz = sumsz;
        for (int i = Head[x]; i; i = nxt[i]) {
                int v = to[i];
                if (vis[v]) {
                        continue;
                }
                anser -= calc(v, cost[i]);
                root = 0;
                sumsz = sz[v] > sz[x] ? totsz - sz[x] : sz[v];
                getroot(v, 0);
                solve(root);
        }
}
int gcd(int a, int b) {
        int t;
        while (b) {
                t = b;
                b = a % b;
                a = t;
        }
        return a;
}
int main() {
        scanf("%d", &n);
        memset(Head, 0, sizeof(Head));
        memset(vis, 0, sizeof(vis));
        ed = 1;
        int u, v, c;
        for (int i = 1; i < n; i++) {
                scanf("%d %d %d", &u, &v, &c);
                ADD(u, v, c);
        }
        root = 0, sumsz = f[0] = n;
        getroot(1, 0);
        solve(root);
        int chu = gcd(anser, n * n);
        printf("%d/%d
", anser / chu, n * n / chu);
        return 0;
}
View Code
原文地址:https://www.cnblogs.com/Aragaki/p/10478906.html