SPOJ 913 Query on a tree II 树链剖分

对于询问dist,树链剖分搞之,把边权转化到点上,然后注意细节就好(我在代码里标出来了,为了这个细节,wa了一屏)

对于询问kth,可以先求出两点(x和y)的lca,然后判断第k个数字是在x到lca的路径上还是y到lca的路径上,确定之后,倍增的寻找就好了~

View Code
  1 #include <iostream>
  2 #include <cstring>
  3 #include <cstdlib>
  4 #include <algorithm>
  5 #include <cstdio>
  6 
  7 #define N 50000
  8 #define M 100000
  9 
 10 using namespace std;
 11 
 12 int head[N],next[M],to[M],len[M];
 13 int n,tot,cnt;
 14 int fa[N],son[M],top[N],dat[N],sum[N<<2],dep[N],sz[N],pre[N],bh[N];
 15 int f[N][22],bit[22];
 16 int q[M];
 17 
 18 inline void init()
 19 {
 20     memset(head,-1,sizeof head); cnt=2; tot=0;
 21     memset(son,0,sizeof son);
 22     memset(fa,0,sizeof fa);
 23     memset(f,0,sizeof f);
 24     memset(sum,0,sizeof sum);
 25     bit[0]=1;
 26     for(int i=1;i<=20;i++) bit[i]=bit[i-1]<<1;
 27 }
 28 
 29 inline void prep()
 30 {
 31     int h=1,t=2,sta;
 32     q[1]=1; dep[1]=1;
 33     while(h<t)
 34     {
 35         sta=q[h++]; sz[sta]=1;
 36         for(int i=head[sta];~i;i=next[i])
 37             if(fa[sta]!=to[i])
 38             {
 39                 fa[to[i]]=sta;
 40                 f[to[i]][0]=sta;
 41                 pre[to[i]]=i^1;
 42                 dep[to[i]]=dep[sta]+1;
 43                 q[t++]=to[i];
 44             }
 45     }
 46     for(int j=t-1;j>=1;j--)
 47     {
 48         sta=q[j];
 49         for(int i=head[sta];~i;i=next[i])
 50             if(fa[sta]!=to[i])
 51             {
 52                 sz[sta]+=sz[to[i]];
 53                 if(sz[to[i]]>sz[son[sta]]) son[sta]=to[i];
 54             }
 55     }
 56     for(int i=1;i<t;i++)
 57     {
 58         sta=q[i];
 59         if(son[fa[sta]]==sta) top[sta]=top[fa[sta]];
 60         else top[sta]=sta;
 61     }
 62 }
 63 
 64 inline void rewrite()
 65 {
 66     for(int i=1;i<=n;i++)
 67         if(top[i]==i)
 68             for(int j=i;j;j=son[j])
 69             {
 70                 bh[j]=++tot;
 71                 dat[tot]=len[pre[j]];
 72             }
 73 }
 74 
 75 inline void lcainit()
 76 {
 77     for(int j=1;j<=20;j++)
 78         for(int i=1;i<=n;i++)
 79             f[i][j]=f[f[i][j-1]][j-1];
 80 }
 81 
 82 inline void pushup(int x)
 83 {
 84     sum[x]=sum[x<<1]+sum[x<<1|1];
 85 }
 86 
 87 inline void build(int u,int L,int R)
 88 {
 89     if(L==R) {sum[u]=dat[L];return;}
 90     int MID=(L+R)>>1;
 91     build(u<<1,L,MID); build(u<<1|1,MID+1,R);
 92     pushup(u);
 93 }
 94 
 95 inline void add(int u,int v,int w)
 96 {
 97     to[cnt]=v; len[cnt]=w; next[cnt]=head[u]; head[u]=cnt++;
 98 }
 99 
100 inline void read()
101 {
102     init();
103     scanf("%d",&n);
104     for(int i=1,a,b,c;i<n;i++)
105     {
106         scanf("%d%d%d",&a,&b,&c);
107         add(a,b,c); add(b,a,c);
108     }
109     prep();
110     rewrite();
111     build(1,1,tot);
112     lcainit();
113 }
114 
115 inline int querysum(int u,int L,int R,int l,int r)
116 {
117     if(l<=L&&R<=r) return sum[u];
118     int MID=(L+R)>>1,res=0;
119     if(l<=MID) res+=querysum(u<<1,L,MID,l,r);
120     if(MID<r) res+=querysum(u<<1|1,MID+1,R,l,r);
121     return res;
122 }
123 
124 inline int getsum(int x,int y)
125 {
126     int res=0;
127     while(top[x]!=top[y])
128     {
129         if(dep[top[x]]<dep[top[y]]) swap(x,y);
130         res+=querysum(1,1,tot,bh[top[x]],bh[x]);
131         x=fa[top[x]];
132     }
133     if(x==y) return res;//这句话好坑啊!把边权转移到点权上时会出现这个问题! 
134     if(bh[x]>bh[y]) swap(x,y);
135     res+=querysum(1,1,tot,bh[son[x]],bh[y]);//细节 
136     return res;
137 }
138 
139 inline int getlca(int x,int y)
140 {
141     if(dep[x]<dep[y]) swap(x,y);
142     for(int i=20;i>=0;i--)
143         if(dep[f[x][i]]>=dep[y]) x=f[x][i];
144     if(x==y) return x;
145     for(int i=20;i>=0;i--)
146         if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
147     return f[x][0];
148 }
149 
150 inline int getlen(int x,int lca)
151 {
152     int res=0;
153     for(int i=20;i>=0;i--)
154         if(dep[f[x][i]]>=dep[lca]) x=f[x][i],res+=bit[i];
155     return res;
156 }
157 
158 inline int getnum(int x,int p)
159 {
160     int res=0;
161     for(int i=20;i>=0;i--)
162         if(res+bit[i]<=p) x=f[x][i],res+=bit[i];
163     return x;
164 }
165 
166 inline int getkth(int x,int y,int p)
167 {
168     int lca=getlca(x,y);
169     int lx=getlen(x,lca)+1;
170     int ly=getlen(y,lca)+1;
171     if(lx>=p) return getnum(x,p-1);
172     return getnum(y,lx+ly-p-1);
173 }
174 
175 inline void go()
176 {
177     char str[10];int a,b,c;
178     while(scanf("%s",str))
179     {
180         if(str[1]=='O') break;
181         if(str[0]=='K')
182         {
183             scanf("%d%d%d",&a,&b,&c);
184             printf("%d\n",getkth(a,b,c));
185         }
186         else
187         {
188             scanf("%d%d",&a,&b);
189             printf("%d\n",getsum(a,b));
190         }
191     }
192     puts("");
193 }
194 
195 int main()
196 {
197     int cas;scanf("%d",&cas);
198     while(cas--) read(),go();
199     return 0;
200 }
原文地址:https://www.cnblogs.com/proverbs/p/2868299.html