bzoj 3611: [Heoi2014]大工程

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<algorithm>
  5 #define M 2000009
  6 #define inf 0x7ffffff
  7 #define ll long long
  8 using namespace std;
  9 int n,head[M],next[M],u[M],cnt,head1[M],next1[M],u1[M],fa[M][21],deep[M],m,dfn[M],T,v[M];
 10 int h[M],st[M],mn[M],mx[M],size[M],mx1,mi1,v1[M];
 11 ll cn1,sum[M];
 12 void jia(int a1,int a2)
 13 {
 14     cnt++;
 15     next[cnt]=head[a1];
 16     head[a1]=cnt;
 17     u[cnt]=a2;
 18     return;
 19 }
 20 void jia2(int a1,int a2)
 21 {
 22     if(a1==a2)
 23       return;
 24     cnt++;
 25     next1[cnt]=head1[a1];
 26     head1[a1]=cnt;
 27     u1[cnt]=a2;
 28     v1[cnt]=deep[a2]-deep[a1];
 29     return;
 30 }
 31 bool cmp(int a1,int a2)
 32 {
 33     return dfn[a1]<dfn[a2];
 34 }
 35 void dfs(int a1)
 36 {
 37     dfn[a1]=++T;
 38     for(int i=1;(1<<i)<=deep[a1];i++)
 39       fa[a1][i]=fa[fa[a1][i-1]][i-1];
 40     for(int i=head[a1];i;i=next[i])
 41       if(fa[a1][0]!=u[i])
 42         {
 43             deep[u[i]]=deep[a1]+1;
 44             fa[u[i]][0]=a1;
 45             dfs(u[i]);
 46         }
 47     return;
 48 }
 49 int lca(int a1,int a2)
 50 {
 51     if(deep[a1]<deep[a2])
 52       swap(a1,a2);
 53     int a3=deep[a1]-deep[a2];
 54     for(int i=0;i<=20;i++)
 55       if(a3&(1<<i))
 56         a1=fa[a1][i];
 57     for(int i=20;i>=0;i--)
 58       if(fa[a1][i]!=fa[a2][i])
 59         {
 60             a1=fa[a1][i];
 61             a2=fa[a2][i];
 62         }
 63     if(a1==a2)
 64       return a1;
 65     return fa[a1][0];
 66 }
 67 void dp(int x){
 68     sum[x]=0;
 69     mx[x]=v[x]?0:-inf;
 70     mn[x]=v[x]?0:inf;
 71     size[x]=v[x];
 72     for(int i=head1[x];i;i=next1[i]){
 73         int v=u1[i];
 74         dp(v);
 75         cn1+=(sum[x]+size[x]*v1[i])*size[v]+size[x]*sum[v];
 76         size[x]+=size[v];
 77         sum[x]+=sum[v]+(ll)size[v]*v1[i];
 78         mi1=min(mi1,mn[v]+mn[x]+v1[i]);
 79         mx1=max(mx1,mx[v]+mx[x]+v1[i]);
 80         mn[x]=min(mn[x],mn[v]+v1[i]);
 81         mx[x]=max(mx[x],mx[v]+v1[i]);
 82     }
 83     head1[x]=0;
 84 }
 85 void solve(){
 86     cnt=cn1=0;mi1=inf;mx1=-inf;
 87     int K;
 88     scanf("%d",&K);
 89     for(int i=1;i<=K;i++) scanf("%d",&h[i]),v[h[i]]=1;
 90     sort(h+1,h+K+1,cmp);int top=1;st[1]=1;
 91     for(int i=1;i<=K;i++){
 92         int now=h[i],f=lca(st[top],now);
 93         if(dfn[f]==dfn[st[top]]) st[++top]=now;
 94         else{
 95             while(top){
 96                 int q=st[top-1];
 97                 if(dfn[q]>dfn[f]) jia2(st[top-1],st[top]),top--;
 98                 else if(dfn[q]==dfn[f]){
 99                     jia2(q,st[top]);top--;break;
100                 }
101                 else {
102                     jia2(f,st[top]);st[top]=f;break;
103                 }   
104             }
105             if(st[top]!=now) st[++top]=now;
106         }
107     }
108     while(--top)jia2(st[top],st[top+1]);
109     dp(1);
110     printf("%lld ",cn1);
111     printf("%d %d
",mi1,mx1);
112     for(int i=1;i<=K;i++) v[h[i]]=0;
113 }
114 int main()
115 {
116     scanf("%d",&n);
117     for(int i=1;i<n;i++)
118       {
119         int a1,a2;
120         scanf("%d%d",&a1,&a2);
121         jia(a1,a2);
122         jia(a2,a1);
123       }
124     dfs(1);
125     scanf("%d",&m);
126     for(int i=1;i<=m;i++)
127       solve();
128     return 0;
129 }

虚树,树形DP

原文地址:https://www.cnblogs.com/xydddd/p/5309513.html