洛谷 P3384 [模板] 树链剖分

传送门

树链剖分

本质上,树链剖分是一种将树肢解成平摊开来,再使用线段树对其进行维护的神奇算法。

我们需要通过两次 (dfs) ,预处理一些我们需要的东西,这里是第一次:

  1. 树上每个节点的父亲,这个不必多说,方便后续找 (LCA) 时的上跳过程;

  2. 树上每个节点的深度,这个也不必多说,方便后续决定对哪个节点进行操作;

  3. 每个节点的子树大小,同时标记每个节点的重儿子

何为重儿子? 顾名思义,对于一个节点,他的所有儿子中子树大小最大的那个儿子就是重儿子。其他的儿子我们就称其为轻儿子。

对于一条树边,连接重儿子的边我们叫他重边,连接轻儿子的边我们叫他轻边(又称轻链),重边连成的链我们叫他重链。

这里是第二次:

  1. 每个节点的 (dfs) 序,将树肢解后节点就按这个顺序平摊,值得注意的是每次要先遍历该节点的重儿子,回溯之后再遍历其他出点;

  2. (dfs) 编号所对应的节点编号,方便逆向访问到这个节点;

  3. 对于每各节点,我们记录顺着该节点所在的链向上走所能到达的最上方的节点(链顶);

形象地说,重链是我们修建的高速公路,在其上我们可以直接快速达到一条链的最顶部,而轻链则不同,在其上只能一步一步向上跳跃(也可以说,轻链的链顶就是当前点的父亲,轻链是长度为1
的重链)。

实际上,所谓轻链重链并没有什么特殊意义,其本质是将一棵树剖成几条链的一个较为方便的策略。

我们容易知道,树上任意两个点的最短路径都可以通过上文提到的轻重链来达到。对于每个链顶深度较大的点,我们让他跳跃到他链顶的父亲的位置,并对他途径的链进行区间操作

这个时候,你就会发现将重链 (dfs) 序连续的妙处所在了:

每一个重链都处在一段连续的区间上,我们可以统一对其进行处理。

而对于轻链,由于轻链长度只有一,所以不在乎所处区间是否连续。

重复上述操作,直到两个点的链顶相同为止。通过这样的步骤,我们将树上两点间的最短路径拆分成了数个区间,然后转化为了区间操作。

那么对子树的操作如何实现呢?

不难发现,每一个节点的子树在区间上都连续,理由很简单,只有当前子树递归完毕之后才会访问另一颗子树。同时,我们也可以得到这个区间的左右端点,若设当前节点的 (dfs) 序为 (x)

(left node : x , right node : x + size[x] - 1)

然后,对这个区间进行区间操作即可。

怎么样,是不是觉得非常简单?

当然,在代码实现的过程中,依旧有一些小小的细节值得注意:

  1. 线段树,不用我说,写错了就拖出去枪毙十分钟(笔者至少被枪毙了半小时);

  2. 对于子树操作(等价于区间操作),参数是点的 (dfs) 序,而对于最短路操作,参数则是点原本的序号,务必要搞清楚 (dfs) 序与原本点的编号的异同;

  3. 先两次 (dfs) ,执行完毕后再建树(这问题太蠢我不忍直视);

  4. (dfs1) 中注意不要又跑回父亲节点了,(dfs2) 中注意到了叶子节点要及时终止函数。

一时间就想到这么多,希望能对大家有所帮助。

以下提供模板代码,为了锻炼读者的代码阅读能力(我懒),没有加任何注释,愿各位食用愉快(逃)

模板代码

