【HDOJ5977】Garden of Eden(点分治)

题意:给定一棵n个点的树,每个节点上有一种颜色a[i],一共有k种颜色,问包含所有颜色的路径条数

n<=5e4,k<=10

思路:点分治求方案数

集合并卷积的时候暴力枚举状态即可O(n^logn*2^k)

75e的复杂度 只跑了1.7s 我也是醉了

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 #include<algorithm>
  5 #include<cmath>
  6 typedef long long ll;
  7 using namespace std;
  8 #define N   110000
  9 #define oo  10000000
 10 #define MOD 1000000007
 11 
 12 ll ans;
 13 int head[N],vet[N],nxt[N],a[N],flag[N],son[N],f[N],s[1100],dep[N],d[N],
 14     n,k,all,tot,sum,root;
 15 
 16 void add(int a,int b)
 17 {
 18     nxt[++tot]=head[a];
 19     vet[tot]=b;
 20     head[a]=tot;
 21 }
 22 
 23 void getroot(int u,int fa)
 24 {
 25     son[u]=1; f[u]=0;
 26     int e=head[u];
 27     while(e)
 28     {
 29         int v=vet[e];
 30         if(v!=fa&&!flag[v])
 31         {
 32             getroot(v,u);
 33             son[u]+=son[v];
 34             f[u]=max(f[u],son[v]);
 35         }
 36         e=nxt[e];
 37     }
 38     f[u]=max(f[u],sum-f[u]);
 39     if(f[u]<f[root]) root=u;
 40 }
 41 
 42 void getdep(int u,int fa)
 43 {
 44     dep[++dep[0]]=d[u];
 45     int e=head[u];
 46     while(e)
 47     {
 48         int v=vet[e];
 49         if(v!=fa&&!flag[v])
 50         {
 51             d[v]=d[u]|(1<<a[v]);
 52             getdep(v,u);
 53         }
 54         e=nxt[e];
 55     }
 56 }
 57 
 58 ll calc(int u,int now)
 59 {
 60     d[u]=now; dep[0]=0;
 61     getdep(u,0);
 62     memset(s,0,sizeof(s));
 63     ll ans=0;
 64     for(int i=1;i<=dep[0];i++)
 65     {
 66          for(int j=0;j<=all;j++)
 67            if((dep[i]|j)==all) ans+=s[j]; 
 68           s[dep[i]]++;
 69     }
 70     return ans;
 71 }
 72 
 73 void solve(int u)
 74 {
 75     ans+=calc(u,1<<a[u]);
 76     flag[u]=1;
 77     int e=head[u];
 78     while(e)
 79     {
 80         int v=vet[e];
 81         if(!flag[v])
 82         {
 83             ans-=calc(v,(1<<a[u])|(1<<a[v]));
 84             sum=son[v];
 85             root=0;
 86             getroot(v,root);
 87             solve(root);
 88         }
 89         e=nxt[e];
 90     }
 91 }
 92 
 93 int main()
 94 { 
 95     while(scanf("%d%d",&n,&k)!=EOF)
 96     {     
 97         all=(1<<k)-1;
 98         for(int i=1;i<=n;i++) head[i]=flag[i]=0;     
 99         for(int i=1;i<=n;i++) 
100         {
101             scanf("%d",&a[i]);
102             a[i]--;
103         }
104         tot=0;
105         for(int i=1;i<=n-1;i++)
106         {
107             int x,y;
108             scanf("%d%d",&x,&y);
109             add(x,y);
110             add(y,x);
111         }
112         if(k==1)
113         {
114             ll ans=n*n;
115             printf("%I64d
",ans);
116             continue;
117         }
118         sum=n; f[0]=oo; ans=0; root=0;
119         getroot(1,0);
120         solve(root);
121         ans*=2;
122         printf("%I64d
",ans);
123     }
124     return 0;
125 }
126     
原文地址:https://www.cnblogs.com/myx12345/p/9983827.html