树链剖分讲解

题目:Aragorn's Story

链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966

题意:给一棵树,每个结点都有初始的权值,有m个操作,分两种:一是从x 结点到y 结点路上所有的结点权值+z或-z,二是问x结点的权值。

思路:

  树链剖分。

  这是我学树剖的第一题,建议还没接触过的伙伴,第一次学习的时候不要一直纠结理论,直接找一道模板题,然后找一篇AC代码,直接理解,做完一题后,你就会发现理论其实也挺好理解的,树剖也挺好学的。

  树剖中的新概念:重孩子、轻孩子、重链、轻链,后面解释

  fa[x]:x 结点的父结点

  dep[x]:x 结点的深度

  siz[x]:以x 结点为根的子树的结点个数

  son[x]:x 结点的重孩子,即x 所有孩子中siz 最大的那个(相对的,其他的为轻孩子)

  结点x 和他重孩子的连边叫作重边,重边组成的一条链叫作重链。

  top[x]:x 所属的重链的头结点

  树链剖分就是为了合理地安排每个结点在线段树中的位置,这里,属于同一条重链的结点将分配在一起。

  pos[x]:x 结点在线段树中的位置

  xd[x]:和pos相反,表示线段树中x 位置的结点是哪个


  上面的数组定义给出后,有过基础的完全可以自己求出来了(下面AC代码的dfs1、dfs2),至于求出来什么用,再看下面。

  比如现在要求:x 结点到y 结点路上所有结点的权值+1。

  分两种:1. 如果x、y 两结点在同一条重链上,那么他们路上的结点其实就是线段树上连续的一个区间[x, y],那么像普通的线段树区间更新似的便可以解决。2. 如果x、y 不在同一个重链上,找到他们的链头,也就是top[x]、top[y],判断哪个深度大,选择深度大的那个,假设为x,现在我们可以更新区间[top[x],x],然后x指向top[x]的父结点,再次判断x、y是否同一重链。

  单点查询就不说了,和线段树单点查询一样。


  实现弄完了,我们来研究一下为什么他可以快速解决该类问题,从区间更新那里,我们可以看到,如果属于同一条重链,那么接下来的更新操作是线段树操作,这个时间复杂度是logn大家都学过了。也就是说如果有可能慢,那就慢在属于不同的重链,而且慢在必须一直跳(也就是说始终跳不到同一条重链),慢在重链的长度很短(一次只能跳一点)。如果始终没跳到一条重链上,那么跳的次数最多就是树的高度,那么会不会树的高度很大而一次跳很短呢,答案是否定的,因为结点x 的重孩子是其所有孩子结点中siz 最大的那个,如果要跳的y 在重孩子那棵子树,那么边x-son[x]是重边,是可以跳过的,如果要跳的y 在轻孩子z那棵子树,虽然x-z不是重边,是不能跳过去的,但重孩子至少分去了一半的结点,这样层层计算下来,最终时间复杂度也是logn,并不会出现跳的次数过多的情况。

