【模板】树链剖分

[ZJOI2008]树的统计

洛谷传送门

第一遍树链剖分,打的很难受。

其中拉闸了,检查真是费劲。

树链剖分是什么?

树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。

树链剖分可以支持链上求和,链上求最值,链上修改等线段树的操作。

但若断开一条边或者连接两个点,保证两个点连接后依然是棵树。这样树链剖分就虚了,因为线段树不支持这种操作,就需要把线段树换成splay,于是LCT = 树剖 + splay。

将树中的边分为:轻边和重边 ž定义size(X)为以X为根的子树的节点个数。 ž令V为U的儿子节点中size值最大的节点,那么边(U,V)被称为重边,树中重边之外的边被称为轻边。
性质:ž轻边(U,V),size(V)<=size(U)/2。 ž从根到某一点的路径上,不超过O(logN)条轻边,不超过O(logN)条重路径。

说明:

重孩子:儿子节点所有孩子中size最大的

轻孩子:儿子节点中除了重儿子的节点

重边:连接重儿子的边

轻边:连接轻儿子的边

重链:重边连成的链

轻链:轻边连成的链

a[i] 表示节点 i 权值

f[i] 表示节点 i 的父亲在原树中的位置

son[i] 表示节点 i 的重儿子在原树中的位置

top[i] 表示节点 i 所在链的顶端节点在原树中的位置,就是深度最小的

size[i] 表示以 i 为根的子树节点个数

tid[i] 表示树中节点 i 剖分后的新编号

rank[i] 表示剖分后的节点 i 在原树中的位置

deep[i] 表示节点 i 深度,根节点深度为 1 

实现方法:

第一遍dfs可以预处理出size,deep,f,son数组

第二遍dfs可以预处理出top,tid,rank数组,通过优先搜索重边,然后搜索轻边

树链剖分目的是把树上的边剖分成一个链,就是一个线段,标号是连续的。

为什么要先搜索重边呢?

可以看出,这样搜可以使得重链上的点的dfs序是连续的,可以用线段树来维护

如何查询呢?