#include<iostream>
#include<cctype>
#include<cstdio>
using namespace std;
typedef long long ll; 
const int maxn = 50005;
ll read(){
	ll re = 0,ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) re = (re<<1) + (re<<3) + ch - '0',ch = getchar();
	return re;
}
int n,m,r,p;
struct edge{
	int v,nxt;
}e[maxn<<1];
int h[maxn],cnt;
void addedge(int u,int v){
	e[++cnt].v = v;
	e[cnt].nxt = h[u];
	h[u] = cnt;
}
int fa[maxn],sz[maxn],son[maxn],dfn[maxn],rev[maxn],val[maxn],dis[maxn],top[maxn];
void dfs1(int u,int f){
	dis[u] = dis[f] + 1;
	fa[u] = f;
	sz[u] = 1;
	for(int i = h[u];i;i = e[i].nxt){
		if(e[i].v != f){
			dfs1(e[i].v,u);
			sz[u] += sz[e[i].v];
			if(sz[e[i].v] > sz[son[u]]) son[u] = e[i].v;
		}
	}
}
void dfs2(int u,int topf){
	dfn[u] = ++cnt;
	rev[cnt] = u;
	top[u] = topf;
	if(!son[u]) return;
	dfs2(son[u],topf);
	for(int i = h[u];i;i = e[i].nxt)
		if(!dfn[e[i].v]) dfs2(e[i].v,e[i].v);
}
struct node{
	int l,r;
	ll sum,add;
	#define l(x) t[x].l 
	#define r(x) t[x].r
	#define sum(x) t[x].sum
	#define add(x) t[x].add
	#define mid(x) (t[x].r + t[x].l >> 1)
}t[maxn<<2];
void pushdown(int x){
	if(add(x)){
		sum(x<<1) += add(x) * (mid(x) - l(x) + 1);
		sum(x<<1|1) += add(x) * (r(x) - mid(x));
		add(x<<1) += add(x);
		add(x<<1|1) += add(x);
		add(x) = 0;
	}
}
void pushup(int x){
	sum(x) = (sum(x<<1) % p + sum(x<<1|1) % p) % p;
}
void build(int x,int l,int r){
	l(x) = l;
	r(x) = r;
	if(l == r){
		sum(x) = val[rev[l]];
		return;
	}
	build(x<<1,l,mid(x));
	build(x<<1|1,mid(x) + 1,r);
	pushup(x);
}
void modify(int x,int l,int r,int v){
	if(l <= l(x) && r >= r(x)){
		sum(x) += (r(x) - l(x) + 1) * v;
		add(x) += v;
		sum(x) %= p;
		add(x) %= p;
		return;
	}
	pushdown(x);
	if(l <= mid(x)) modify(x<<1,l,r,v);
	if(r > mid(x)) modify(x<<1|1,l,r,v);
	pushup(x); 
}
ll quiry(int x,int l,int r){
	ll ans = 0;
	if(l <= l(x) && r >= r(x))
		return sum(x);	
	pushdown(x);
	if(l <= mid(x)) ans += quiry(x<<1,l,r);
	if(r > mid(x)) ans += quiry(x<<1|1,l,r);
	return ans % p;
}
void tadd(int x,int y,int v){
	while(top[x] != top[y]){
		if(dis[top[x]] < dis[top[y]]) swap(x,y);
		modify(1,dfn[top[x]],dfn[x],v);
		x = fa[top[x]];
	}
	if(dis[x] > dis[y]) swap(x,y);
	modify(1,dfn[x],dfn[y],v);
}
ll task(int x,int y){
	ll ans = 0;
	while(top[x] != top[y]){
		if(dis[top[x]] < dis[top[y]]) swap(x,y);
		ans += quiry(1,dfn[top[x]],dfn[x]);
		ans %= p;
		x = fa[top[x]];
	}
	if(dis[x] > dis[y]) swap(x,y);
	ans += quiry(1,dfn[x],dfn[y]);
	return ans % p;
}
int main(){
	n = read(),m = read(),r = read(),p = read();
	for(int i = 1;i <= n;i++) val[i] = read();
	for(int i = 1;i < n;i++){
		int u = read(),v = read();
		addedge(u,v);
		addedge(v,u);
	}
	cnt = 0; 
	dfs1(r,0);
	dfs2(r,0);
	build(1,1,n);
	for(int i = 1;i <= m;i++){
		int op = read(),x,y,z;
		if(op == 1){
			x = read(),y = read(),z = read();
			tadd(x,y,z);		
		}
		if(op == 2){
			x = read(),y = read(); 
			printf("%lld
",task(x,y));
		}
		if(op == 3){
			x = read(),z = read();
			modify(1,dfn[x],dfn[x] + sz[x] - 1,z);
		}
		if(op == 4){
			x = read();
			printf("%lld
",quiry(1,dfn[x],dfn[x] + sz[x] - 1));
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/mysterious-garden/p/9859599.html