AC代码:

  1 #include<stdio.h>
  2 #include<vector>
  3 #include<algorithm>
  4 using namespace std;
  5 #define N 100010
  6 #define lson rt<<1
  7 #define rson rt<<1|1
  8 int fa[N],dep[N],siz[N];
  9 int son[N];
 10 vector<int> e[N];
 11 void dfs1(int rt,int f,int h)
 12 {
 13   dep[rt]=h; fa[rt]=f; siz[rt]=1;
 14   for(int i=0;i<e[rt].size();i++)
 15   {
 16     int ad=e[rt][i];
 17     if(ad!=f)
 18     {
 19       dfs1(ad,rt,h+1);
 20       siz[rt]+=siz[ad];
 21       if(son[rt]==-1 || siz[ad]>siz[son[rt]])
 22         son[rt]=ad;
 23     }
 24   }
 25 }
 26 int top[N],pos[N],xd[N],po;
 27 void dfs2(int rt,int org)
 28 {
 29   top[rt]=org; pos[rt]=po++;
 30   xd[pos[rt]]=rt;
 31   if(son[rt]==-1) return;
 32   dfs2(son[rt],org);
 33   for(int i=0;i<e[rt].size();i++)
 34   {
 35     int ad=e[rt][i];
 36     if(ad!=son[rt] && ad!=fa[rt])
 37       dfs2(ad,ad);
 38   }
 39 }
 40 struct Node
 41 {
 42   int w,c;
 43   int l,r;
 44   int mid()
 45   {
 46     return (l+r)/2;
 47   }
 48 };
 49 Node v[N<<2];
 50 int num[N];
 51 
 52 void build(int l,int r,int rt)
 53 {
 54   v[rt].c=0;
 55   v[rt].l=l;
 56   v[rt].r=r;
 57   if(l==r)
 58   {
 59     v[rt].w=num[xd[l]];
 60     return;
 61   }
 62   build(l,v[rt].mid(),lson);
 63   build(v[rt].mid()+1,r,rson);
 64   v[rt].w = v[lson].w+v[rson].w;
 65 }
 66 void update(int val,int l,int r,int rt)
 67 {
 68   if(l<=v[rt].l && v[rt].r<=r)
 69   {
 70       v[rt].c+=val;
 71       v[rt].w+=val*(v[rt].r-v[rt].l+1);
 72       return;
 73   }
 74   if(v[rt].c)
 75   {
 76     v[lson].c += v[rt].c;
 77     v[rson].c += v[rt].c;
 78     v[lson].w += (v[lson].r-v[lson].l+1)*v[rt].c;
 79     v[rson].w += (v[rson].r-v[rson].l+1)*v[rt].c;
 80     v[rt].c=0;
 81   }
 82   int mid=v[rt].mid();
 83   if(l<=mid) update(val,l,r,lson);
 84   if(r>mid) update(val,l,r,rson);
 85   v[rt].w = v[lson].w+v[rson].w;
 86 }
 87 void change(int x,int y,int val)
 88 {
 89   while(top[x]!=top[y])
 90   {
 91     if(dep[top[x]]<dep[top[y]]) swap(x,y);
 92     update(val,pos[top[x]],pos[x],1);
 93     x=fa[top[x]];
 94   }
 95   if(dep[x]>dep[y]) swap(x,y);
 96   update(val,pos[x],pos[y],1);
 97 }
 98 int query(int rt,int val)
 99 {
100   if(v[rt].l==v[rt].r) return v[rt].w;
101 
102   if(v[rt].c)
103   {
104     v[lson].c += v[rt].c;
105     v[rson].c += v[rt].c;
106     v[lson].w += (v[lson].r-v[lson].l+1)*v[rt].c;
107     v[rson].w += (v[rson].r-v[rson].l+1)*v[rt].c;
108     v[rt].c=0;
109   }
110   int mid=v[rt].mid();
111   int ret=0;
112   if(val<=mid) ret=query(lson,val);
113   else ret=query(rson,val);
114   v[rt].w = v[lson].w+v[rson].w;
115   return ret;
116 }
117 int main()
118 {
119   int n,m,x,y,z;
120   while(~scanf("%d%*d%d",&n,&m))
121   {
122     po=1;
123     for(int i=1;i<=n;i++)
124     {
125       e[i].clear();
126       son[i]=-1;
127       scanf("%d",&num[i]);
128     }
129 
130     for(int i=1;i<n;i++)
131     {
132       scanf("%d%d",&x,&y);
133       e[x].push_back(y);
134       e[y].push_back(x);
135     }
136     dfs1(1,0,0);
137     dfs2(1,1);
138     build(1,n,1);
139     while(m--)
140     {
141       char s[10];
142       scanf("%s",s);
143       if(s[0]=='Q')
144       {
145         scanf("%d",&x);
146         printf("%d
",query(1,pos[x]));
147       }
148       else
149       {
150         scanf("%d%d%d",&x,&y,&z);
151         if(s[0]=='D') z=-z;
152         change(x,y,z);
153       }
154     }
155   }
156   return 0;
157 }
原文地址:https://www.cnblogs.com/hchlqlz-oj-mrj/p/6040775.html