点分治_学习笔记+题目清单

1.模板题 洛谷P3806

注意对limit加限制<=1e7,不然会RE

#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=1e5+5;
const int maxk=1e7+5;
const int limit=1e7;
int tot,head[maxn];
struct E{
    int to,next,w;
}edge[maxn<<1];
void add(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,m,rt,sum,cnt,q[maxn];
int tmp[maxn],siz[maxn],dis[maxn],maxp[maxn];
bool judge[maxk],ans[maxn],vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int f){
    if(dis[u]<=limit) tmp[cnt++]=dis[u];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=dis[u]+edge[i].w;
        getdis(v,u);
    }
}
void solve(int u){
    queue<int> que;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=edge[i].w;
        getdis(v,u);
        for(int j=0;j<cnt;j++)
            for(int k=0;k<m;k++)
                if(q[k]>=tmp[j])
                    ans[k]|=judge[q[k]-tmp[j]];
        for(int j=0;j<cnt;j++){
            que.push(tmp[j]);
            judge[tmp[j]]=true;
        }
    }
    while(!que.empty()){
        judge[que.front()]=false;
        que.pop();
    }
}
void divide(int u){
    vis[u]=judge[0]=true;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
int main(){
    scanf("%d%d",&n,&m);mem(head,-1);
    for(int i=1;i<n;i++){
        int u,v,w;scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);add(v,u,w);
    }
    for(int i=0;i<m;i++) scanf("%d",&q[i]);
    maxp[0]=sum=n;
    getrt(1,0);
    getrt(rt,0);
    divide(rt);
    for(int i=0;i<m;i++){
        if(ans[i]) puts("AYE");
        else puts("NAY");
    }
}
View Code

2.P4178 Tree

遵循点分治思想,在solve函数中用双指针维护和<=k,另外在求距离的getdis函数中,如果当前距离>k了可以直接return;因为距离是累加的,大于了就无意义.暴力会TLE

#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=4e4+5;
int tot,head[maxn];
struct E{
    int to,next,w;
}edge[maxn<<1];
void add(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
bool vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int f){
    if(dis[u]>lim) return ;
    tmp[cnt++]=dis[u];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=dis[u]+edge[i].w;
        getdis(v,u);
    }
}
void solve(int u){
    vector<int> vec;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=edge[i].w;
        getdis(v,u);
        sort(tmp,tmp+cnt);
        sort(vec.begin(),vec.end());
        int r=vec.size()-1;
        for(int j=0;j<cnt&&r>=0;j++){
            while(tmp[j]+vec[r]>lim){
                --r;
                if(r<=-1) break;
            }
            if(r>=0) ans+=(r+1);
            else break;
        }
        for(int j=0;j<cnt;j++){
            if(tmp[j]<=lim) ++ans;
            vec.push_back(tmp[j]);
        }            
    }
    vec.clear();
}
void divide(int u){
    vis[u]=1;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
int main(){
    scanf("%d",&n);mem(head,-1);
    rep(i,1,n-1){
        int u,v,w;scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);add(v,u,w);
    }
    scanf("%d",&lim);
    maxp[0]=sum=n;
    getrt(1,0);
    getrt(rt,0);
    divide(rt);
    printf("%d
",ans);
}    
View Code

3.P2634 [国家集训队]聪聪可可

统计一共有几条边,并对边分析,看%3是否==0

#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=1e5+5;
int tot,head[maxn];
struct E{
    int to,next,w;
}edge[maxn<<1];
void add(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,rt,cnt,sum,cnt1=0,cnt2=0,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
bool vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int f){
    tmp[cnt++]=dis[u];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=dis[u]+edge[i].w;
        getdis(v,u);
    }
}
void solve(int u){
    vector<int> vec;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=edge[i].w;
        getdis(v,u);
        for(int j=0;j<cnt;j++){
            for(auto it:vec){
                ++cnt2;
                if((it+tmp[j])%3==0) ++cnt1;
            }
        }
        for(int j=0;j<cnt;j++){
            ++cnt2;
            if(tmp[j]%3==0) ++cnt1;
            vec.push_back(tmp[j]);
        }            
    }
    vec.clear();
}
void divide(int u){
    vis[u]=1;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
int main(){
    scanf("%d",&n);mem(head,-1);
    rep(i,1,n-1){
        int u,v,w;scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);add(v,u,w);
    }
    maxp[0]=sum=n;
    getrt(1,0);
    getrt(rt,0);
    divide(rt);
    cnt1*=2;cnt2*=2;
    cnt1+=n;cnt2+=n;
    int c=__gcd(cnt1,cnt2);
    cnt1/=c;cnt2/=c;
    printf("%d/%d",cnt1,cnt2);
}    
View Code

 4.Codeforces.161D Distance in Tree

