2017.10.2 国庆清北 D2T2 树上抢男主

  1 /*
  2 我只看懂了求LCA
  3 */ 
  4 
  5 #include<iostream>
  6 #include<cstring>
  7 #include<cstdio>
  8 #include<cmath>
  9 #include<algorithm>
 10 #define N 100005
 11 using namespace std;
 12 
 13 int n,m,head,tot;
 14 int first[N],fa[N][20],deep[N],z[N*3],que[N],sum[N*3][2],fd[N],start[N],endd[N],value[N]; 
 15 
 16 struct edge
 17 {
 18     int u,v,w,next;
 19 }edge[N<<1];
 20 
 21 inline void add_edge(int u,int v,int w)
 22 {
 23     ++head;
 24     edge[head].u=u;
 25     edge[head].v=v;
 26     edge[head].w=w;
 27     edge[head].next=first[u];
 28     first[u]=head;
 29 }
 30 
 31 inline int get(int p,int d)
 32 {
 33     if(d==-1) return p;
 34     int x=0;
 35     while(d)
 36     {
 37         if(d&1) p=fa[p][x];
 38         d>>=1;
 39         x++;
 40     }
 41     return p;
 42 }
 43 
 44 inline int get_lca(int a,int b)
 45 {
 46     if(deep[a]<deep[b]) swap(a,b);
 47     a=get(a,deep[a]-deep[b]);
 48     int x=0;
 49     while(a!=b)
 50     {
 51         if(!x||fa[a][x]!=fa[b][x])
 52         {
 53             a=fa[a][x];
 54             b=fa[b][x];
 55             x++;
 56         }
 57         else x--;
 58     }
 59     return a;
 60 }
 61 
 62 inline int calc(int a,int b)
 63 {
 64     if(a==fa[b][0]) return value[1]-value[b];
 65     return value[a]+fd[a];
 66 }
 67 
 68 inline int calcp(int p,int v)
 69 {
 70     int l=start[p]-1,r=endd[p];
 71     while(l+1<r)
 72     {
 73         int mid=(l+r)>>1;
 74         if(v>z[mid]) l=mid;
 75         else r=mid;
 76     }
 77     return r;
 78 }
 79 
 80 int main()
 81 {
 82     scanf("%d%d",&n,&m);
 83     for(int i=1;i<n;i++)
 84     {
 85         int u,v,w;
 86         scanf("%d%d%d",&u,&v,&w);
 87         tot+=w;
 88         add_edge(u,v,w);
 89         add_edge(v,u,w);
 90     }
 91     deep[1]=1;
 92     int front=1,tail=1;
 93     que[1]=1;
 94     while(front<=tail)        //预处理每个点倍增值 
 95     {
 96         int now=que[front++];
 97         for(int i=first[now];i;i=edge[i].next)
 98         {
 99             int to=edge[i].v;
100             if(!deep[to])
101             {
102                 deep[to]=deep[now]+1;
103                 fd[to]=edge[i].w;
104                 fa[to][0]=now;
105                 int pre=now,x=0;
106                 while(fa[pre][x])
107                 {
108                     fa[to][x+1]=fa[pre][x];
109                     pre=fa[pre][x];
110                     x++;
111                 }
112                 que[++tail]=to;
113             }
114         }
115     }
116     int cnt=0;
117     for(int i=n;i;i--)
118     {
119         int now=que[i];
120         start[now]=cnt+1;
121         for(int i=first[now];i;i=edge[i].next)
122         {
123             int to=edge[i].v;
124             if(deep[to]==deep[now]+1)
125             {
126                 z[++cnt]=value[to]+edge[i].w;
127                 value[now]+=value[to]+edge[i].w;
128             }
129         }
130         z[++cnt]=tot-value[now];
131         endd[now]=cnt;
132         sort(z+start[now],z+endd[now]+1);
133         sum[endd[now]][0]=z[endd[now]];
134         sum[endd[now]][1]=0;
135         for(int i=endd[now]-1;i>=start[now];i--)
136         {
137             sum[i][0]=sum[i+1][0];
138             sum[i][1]=sum[i+1][1];
139             if((i&1)==(endd[now]&1)) sum[i][0]+=z[i];
140             else sum[i][1]+=z[i];
141         }
142         cnt++;
143     }
144     for(int i=1;i<=m;i++)
145     {
146         int p1,p2;
147         scanf("%d%d",&p1,&p2);
148         int lca=get_lca(p1,p2);
149         int dis=deep[p1]+deep[p2]-2*deep[lca];
150         int delta=dis/2+(dis&1);
151         int px,px1,px2;
152         if(deep[p1]-deep[lca]<delta) px=get(p2,dis-delta);
153         else px=get(p1,delta);
154         if(deep[p1]-deep[lca]<delta-1) px1=get(p2,dis-delta+1);
155         else px1=get(p1,delta-1);
156         if(deep[p2]-deep[lca]<dis-delta-1) px2=get(p1,delta+1);
157         else px2=get(p2,dis-delta-1);
158         int ans=0;
159         if(p1==px)
160         {
161             if(p2==px) ans=sum[start[px]][0];
162             else
163             {
164                 int v2=calc(px2,px);
165                 int p=calcp(px,v2);
166                 ans=sum[p+1][0]+sum[start[px]][1]-sum[p][1];
167             }
168         }
169         else
170         {
171             if(p2==px)
172             {
173                 int v1=calc(px1,px);
174                 int p=calcp(px,v1);
175                 ans=v1+sum[p+1][1]+sum[start[px]][0]-sum[p][0];
176             }
177             else
178             {
179                 int v1=calc(px1,px);
180                 int pp1=calcp(px,v1);
181                 int v2=calc(px2,px);
182                 int pp2=calcp(px,v2);
183                 if(pp2==pp1) pp2++;
184                 if(pp1>pp2) swap(pp1,pp2);
185                 ans=v1+sum[pp2+1][dis&1]+sum[pp1+1][1-(dis&1)]-sum[pp2][1-(dis&1)]+sum[start[px]][dis&1]-sum[pp1][dis&1];
186             }
187         }
188         printf("%d
",ans);
189     }
190     return 0;
191 }
View Code
原文地址:https://www.cnblogs.com/lovewhy/p/7653089.html