【Foreign】采蘑菇 [点分治]

采蘑菇

Time Limit: 20 Sec  Memory Limit: 256 MB

Description

  

Input

  

Output

  

Sample Input

  5
  1 2 3 2 3
  1 2
  1 3
  2 4
  2 5

Sample Output

  10
  9
  12
  9
  11

HINT

  

Main idea

  询问从以每个点为起始点时,各条路径上的颜色种类的和。

Solution

  我们看到题目,立马想到了O(n^2)的做法,然后从这个做法研究一下本质,我们确定了可以以点分治作为框架。

  我们先用点分治来确定一个center(重心)。然后计算跟这个center有关的路径。设现在要统计的是经过center,对x提供贡献的路径。

  我们先记录一个记录Sum[x]表示1~i-1子树中 颜色x 第一次出现的位置的那个点 的子树和,然后我们就利用这个Sum来解题。

  我们显然可以分两种情况来讨论:

  (1)统计center->x出现颜色的贡献
    显然,这时候,对于center->x这一段,直接像O(n^2)做法那样记录一个color表示到目前为止出现的颜色个数,然后加一下即可。再记录一个record表示当前可有的贡献和,一旦出现过一个颜色,那么这个颜色在1~i-1子树上出现第一次以下的点,对于x就不再提供贡献了,record减去Sum[这个颜色],然后这样深搜往下计算即可。

  (2)统计center->x没出现过的颜色的贡献
    显然,对于center->x上没出现过的颜色,直接往下深搜,一开始为record为(All - Sum[center]),一旦出现了一个颜色,record则减去这个Sum。同样表示不再提供贡献即可。

  我们这样做就可以求出每个子树前缀对于其的贡献了,倒着再做一边即可求出全部的贡献。统计x的时候,顺便统计一下center。可以满足效率,成功AC这道题。

Code

  1 #include<iostream>  
  2 #include<algorithm>  
  3 #include<cstdio>  
  4 #include<cstring>  
  5 #include<cstdlib>  
  6 #include<cmath>  
  7 using namespace std;
  8 
  9 const int ONE = 600005;
 10 const int INF = 214783640;
 11 const int MOD = 1e9+7;
 12 
 13 int n,x,y;
 14 int Val[ONE];
 15 int next[ONE],first[ONE],go[ONE],tot;
 16 int vis[ONE];
 17 int Ans[ONE],Sum[ONE];
 18 int All;
 19 
 20 
 21 int get()
 22 { 
 23         int res,Q=1;    char c;
 24         while( (c=getchar())<48 || c>57)
 25         if(c=='-')Q=-1;
 26         if(Q) res=c-48; 
 27         while((c=getchar())>=48 && c<=57) 
 28         res=res*10+c-48; 
 29         return res*Q; 
 30 }
 31 
 32 void Add(int u,int v)
 33 {
 34         next[++tot]=first[u];    first[u]=tot;    go[tot]=v;
 35         next[++tot]=first[v];    first[v]=tot;    go[tot]=u;
 36 }
 37 
 38 namespace Point
 39 {
 40         int center;
 41         int Stack[ONE],top;
 42         int total,Max,center_vis[ONE];
 43         int num,V[ONE];
 44         
 45         struct power
 46         {
 47             int size,maxx;
 48         }S[ONE];
 49         
 50         void Getsize(int u,int father)
 51         {
 52             S[u].size=1;
 53             S[u].maxx=0;
 54             for(int e=first[u];e;e=next[e])
 55             {
 56                 int v=go[e];
 57                 if(v==father || center_vis[v]) continue;
 58                 Getsize(v,u);
 59                 S[u].size += S[v].size;
 60                 S[u].maxx = max(S[u].maxx,S[v].size);
 61             }
 62         }
 63              
 64         void Getcenter(int u,int father,int total)
 65         {
 66             S[u].maxx = max(S[u].maxx,total-S[u].size);
 67             if(S[u].maxx < Max)
 68             {
 69                 Max = S[u].maxx;
 70                 center = u;
 71             }
 72                
 73             for(int e=first[u];e;e=next[e])
 74             {
 75                 int v=go[e];
 76                 if(v==father || center_vis[v]) continue;
 77                 Getcenter(v,u,total);
 78             }
 79         }
 80         
 81         void Ad_sum(int u,int father)
 82         {
 83             if(!vis[Val[u]])
 84             {
 85                 Stack[++top] = Val[u];
 86                 All += S[u].size;    Sum[Val[u]] += S[u].size;
 87             }
 88             vis[Val[u]]++;
 89             for(int e=first[u];e;e=next[e])
 90             {
 91                 int v=go[e];
 92                 if(v==father || center_vis[v]) continue;
 93                 Ad_sum(v,u);
 94             }
 95             vis[Val[u]]--;
 96         }
 97 
 98         void Calc_in(int u,int father,int center,int Size,int f_time,int record)
 99         {
100             if(!vis[Val[u]]) f_time++, record += Size, record -= Sum[Val[u]];
101             Ans[u] += record;    Ans[center]+=f_time;
102             Ans[u] += f_time;    vis[Val[u]] ++;
103             for(int e=first[u];e;e=next[e])
104             {
105                 int v=go[e];
106                 if(v==father || center_vis[v]) continue;
107                 Calc_in(v,u,center,Size,f_time,record);
108             }
109             vis[Val[u]] --;
110         }
111         
112         void Calc_not(int u,int father,int record)
113         {
114             if(!vis[Val[u]]) record -= Sum[ Val[u] ];
115             Ans[u] += record;    vis[Val[u]] ++;
116             for(int e=first[u];e;e=next[e])
117             {
118                 int v=go[e];
119                 if(v==father || center_vis[v]) continue;
120                 Calc_not(v,u,record);
121             }
122             vis[Val[u]] --;
123         }
124         
125         void Dfs(int u)
126         {
127             Max = n;
128             Getsize(u,0);
129             Getcenter(u,0,S[u].size);
130             Getsize(center,0);
131             center_vis[center] = 1;
132             
133             int num=0; for(int e=first[center];e;e=next[e]) if(!center_vis[go[e]]) V[++num]=go[e];
134             
135             for(int i=1;i<=num;i++)
136             {
137                 int v=V[i];
138                 int Size = S[center].size - S[v].size - 1;
139                 vis[Val[center]] = 1;
140                 Calc_in(v,center,center, Size,1,All - Sum[Val[center]] + Size);
141                 vis[Val[center]] = 0;
142                 Ad_sum(v,center);
143             }
144             while(top) Sum[Stack[top--]]=0;    All=0;
145             
146             for(int i=num;i>=1;i--)
147             {
148                 int v=V[i];
149                 vis[Val[center]] = 1;
150                 Calc_not(v,center, All-Sum[Val[center]]);
151                 vis[Val[center]] = 0;
152                 Ad_sum(v,center);
153             }
154             
155             while(top) Sum[Stack[top--]]=0;    All=0;
156             for(int e=first[center];e;e=next[e])
157             {
158                 int v=go[e];
159                 if(center_vis[v]) continue;
160                 Dfs(v);
161             }
162         }
163         
164 }
165 
166 int main()
167 {      
168         n=get();
169         for(int i=1;i<=n;i++)    Val[i]=get();
170 
171         for(int i=1;i< n;i++)
172         {
173             x=get();    y=get();
174             Add(x,y);
175         }
176         
177         Point:: Dfs(1);
178         for(int i=1;i<=n;i++)
179             printf("%d
",Ans[i]+1);
180 }
View Code
原文地址:https://www.cnblogs.com/BearChild/p/6517325.html