bzoj1036: [ZJOI2008]树的统计Count(树链剖分)

1036: [ZJOI2008]树的统计Count

题目:传送门 

题解:

   数剖的模板题...就来水个经验

代码:

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<cstdlib>
  4 #include<cmath>
  5 #include<algorithm>
  6 using namespace std;
  7 int b[210000];
  8 struct node
  9 {
 10     int x,y,next;
 11 }a[410000];int len,last[210000];
 12 void ins(int x,int y)
 13 {
 14     len++;
 15     a[len].x=x;a[len].y=y;
 16     a[len].next=last[x];last[x]=len;
 17 }
 18 struct trnode
 19 {
 20     int l,r,c,lc,rc,sum;
 21 }tr[410000];int trlen;
 22 void bt(int l,int r)
 23 {
 24     int now=++trlen;
 25     tr[now].l=l;tr[now].r=r;tr[now].c=-999999999;tr[now].sum=0;
 26     tr[now].lc=tr[now].rc=-1;
 27     if(l<r)
 28     {
 29         int mid=(l+r)/2;
 30         tr[now].lc=trlen+1;bt(l,mid);
 31         tr[now].rc=trlen+1;bt(mid+1,r);
 32     }
 33 }
 34 int n,fa[210000],dep[210000],son[210000],tot[210000];
 35 void pre_tree_node(int x)
 36 {
 37     tot[x]=1;son[x]=0;
 38     for(int k=last[x];k;k=a[k].next)
 39     {
 40         int y=a[k].y;
 41         if(y!=fa[x])
 42         {
 43             fa[y]=x;
 44             dep[y]=dep[x]+1;
 45             pre_tree_node(y);
 46             if(tot[son[x]]<tot[y])son[x]=y;
 47             tot[x]+=tot[y];
 48         }
 49     }
 50 }
 51 int z,ys[210000],top[210000];
 52 void pre_tree_edge(int x,int tp)
 53 {
 54     ys[x]=++z;top[x]=tp;
 55     if(son[x]!=0)pre_tree_edge(son[x],tp);
 56     for(int k=last[x];k;k=a[k].next)
 57     {
 58         int y=a[k].y;
 59         if(y!=fa[x] && y!=son[x])
 60             pre_tree_edge(y,y);
 61     }
 62 }
 63 void change(int now,int p,int c)
 64 {
 65     if(tr[now].l==tr[now].r){tr[now].c=c;tr[now].sum=c;return ;}
 66     int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
 67     if(p<=mid)change(lc,p,c);
 68     else change(rc,p,c);
 69     tr[now].c=max(tr[lc].c,tr[rc].c);
 70     tr[now].sum=tr[lc].sum+tr[rc].sum;
 71 }
 72 int findmax(int now,int l,int r)
 73 {
 74     if(tr[now].l==l && r==tr[now].r)return tr[now].c;
 75     int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
 76     if(mid+1<=l)return findmax(rc,l,r);
 77     else if(r<=mid)return findmax(lc,l,r);
 78     else return max(findmax(lc,l,mid),findmax(rc,mid+1,r));
 79 }
 80 int getsum(int now,int l,int r)
 81 {
 82     if(tr[now].l==l && r==tr[now].r)return tr[now].sum;
 83     int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
 84     if(mid+1<=l)return getsum(rc,l,r);
 85     else if(r<=mid)return getsum(lc,l,r);
 86     else return getsum(lc,l,mid)+getsum(rc,mid+1,r);
 87 }
 88 int solve1(int x,int y)
 89 {
 90      int tx=top[x],ty=top[y],ans=-999999999;
 91      while(tx!=ty)
 92      {
 93         if(dep[tx]>dep[ty])
 94         {
 95             swap(x,y);
 96             swap(tx,ty);
 97         }
 98         ans=max(ans,findmax(1,ys[ty],ys[y]));
 99         y=fa[ty];ty=top[y];
100      }
101      if(x==y)return max(ans,findmax(1,ys[x],ys[x]));
102      else
103      {
104         if(dep[x]>dep[y])swap(x,y);
105         return max(ans,findmax(1,ys[x],ys[y]));
106      }
107 }
108 int solve2(int x,int y)
109 {
110     int tx=top[x],ty=top[y],ans=0;
111      while(tx!=ty)
112      {
113         if(dep[tx]>dep[ty])
114         {
115             swap(x,y);
116             swap(tx,ty);
117         }
118         ans+=getsum(1,ys[ty],ys[y]);
119         y=fa[ty];ty=top[y];
120      }
121      if(x==y)return b[x]+ans;
122      else
123      {
124         if(dep[x]>dep[y])swap(x,y);
125         return ans+getsum(1,ys[x],ys[y]);
126      }
127 }
128 int main()
129 {
130     int n,m;
131     scanf("%d",&n);
132     len=0;memset(last,0,sizeof(last));
133     for(int i=1;i<n;i++)
134     {
135         int x,y;
136         scanf("%d%d",&x,&y);
137         ins(x,y);
138         ins(y,x);
139     }
140     for(int i=1;i<=n;i++)
141     dep[1]=1;fa[1]=0;pre_tree_node(1);
142     z=0;pre_tree_edge(1,1);
143     trlen=0;bt(1,z);
144     for(int i=1;i<=n;i++){scanf("%d",&b[i]);change(1,ys[i],b[i]);}
145     char s[11];
146     scanf("%d",&m);
147     while(m--)
148     {
149         int x,y;
150         scanf("%s%d%d",s+1,&x,&y);
151         if(s[2]=='M')printf("%d
",solve1(x,y));
152         else if(s[2]=='S')printf("%d
",solve2(x,y));
153         else {b[x]=y;change(1,ys[x],y);}
154     }
155     return 0;
156 }
原文地址:https://www.cnblogs.com/CHerish_OI/p/8425285.html