BZOJ 3697

http://www.lydsy.com/JudgeOnline/problem.php?id=3697

点分治

休息站在起点到根的路径上或根到终点的路径上。

dfs时记录下路径的树上前缀和x,并判断路径的前缀和为x的节点。

枚举根的每个子树。

用g[i][0/1],f[i][0/1]分别表示已访问过的子树以及现在的子树和为i的路径数目,

0和1用于区分路径上是否存在前缀和为i的节点.

当前子树对答案的贡献为f[0][0]*g[0][0]+Σ(f[i][0]*g[-i][1]+f[i][1]*g[-i][0]+f[i][1]*g[-i][1]).

 
#include<cstdio>
#define FOR(i,s,t) for(register int i=s;i<=t;++i) 
#define gc getchar()
inline int max(int a,int b){return a>b?a:b;}
typedef long long ll;
ll ans;
inline int read(){
    char c;while(c=gc,c==' '||c=='
');int data=c-48;
    while(c=gc,c>='0'&&c<='9')data=(data<<1)+(data<<3)+c-48;return data;
}
const int N=800011;
struct edge{
    int to,w;
    edge *nxt;
    #define to(it) it->to
    #define w(it) it->w 
    #define add(x,y,z) (*++et=(edge){y,z,las[x]},las[x]=et)
    #define VIS(now) for(edge *it=las[now];it;it=it->nxt)
}e[N>>1],*las[N>>2],*et=e;  
int f[N][2],g[N][2],t[N];
int sz[N],p[N];
bool vis[N];
int n,x,y,z,G,sum,maxdeep;
inline void getG(int now,int fa){
    sz[now]=1,p[now]=0;
    VIS(now)
        if(to(it)!=fa&&!vis[to(it)]){
            getG(to(it),now);
            sz[now]+=sz[to(it)];
            p[now]=max(p[now],sz[to(it)]);
        }
    p[now]=max(p[now],sum-sz[now]);
    if(p[now]<p[G])G=now;
}
inline void dfs(int now,int fa,int len,int dep){
    maxdeep=max(maxdeep,dep);
    t[len+n]?++g[len+n][1]:++g[len+n][0];
    ++t[len+n];
    VIS(now)
        if(to(it)!=fa&&!vis[to(it)])
            dfs(to(it),now,len+w(it),dep+1); 
    --t[len+n];
}
inline void solve(int now){
    vis[now]=1;f[n][0]=1;
    int S=sum,mx=0;
    VIS(now)
        if(!vis[to(it)]){
            maxdeep=1;
            dfs(to(it),0,w(it),1);
            ans+=1ll*g[n][0]*(f[n][0]-1);
            mx=max(mx,maxdeep);
            FOR(i,-maxdeep,maxdeep)
                ans+=1ll*g[n-i][1]*f[n+i][1]+1ll*g[n-i][0]*f[n+i][1]+1ll*g[n-i][1]*f[n+i][0];
            FOR(i,-maxdeep,maxdeep){
                f[i+n][0]+=g[i+n][0];
                f[i+n][1]+=g[i+n][1];
                g[i+n][0]=g[i+n][1]=0;
            }
        }
    FOR(i,n-mx,n+mx)
        f[i][0]=f[i][1]=0;
    VIS(now)
        if(!vis[to(it)]){
            G=0;
            sum=sz[to(it)];
            getG(to(it),0);
            solve(G);
        }
}
int main(){
    n=read();
    FOR(i,2,n){
        x=read();y=read();z=read();
        if(!z)z=-1;
        add(x,y,z);add(y,x,z);
    }
    sum=p[0]=n;
    getG(1,0);
    solve(G);
    printf("%lld
",ans);
    return 0;
}

  

原文地址:https://www.cnblogs.com/Stump/p/7994619.html