BZOJ 3572 世界树(虚树)

http://www.lydsy.com/JudgeOnline/problem.php?id=3572

思路:建立虚树,然后可以发现,每条边不是同归属于一端,那就是切开,一半给上面,一半给下面。

  1 #include<algorithm>
  2 #include<cstdio>
  3 #include<cmath>
  4 #include<cstring>
  5 #include<iostream>
  6 #define N 300005
  7 int tot,go[N*2],next[N*2],first[N],sz,deep[N],tmp[N],tree[N];
  8 int son[N],dfn[N],fa[N][20],bin[20],ask[N],In[N],st[N],n,m,ans[N];
  9 int father[N],val[N];
 10 std::pair<int,int>near[N];
 11 int read(){
 12     int t=0,f=1;char ch=getchar();
 13     while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
 14     while ('0'<=ch&&ch<='9'){t=t*10+ch-'0';ch=getchar();}
 15     return t*f;
 16 }
 17 void insert(int x,int y){
 18     tot++;
 19     go[tot]=y;
 20     next[tot]=first[x];
 21     first[x]=tot;
 22 }
 23 void add(int x,int y){
 24     insert(x,y);insert(y,x);
 25 }
 26 void dfs(int x){
 27     son[x]=1;
 28     dfn[x]=++sz;
 29     for (int i=1;i<=19;i++)
 30      fa[x][i]=fa[fa[x][i-1]][i-1];
 31     for (int i=first[x];i;i=next[i]){
 32         int pur=go[i];
 33         if (pur==fa[x][0]) continue;
 34         deep[pur]=deep[x]+1;
 35         fa[pur][0]=x;
 36         dfs(pur);
 37         son[x]+=son[pur];
 38     } 
 39 }
 40 int find(int x,int dep){
 41     for (int i=19;i>=0;i--)
 42      if (deep[fa[x][i]]>=dep) x=fa[x][i];
 43     return x; 
 44 }
 45 int lca(int x,int y){
 46     if (deep[x]<deep[y]) std::swap(x,y);
 47     int t=deep[x]-deep[y];
 48     for (int i=0;i<=19;i++)
 49      if (t&bin[i])
 50       x=fa[x][i];
 51     if (x==y) return x;
 52     for (int i=19;i>=0;i--)
 53      if (fa[x][i]!=fa[y][i])
 54       x=fa[x][i],y=fa[y][i];
 55     return fa[x][0];    
 56 }
 57 bool cmp(int a,int b){
 58     return dfn[a]<dfn[b];
 59 }
 60 void solve(){
 61     m=read();
 62     for (int i=1;i<=m;i++){
 63         ask[i]=read(),tmp[i]=tree[i]=ask[i];
 64         near[ask[i]]=std::make_pair(0,ask[i]);
 65         ans[ask[i]]=0;
 66     }
 67     std::sort(ask+1,ask+1+m,cmp);
 68     int top=0,all=m;
 69     for (int i=1;i<=m;i++){
 70         int p=ask[i];
 71         if (!top) father[p]=0,st[++top]=p;
 72         else{
 73             int x=lca(st[top],p);
 74             father[p]=x;
 75             while (top&&deep[st[top]]>deep[x]){
 76                 if (deep[st[top-1]]<=deep[x]){
 77                     father[st[top]]=x;
 78                 }
 79                 top--;
 80             }
 81             if (st[top]!=x){
 82                 father[x]=st[top];tree[++all]=x;
 83                 st[++top]=x;near[x]=std::make_pair(1<<30,0);
 84             }
 85             st[++top]=p;
 86         }
 87     }
 88     std::sort(tree+1,tree+1+all,cmp);
 89     for (int i=1;i<=all;i++){
 90         int p=tree[i],f=father[p];
 91         val[p]=son[p];
 92         if (i>1) In[p]=deep[p]-deep[f];
 93     }
 94     for (int i=all;i>1;i--){
 95         int p=tree[i],f=father[p];
 96         near[f]=std::min(near[f],std::make_pair(near[p].first+In[p],near[p].second));
 97     }
 98     for (int i=2;i<=all;i++){
 99         int p=tree[i],f=father[p];
100         near[p]=std::min(near[p],std::make_pair(near[f].first+In[p],near[f].second));
101     }
102     for (int i=1;i<=all;i++){
103         int p=tree[i],f=father[p],sum=son[find(p,deep[f]+1)]-son[p];
104         if (f==0) ans[near[p].second]+=n-son[p];
105         else{
106             val[f]-=sum+son[p];
107             if (near[p].second==near[f].second) ans[near[p].second]+=sum;
108             else{
109                 int dis=(deep[p]-deep[f]-near[p].first+near[f].first)/2;
110                 if (dis+near[p].first==near[f].first+deep[p]-deep[f]-dis&&near[f].second<near[p].second) dis--;
111                 int x=find(p,deep[p]-dis);
112                 ans[near[p].second]+=son[x]-son[p];
113                 ans[near[f].second]+=sum+son[p]-son[x];
114             }
115         }
116     }
117     for (int i=1;i<=all;i++){
118         ans[near[tree[i]].second]+=val[tree[i]];
119     }
120     for (int i=1;i<=m;i++)
121      printf("%d ",ans[tmp[i]]);
122     puts(""); 
123 }
124 int main(){
125     n=read();
126     bin[0]=1;
127     for (int i=1;i<=19;i++) bin[i]=bin[i-1]*2;
128     for (int i=1;i<n;i++){
129         int x=read(),y=read();
130         add(x,y); 
131     }
132     dfs(1);
133     int T=read();
134     while (T--) solve();
135 }
原文地址:https://www.cnblogs.com/qzqzgfy/p/5585545.html