51Nod 1766 树上的最远点对(欧拉序、lca、线段树区间合并)

http://www.51nod.com/Challenge/Problem.html#!#problemId=1766

题解

首先要知道一个结论:两个区间的最远点对一定由各自区间的最远点对里的点组成。

然后就好做了,dfs建序然后求出欧拉序,然后打st表,通过lca求树上两点距离,然后在欧拉序上建线段树维护区间最远点对就做完了。

代码参考自Leviaton.

  1 #define dbg(x) cout<<#x<<" = "<< (x)<< endl
  2 #define IO std::ios::sync_with_stdio(0);
  3 #include <bits/stdc++.h>
  4 #define iter ::iterator
  5 using namespace  std;
  6 typedef long long ll;
  7 typedef pair<ll,ll>P;
  8 #define pb push_back
  9 #define se second
 10 #define fi first
 11 #define rs o<<1|1
 12 #define ls o<<1
 13 const ll inf=0x7fffffff;
 14 const int N=1e6+10;
 15 int n,q,tim,tot;
 16 vector<int>g[N],gl[N];
 17 int dfn[N],line[N*2],chain[N*2],d[N],sum[N];
 18 int st[N][25];
 19 int lg[N],pw[30],first[N];
 20 struct edge{ int from,u,v,len; }e[N];
 21 void insert(int u,int v,int len){ tot++; e[tot].from = first[u], e[tot].u = u , e[tot].v = v , e[tot].len = len, first[u] = tot; }
 22 void dfs(int u,int fa){
 23     //printf("u=%d
",u);
 24     dfn[u]=++tim;
 25     line[tim]=d[u];
 26     chain[tim]=u;
 27     /*for(int i = first[u];i;i = e[i].from){
 28         int v = e[i].v;
 29         if(v != fa){
 30             sum[v] = sum[u]+e[i].len;
 31             d[v] = d[u]+1;
 32             dfs(v,u);
 33             line[++tim] = d[u];
 34             chain[tim] = u;
 35         }
 36     }*/
 37     for(int i=0;i<g[u].size();i++){
 38         int v=g[u][i];
 39         int len=gl[u][i];
 40         if(v==fa)continue;
 41         d[v]=d[u]+1;
 42         sum[v]=sum[u]+len;
 43         dfs(v,u);
 44         line[++tim]=d[u];
 45         chain[tim]=u;
 46     }
 47 }
 48 void init(){
 49     for(int i=1;i<=tim;i++){
 50         st[i][0]=i;
 51     }
 52     pw[0]=1;
 53     for(int i=1;i<=20;i++){
 54         pw[i]=2*pw[i-1];
 55     }
 56     lg[0]=-1;
 57     for(int i=1;i<=tim;i++){
 58         lg[i]=lg[i>>1]+1;
 59         //printf("%d
",lg[i]);
 60     }
 61     for(int j=1;j<=20;j++){
 62         for(int i=1;i<=tim;i++){
 63             int res=i+pw[j]-1;
 64             if(res<=tim){
 65                 int x=st[i][j-1];
 66                 int y=st[i+pw[j-1]][j-1];
 67                 if(line[x]<line[y])st[i][j]=x;
 68                 else st[i][j]=y;
 69             }
 70         }
 71     }
 72 }
 73 int qust(int l,int r){
 74     l=dfn[l],r=dfn[r];
 75     if(l>r)swap(l,r);
 76     if(l==r)return chain[l];
 77     int x=lg[r-l+1];
 78     int x1=st[l][x];
 79     int x2=st[r-pw[x]+1][x];
 80     return  line[x1]<line[x2]?chain[x1]:chain[x2];
 81 }
 82 int getdis(int u,int v){
 83     if(d[u]<d[v])swap(u,v);
 84     int res=qust(u,v);
 85     if(res==v)return sum[u]-sum[v];
 86     else return sum[u]+sum[v]-sum[res]*2;
 87 }
 88 struct node{
 89     int A,B;
 90     node(){
 91         A=B=-1;
 92     }
 93 }tree[N*4];
 94 node merge(node a,node b,int flag){
 95     node tmp;
 96     int dis=-1;
 97     if(a.A>0&&a.B>0&&flag){
 98         int res=getdis(a.A,a.B);
 99         if(res>dis){
100             dis=res;
101             tmp.A=a.A;
102             tmp.B=a.B;
103         }
104     }
105     if(a.A>0&&b.A>0){
106         int res=getdis(a.A,b.A);
107         if(res>dis){
108             dis=res;
109             tmp.A=a.A;
110             tmp.B=b.A;
111         }
112     }
113     if(a.A>0&&b.B>0){
114         int res=getdis(a.A,b.B);
115         if(res>dis){
116             dis=res;
117             tmp.A=a.A;
118             tmp.B=b.B;
119         }
120     }
121     if(a.B>0&&b.A>0){
122         int res=getdis(a.B,b.A);
123         if(res>dis){
124             dis=res;
125             tmp.A=a.B;
126             tmp.B=b.A;
127         }
128     }
129     if(a.B>0&&b.B>0){
130         int res=getdis(a.B,b.B);
131         if(res>dis){
132             dis=res;
133             tmp.A=a.B;
134             tmp.B=b.B;
135         }
136     }
137     if(b.A>0&&b.B>0&&flag){
138         int res=getdis(b.A,b.B);
139         if(res>dis){
140             dis=res;
141             tmp.A=b.A;
142             tmp.B=b.B;
143         }
144     }
145     return tmp;
146 }
147 void build(int o,int l,int r){
148     if(l==r){
149         tree[o].A=l,tree[o].B=l;
150         return;
151     }
152     int m=(l+r)/2;
153     build(ls,l,m);
154     build(rs,m+1,r);
155     tree[o]=merge(tree[ls],tree[rs],1);
156 }
157 node qu(int o,int l,int r,int ql,int qr){
158     if(l>=ql&&r<=qr){
159         return tree[o];
160     }
161     node tmp;
162     int m=(l+r)/2;
163     if(ql<=m)tmp=merge(tmp,qu(ls,l,m,ql,qr),1);
164     if(qr>m)tmp=merge(tmp,qu(rs,m+1,r,ql,qr),1);
165     return tmp;
166 }
167 int read(){
168     int ret = 0; char ctr = getchar();
169     while(ctr < '0' || ctr > '9') ctr = getchar();
170     while(ctr >= '0' && ctr <= '9') ret = ret*10+ctr-'0',ctr = getchar();
171     return ret;
172 }
173 int main(){
174     n=read();
175     for(int i=1;i<n;i++){
176         int x,y,z;
177         x=read();y=read();z=read();
178         //insert(x,y,z); insert(y,x,z);
179         g[x].pb(y);
180         g[y].pb(x);
181         gl[x].pb(z);
182         gl[y].pb(z);
183     }
184     d[1]=1;
185     dfs(1,-1);
186     init();
187     build(1,1,n);
188     //printf("%d
",tim);
189     /*for(int i=1;i<=5;i++){
190         printf("i=%d: %d %d %d
",i,d[i],sum[i],dfn[i]);
191         //printf("%d
")
192     }*/
193     //cin>>q;
194     q=read();
195     while(q--){
196         int a,b,c,d;
197         a=read();
198         b=read();
199         c=read();
200         d=read();
201         node ans=merge(qu(1,1,n,a,b),qu(1,1,n,c,d),0);
202         printf("%d
",getdis(ans.A,ans.B));
203     }
204 }
原文地址:https://www.cnblogs.com/ccsu-kid/p/10752468.html