思想和P3806很贴近,我一开始用了双指针,但是双指针会处理过剩,所以会wa(也有可能我写的菜吧).所以我用了judge来存每个边权出现的次数,lim-tmp[j],表示在u结点的别的子树上是否存有当前值的边权,若有,ans+其数量即可。对于每个v所在自己的子树也要处理.

#include<bits/stdc++.h>
#define ll long long
#define int long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=5e4+5;
int tot,head[maxn];
struct E{
    int to,next,w;
}edge[maxn<<1];
void add(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,m,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
bool vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int f){
    if(dis[u]>lim) return ;
    tmp[cnt++]=dis[u];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=dis[u]+edge[i].w;
        getdis(v,u);
    }
}
void solve(int u){
    int judge[510]={0};
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=edge[i].w;
        getdis(v,u);
        for(int j=0;j<cnt;j++){
            ans+=judge[lim-tmp[j]];
        }
        for(int j=0;j<cnt;j++){
            if(tmp[j]==lim) ++ans;
            judge[tmp[j]]++;
        }            
    }
}
void divide(int u){
    vis[u]=1;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
signed main(){
    scanf("%lld%lld",&n,&lim);mem(head,-1);
    rep(i,1,n-1){
        int u,v,w;scanf("%lld%lld",&u,&v);
        add(u,v,1);add(v,u,1);
    }
    maxp[0]=sum=n;
    getrt(1,0);
    getrt(rt,0);
    divide(rt);
    printf("%lld
",ans);
}    
View Code

 5.P4149 [IOI2011]Race

我觉得这个题还是模板题的延伸,我用tmp2数组记录当前边权值和所运用的边的数量,judge[tmp[j]]表示当前tmp[j]边权下所运用的最少边数量,记得judge一开始初始化为INF,同样如果暴力会TLE,所以我们需要用队列queue来维护已经运用过的judge情况,最后更新删去队首初始化judge为INF,这样会很省时间

#include<bits/stdc++.h>
#define ll long long
#define int long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=2e5+5;
const int maxk=1e6+5;
int tot,head[maxn];
struct E{
    int to,next,w;
}edge[maxn<<1];
void add(int u,int v,int w){
    edge[tot].to=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,m,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
bool vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
int tmp2[maxn],ct,dep[maxn];
void getdis(int u,int f){
    if(dis[u]>lim) return ;
    tmp2[cnt]=dep[u];    
    tmp[cnt++]=dis[u];
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=dis[u]+edge[i].w;
        dep[v]=dep[u]+1;
        getdis(v,u);
    }
}
int judge[maxk];
void solve(int u){
    queue<int> que;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=edge[i].w;
        dep[v]=1;
        getdis(v,u);
        for(int j=0;j<cnt;j++){
            if(judge[lim-tmp[j]]!=INF){
                ans=min(ans,(tmp2[j]+judge[lim-tmp[j]]));
            }
        }
        for(int j=0;j<cnt;j++){
            if(tmp[j]==lim) ans=min(ans,tmp2[j]);
            que.push(tmp[j]);
            judge[tmp[j]]=min(judge[tmp[j]],tmp2[j]);
        }            
    }
    while(!que.empty()){
        judge[que.front()]=INF;
        que.pop();
    }
}
void divide(int u){
    vis[u]=1;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
signed main(){
    scanf("%lld%lld",&n,&lim);
    mem(judge,INF);mem(head,-1);ans=INF;
    rep(i,1,n-1){
        int u,v,w;scanf("%lld%lld%lld",&u,&v,&w);
        u+=1,v+=1;
        add(u,v,w);add(v,u,w);
    }
    maxp[0]=sum=n;
    getrt(1,0);
    getrt(rt,0);
    divide(rt);
    if(ans!=INF) printf("%lld
",ans);
    else puts("-1");
}    
View Code

 6.hdu-4812 D-Tree

这个题出的挺好,但我不知道为什么我一直RE(栈溢出),看了Hzwer的代码感觉差不多但是他能过。。。不过这题想法很好,因为要求两点val的积取模之后等于k,那么就用线性预处理逆元(这里我才学会,不会数论),然后点分治,一开始solve的时候,不带根节点u,每次都是子树上的,用mp查看其逆元是否存在,若存在则更新2个ans,然后再用一次getdis,把自己的子树的dis乘上根节点u的val,这样继续更新。最后的形态是所有子树的距离都带上根节点了,这个时候我们再来一次getdis来把所有情况给去掉,这题很有想法啊

#include<algorithm>
#include<cstdio>
#include<map>
#include<cmath>
#include<cstring>
#pragma comment(linker,"/STACK:102400000,102400000")
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mod 1000003
#define INF 1e9
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
int tot,head[100005];
struct E{
    int to,next;
}edge[200005];
void add(int u,int v){
    edge[tot].to=v;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,k,sum,cnt,rt,id[100005],siz[100005],maxp[100005];
ll val[100005],dis[100005],tmp[100005];
ll mp[1000005],ine[1000005];
int ans1,ans2;
bool vis[100005];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],sum-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int f){
    tmp[++cnt]=dis[u];
    id[cnt]=u;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        dis[v]=(dis[u]*val[v])%mod;
        getdis(v,u);
    }
}
void query(int x,int id){
    x=ine[x]*k%mod;
    int y=mp[x];
    if(y==0)return;
    if(y>id)swap(y,id);
    if(y<ans1||(y==ans1&&id<ans2))
        ans1=y,ans2=id;
}
void divide(int u){
    vis[u]=1;
    mp[val[u]]=u;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=val[v];
        getdis(v,u);
        for(int j=1;j<=cnt;j++){
            query(tmp[j],id[j]);
        }
        cnt=0;
        dis[v]=(val[u]*val[v])%mod;
        getdis(v,u);
        for(int j=1;j<=cnt;j++){
            int now=mp[tmp[j]];
            if(!now||id[j]<now) mp[tmp[j]]=id[j];
        }        
    }
    mp[val[u]]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt=0;
        dis[v]=(val[u]*val[v])%mod;
        getdis(v,u);
        for(int j=1;j<=cnt;j++){
            mp[tmp[j]]=0;
        }
    }
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        rt=0;sum=siz[v];
        getrt(v,0);
        getrt(rt,0);
        divide(rt);
    }
}
int main(){
    ine[1]=1;
    for(int i=2;i<mod;i++){
        int a=mod/i,b=mod%i;
        ine[i]=(ine[b]*(-a)%mod+mod)%mod;    
    }
    while(~scanf("%d%d",&n,&k)){
        mem(head,-1);mem(vis,0);
        cnt=0;ans1=ans2=INF;
        rep(i,1,n) scanf("%d",&val[i]);
        rep(i,1,n-1){
            int u,v;scanf("%d%d",&u,&v);
            add(u,v);add(v,u);
        }
        rt=0;maxp[0]=n+1;sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        if(ans1==INF) puts("No solution");
        else printf("%d %d
",ans1,ans2);
    }
    return 0;
}
View Code

 7.P2664 树上游戏※

这个难度还是和之前6个题来说有很大提升。做了很久没有做出来。思路是:对于一个p作为根节点的子树中某个节点u而言,如果颜色c[u]在u到根节点p的路径上是第一次出现,那么对于根节点p及其不在u所在子树上的任一节点,这些节点均会产生siz[u]大小的新的贡献。同理对于根节点p也一样,所以一开始我们dfs1把根节点u的贡献全部算出来,并统计以p为点分治树根的情况下所有颜色col[c[u]]的贡献情况。

接下来就要讨论一些经过p节点跨根的贡献情况,一个节点u在p的子树v为根的子树上,那么u到p这段链中产生的贡献就是每一个在u之前有几个不同颜色num,然后乘(siz[u]-siz[v]),记得一开始的时候要把v树置0噢,这个是跨根贡献,显然这些还是不够的,因为非v子树上还有不同的点,那怎么办?一开始的sum就有用了,sum说是p产生的贡献(其中包括v树,所以要减掉v树的情况),然后是对已用颜色的去重,若在该链上且在u之前被用过了,那么把sum-col 这样就能去重了。说不清楚太难了

#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,n,a) for(int i=n;i>=a;i--)
#define endl '
'
#define mem(a,b) memset(a,b,sizeof(a))
#define IO ios::sync_with_stdio(false);cin.tie(0);
using namespace std;
const int INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=1e5+5;
int tot,head[maxn];
struct E{
    int to,next;
}edge[maxn<<1];
void add(int u,int v){
    edge[tot].to=v;
    edge[tot].next=head[u];
    head[u]=tot++;
}
int n,rt,SIZE;
int siz[maxn],c[maxn],maxp[maxn],tmp[maxn],dis[maxn];
bool vis[maxn];
void getrt(int u,int f){
    siz[u]=1,maxp[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==f||vis[v]) continue;
        getrt(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxp[u]) maxp[u]=siz[v];
    }
    maxp[u]=max(maxp[u],SIZE-siz[u]);
    if(maxp[u]<maxp[rt]) rt=u;
}
ll ans[maxn],cnt[maxn],col[maxn],sum,num,S;
void dfs1(int u,int f){
    siz[u]=1;cnt[c[u]]++;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]||v==f) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
    }
    if(cnt[c[u]]==1){
        sum+=siz[u];
        col[c[u]]+=siz[u];
    }
    cnt[c[u]]--;
}
void change(int u,int f,int k){
    cnt[c[u]]++;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]||v==f) continue;
        change(v,u,k);
    }
    if(cnt[c[u]]==1){
        sum+=k*siz[u];
        col[c[u]]+=k*siz[u];
    }
    cnt[c[u]]--;
}
void dfs2(int u,int f){
    cnt[c[u]]++;
    if(cnt[c[u]]==1){
        sum-=col[c[u]];num++;
    }
    ans[u]+=sum+num*S;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]||v==f) continue;
        dfs2(v,u);
    }
    if(cnt[c[u]]==1){
        sum+=col[c[u]];num--;
    }
    cnt[c[u]]--;
}
void clear(int u,int f){
    cnt[c[u]]=col[c[u]]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]||v==f) continue;
        clear(v,u);
    }
}
void solve(int u){
    dfs1(u,0);ans[u]+=sum;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        cnt[c[u]]++;sum-=siz[v];col[c[u]]-=siz[v];
        change(v,u,-1);cnt[c[u]]--;
        S=siz[u]-siz[v];dfs2(v,u);
        cnt[c[u]]++;sum+=siz[v];col[c[u]]+=siz[v];
        change(v,u,1);cnt[c[u]]--;        
    }    
    sum=0,num=0,clear(u,0);
}
void divide(int u){
    vis[u]=true;
    solve(u);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=SIZE=siz[v];
        getrt(v,0);
        divide(rt);
    }
}
int main(){
    scanf("%d",&n);mem(head,-1);
    rep(i,1,n) scanf("%d",&c[i]);
    rep(i,1,n-1){
        int u,v;scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    maxp[0]=SIZE=n;rt=0;
    getrt(1,0);
    divide(rt);    
    rep(i,1,n) cout<<ans[i]<<endl;
}
View Code
原文地址:https://www.cnblogs.com/Anonytt/p/13053962.html