判断两点是否属于同一条重链,如果属于,就直接修改,因为他们是连续的,如果不属于,就从深度大点开始不停地找他父亲跳轻链,其中深度是不停地在变的,也就是说,两个点可能会轮着跳,直到属于同一个重链。现在看来,轻边实际上是连接重链的东西。

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <iostream>
  4 #define rt 1, 1, n
  5 #define ls o << 1, l, m
  6 #define rs o << 1 | 1, m + 1, r
  7 
  8 using namespace std;
  9 
 10 const int maxn = 300001;
 11 const int INF = 99999999;
 12 int n, m, q, cnt, tim;
 13 int a[maxn], head[maxn], to[maxn << 2], next[maxn << 2], deep[maxn], size[maxn];
 14 int  son[maxn], top[maxn], f[maxn], tid[maxn], rank[maxn], sumv[maxn], maxv[maxn];
 15 //a节点权值, deep节点深度, size以x为根的子树节点个数, son重儿子, top当前节点所在链的顶端节点
 16 //f当前节点父亲, tid保存树中每个节点剖分后的新编号,  rank保存剖分后的节点在线段树中的位置 
 17 
 18 void add(int x, int y)
 19 {
 20     to[cnt] = y;
 21     next[cnt] = head[x];
 22     head[x] = cnt++;
 23 }
 24 
 25 void dfs1(int u, int father)//记录所有重边 
 26 {
 27     int i, v;
 28     f[u] = father;
 29     size[u] = 1;
 30     deep[u] = deep[father] + 1;
 31     for(i = head[u]; i != -1; i = next[i])
 32     {
 33         v = to[i];
 34         if(v == father) continue;
 35         dfs1(v, u);
 36         size[u] += size[v];
 37         if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v;
 38     }
 39 }
 40 
 41 void dfs2(int u, int tp)
 42 {
 43     int i, v;
 44     top[u] = tp;
 45     tid[u] = ++tim;
 46     rank[tim] = u;
 47     if(son[u] == -1) return;
 48     dfs2(son[u], tp);//重边 
 49     for(i = head[u]; i != -1; i = next[i])
 50     {
 51         v = to[i];
 52         if(v != son[u] && v != f[u]) dfs2(v, v);//轻边 
 53     }
 54 }
 55 
 56 void pushup(int o)
 57 {
 58     sumv[o] = sumv[o << 1] + sumv[o << 1 | 1];
 59     maxv[o] = max(maxv[o << 1], maxv[o << 1 | 1]);
 60 }
 61 
 62 void updata(int o, int l, int r, int d, int x)
 63 {
 64     int m = (l + r) >> 1;
 65     if(l == r)
 66     {
 67         sumv[o] = maxv[o] = x;
 68         return;
 69     }
 70     if(d <= m) updata(ls, d, x);
 71     else updata(rs, d, x);
 72     pushup(o);
 73 }
 74 
 75 void build(int o, int l, int r)
 76 {
 77     int m = (l + r) >> 1;
 78     if(l == r)
 79     {
 80         sumv[o] = maxv[o] = a[rank[l]];
 81         return;
 82     }
 83     build(ls);
 84     build(rs);
 85     pushup(o);
 86 }
 87 
 88 int querymax(int o, int l, int r, int ql, int qr)
 89 {
 90     int m = (l + r) >> 1, ans = -INF;
 91     if(ql <= l && r <= qr) return maxv[o];
 92     if(ql <= m) ans = max(ans, querymax(ls, ql, qr));
 93     if(m < qr) ans = max(ans, querymax(rs, ql, qr));
 94     pushup(o);
 95     return ans;
 96 }
 97 
 98 int qmax(int u, int v)
 99 {
100     int ans = -INF;
101     while(top[u] != top[v])//判断是否在一条重链上 
102     {
103         if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 
104         ans = max(ans, querymax(rt, tid[top[u]], tid[u]));
105         u = f[top[u]];
106     }
107     if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 
108     ans = max(ans, querymax(rt, tid[v], tid[u]));
109     return ans;
110 }
111 
112 int querysum(int o, int l, int r, int ql, int qr)
113 {
114     int m = (l + r) >> 1, ans = 0;
115     if(ql <= l && r <= qr) return sumv[o];
116     if(ql <= m) ans += querysum(ls, ql, qr);
117     if(m < qr) ans += querysum(rs, ql, qr);
118     pushup(o);
119     return ans;
120 }
121 
122 int qsum(int u, int v)
123 {
124     int ans = 0;
125     while(top[u] != top[v])//判断是否在一条重链上
126     {
127         if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 
128         ans += querysum(rt, tid[top[u]], tid[u]);
129         u = f[top[u]];
130     }
131     if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 
132     ans += querysum(rt, tid[v], tid[u]);
133     return ans;
134 }
135 
136 int main()
137 {
138     int i, j, x, y;
139     char s[11];
140     memset(head, -1, sizeof(head));
141     memset(son, -1, sizeof(son));
142     scanf("%d", &n);
143     for(i = 1; i < n; i++)
144     {
145         scanf("%d %d", &x, &y);
146         add(x, y);
147         add(y, x);
148     }
149     for(i = 1; i <= n; i++) scanf("%d", &a[i]);
150     dfs1(1, 1);//根节点和他的父亲
151     dfs2(1, 1);//根节点和链头结点 
152     build(rt);
153     scanf("%d", &q);
154     for(i = 1; i <= q; i++)
155     {
156         scanf("%s %d %d", s, &x, &y);
157         if(s[1] == 'H') updata(rt, tid[x], y);//把位置为x的点修改为y 
158         if(s[1] == 'M') printf("%d
", qmax(x, y));
159         if(s[1] == 'S') printf("%d
", qsum(x, y));
160     }
161     return 0;
162 }
View Code

洛谷模板题

检查了n遍代码,检查了n遍函数,都没有检查出错误来。

最后偶然发现只是因为调用错了函数。。。

这道题要使一个子树的值都加x,且统计子树所有节点的值的和。

可以发现,一个子树上的点在线段树中的编号是连续的,所以可以对区间 ( tim[x], tim[x] + size[x] - 1 ) 进行操作。

——代码

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <iostream>
  4 #define LL long long
  5 #define rt 1, 1, n
  6 #define ls o << 1, l, m
  7 #define rs o << 1 | 1, m + 1, r
  8 
  9 using namespace std;
 10 
 11 const int maxn = 100001;
 12 int n, m, s, cnt, tim;
 13 int head[maxn], next[maxn << 2], to[maxn << 2], deep[maxn], f[maxn], size[maxn], son[maxn], top[maxn], rank[maxn], tid[maxn];
 14 LL p, a[maxn], sumv[maxn << 2], addv[maxn << 2];
 15 
 16 void add(int x, int y)
 17 {
 18     to[cnt] = y;
 19     next[cnt] = head[x];
 20     head[x] = cnt++;
 21 }
 22 
 23 void dfs1(int u, int father)
 24 {
 25     int i, v;
 26     size[u] = 1;
 27     f[u] = father;
 28     deep[u] = deep[father] + 1;
 29     for(i = head[u]; i != -1; i = next[i])
 30     {
 31         v = to[i];
 32         if(v == father) continue;
 33         dfs1(v, u);
 34         size[u] += size[v];
 35         if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v;
 36     }
 37 }
 38 
 39 void dfs2(int u, int tp)
 40 {
 41     int i, v;
 42     top[u] = tp;
 43     tid[u] = ++tim;
 44     rank[tim] = u;
 45     if(son[u] == -1) return;
 46     dfs2(son[u], tp);
 47     for(i = head[u]; i != -1; i = next[i])
 48     {
 49         v = to[i];
 50         if(v != son[u] && v != f[u]) dfs2(v, v);
 51     }
 52 }
 53 
 54 void pushup(int o)
 55 {
 56     sumv[o] = (sumv[o << 1] + sumv[o << 1 | 1]) % p;
 57 }
 58 
 59 void pushdown(int o, int len)
 60 {
 61     addv[o << 1] = (addv[o << 1] + addv[o]) % p;
 62     addv[o << 1 | 1] = (addv[o << 1 | 1] + addv[o]) % p;
 63     sumv[o << 1] = (sumv[o << 1] + addv[o] * (len - (len >> 1))) % p;
 64     sumv[o << 1 | 1] = (sumv[o << 1 | 1] + addv[o] * (len >> 1)) % p;
 65     addv[o] = 0;
 66 }
 67 
 68 void build(int o, int l, int r)
 69 {
 70     if(l == r)
 71     {
 72         sumv[o] = a[rank[l]] % p;
 73         return;
 74     }
 75     int m = (l + r) >> 1;
 76     build(ls);
 77     build(rs);
 78     pushup(o);
 79 }
 80 
 81 void updata(int o, int l, int r, int ql, int qr, LL d)
 82 {
 83     if(ql <= l && r <= qr)
 84     {
 85         addv[o] = (addv[o] + d) % p;
 86         sumv[o] = (sumv[o] + d * (r - l + 1)) % p;
 87         return;
 88     }
 89     if(l > qr || r < ql) return;
 90     if(addv[o]) pushdown(o, r - l + 1);
 91     int m = (l + r) >> 1;
 92     updata(ls, ql, qr, d);
 93     updata(rs, ql, qr, d);
 94     pushup(o);
 95 }
 96 
 97 void qdata(int u, int v, LL d)
 98 {
 99     while(top[u] != top[v])
100     {
101         if(deep[top[u]] < deep[top[v]]) swap(u, v);
102         updata(rt, tid[top[u]], tid[u], d);
103         u = f[top[u]];
104     }
105     if(deep[u] > deep[v]) swap(u, v);
106     updata(rt, tid[u], tid[v], d);
107 }
108 
109 LL querysum(int o, int l, int r, int ql, int qr)
110 {
111     if(ql <= l && r <= qr) return sumv[o];
112     if(l > qr || r < ql) return 0;
113     if(addv[o]) pushdown(o, r - l + 1);
114     int m = (l + r) >> 1;
115     return (querysum(ls, ql, qr) + querysum(rs, ql, qr)) % p;
116 }
117 
118 LL qsum(int u, int v)
119 {
120     LL ans = 0;
121     while(top[u] != top[v])
122     {
123         if(deep[top[u]] < deep[top[v]]) swap(u, v);
124         ans = (ans + querysum(rt, tid[top[u]], tid[u])) % p;
125         u = f[top[u]];
126     }
127     if(deep[u] > deep[v]) swap(u, v);
128     ans = (ans + querysum(rt, tid[u], tid[v])) % p;
129     return ans;
130 }
131 
132 int main()
133 {
134     int i, j, c, x, y;
135     LL z;
136     scanf("%d %d %d %lld", &n, &m, &s, &p);
137     for(i = 1; i <= n; i++) scanf("%lld", &a[i]);
138     memset(head, -1, sizeof(head));
139     memset(son, -1, sizeof(son));
140     for(i = 1; i < n; i++)
141     {
142         scanf("%d %d", &x, &y);
143         add(x, y);
144         add(y, x);
145     }
146     dfs1(s, s);//根节点和他的父亲
147     dfs2(s, s);//根节点和链头结点 
148     build(rt);
149     for(i = 1; i <= m; i++)
150     {
151         scanf("%d", &c);
152         if(c == 1)
153         {
154             scanf("%d %d %lld", &x, &y, &z);
155             qdata(x, y, z % p);
156         }
157         else if(c == 2)
158         {
159             scanf("%d %d", &x, &y);
160             printf("%lld
", qsum(x, y));
161         }
162         else if(c == 3)
163         {
164             scanf("%d %lld", &x, &z);
165             updata(rt, tid[x], tid[x] + size[x] - 1, z % p);
166         }
167         else
168         {
169             scanf("%d", &x);
170             printf("%lld
", querysum(rt, tid[x], tid[x] + size[x] - 1));
171         }
172     }
173     return 0;
174 }
View Code
原文地址:https://www.cnblogs.com/zhenghaotian/p/6705918.html