[LUOGU] P2634 [国家集训队]聪聪可可

点分治裸题,甚至不需要栈回撤。

尝试用容斥写了一波,就是把所有子树混一块计算,最后减去子树内路径条数。

#include<iostream>
#include<cstring>
#include<cstdio>

using namespace std;

inline int rd(){
    int ret=0,f=1;char c;
    while(c=getchar(),!isdigit(c))f=c=='-'?-1:1;
    while(isdigit(c))ret=ret*10+c-'0',c=getchar();
    return ret*f;
}

const int MAXN=20005;

struct Edge{
    int next,to,w;
}e[MAXN<<1];
int ecnt,head[MAXN];
inline void add(int x,int y,int w){
    e[++ecnt].next = head[x];
    e[ecnt].to = y;
    e[ecnt].w = w;
    head[x] = ecnt;
}

int n,m;
bool vis[MAXN];
int siz[MAXN];
void getsiz(int x,int pre){
    siz[x]=1;
    for(int i=head[x];i;i=e[i].next){
        int v=e[i].to;
        if(vis[v]||v==pre) continue;
        getsiz(v,x);
        siz[x]+=siz[v];
    }
}
int root,mn;
void getroot(int x,int pre,int tot){
    int mx=0;
    for(int i=head[x];i;i=e[i].next){
        int v=e[i].to;
        if(vis[v]||v==pre) continue;
        mx=max(mx,siz[v]);
        getroot(v,x,tot);
    }
    mx=max(mx,tot-siz[x]);
    if(mx<mn) mn=mx,root=x;     
}
int f[8];
int s[MAXN];

void dfs(int x,int pre,int dis){
    s[++s[0]]=dis%3;
    for(int i=head[x];i;i=e[i].next){
        int v=e[i].to;
        if(vis[v]||v==pre) continue;
        dfs(v,x,(dis+e[i].w)%3);
    }
}

long long ans=0;

void dac(int x){
    mn=n;f[0]=1;
    getsiz(x,-1);
    getroot(x,-1,siz[x]);
    int u=root;vis[u]=1;
    int offset=0;
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].to;
        if(vis[v]) continue;
        s[0]=0;dfs(v,u,e[i].w%3);
        int t0=0,t1=0,t2=0;
        for(int j=s[0];j>=1;j--){
            if(s[j]==0) t0++;
            if(s[j]==1) t1++;
            if(s[j]==2) t2++;
        }
        offset+=t0*t0+2*t1*t2;
        for(int j=s[0];j>=1;j--){
            f[s[j]]++;
        }
    }
    ans+=f[0]*f[0]+2*f[1]*f[2]-offset;
    memset(f,0,sizeof(f));
    for(int i=head[u];i;i=e[i].next){
        int v=e[i].to;
        if(!vis[v]) dac(v);
    }
}

long long gcd(long long x,long long y){
    return !y?x:gcd(y,x%y);
}

int main(){
    n=rd();
    int x,y,w;
    for(int i=1;i<=n-1;i++){
        x=rd();y=rd();w=rd();
        add(x,y,w%3);add(y,x,w%3);
    }
    dac(1);
    long long tmp=1ll*n*n;
    long long G=gcd(tmp,ans);
    printf("%lld/%lld",ans/G,tmp/G); 
    return 0;
}

本文来自博客园,作者:GhostCai,转载请注明原文链接:https://www.cnblogs.com/ghostcai/p/9477856.html

原文地址:https://www.cnblogs.com/ghostcai/p/9477856.html