树上的简单操作——树链剖分

某神犇:树链剖分什么垃圾,能做的LCT都能做,不能做的LCT也能做

前置条件:

线段树,(都会线段树了应该知道什么是树吧)

前言

现在考虑一棵树,每个节点都有一个点权,要求给x到y路径上的点都加上k,这个问题可以用树上差分很简单地在O(m+n)的复杂度内解决。再考虑一个问题,要求查找树上x到y这条路径上的权值和,也可以先求出每个点到根的dis,然后求出x和y的LCA,最后用公式:dis(x,y)=dis(x,root)+dis(y,root)-2*dis(LCA,root)简单地在O(mlogn+n)的时间内解决。那如果我们把着两种操作整合在一起呢?树链剖分就这么诞生出来了。

正篇

树链剖分,顾名思义,就是把一棵树残忍地剖成一条一条链,然后通过链之间的特性来用某些数据结构去维护它们。我们剖树的时候通常会遵循两大准则:重链剖分和实链剖分,本文暂时只讨论重链剖分。
一般来讲树链剖分的码量都很大,所以可以看作是一种模拟

重链

说到重链,我们先谈谈什么是重儿子。对于某一个树上节点u,它的重儿子就是它儿子里面那个size最大的儿子。可以看成,一个节点只能有一个重儿子,而其他的儿子被称作轻儿子。又重儿子组成的链叫做重链,由轻儿子组成的链叫做轻链。

来张图康康,重链都被加粗显示了:

在这张图里面,我们可以看到,1的儿子中3的size更大,所以1->3就被划分成了一条重链,同样,虽然1->2不是重链,但是2->5也可以是重链这样整棵树就被划分成了重链和轻链。

具体程序

首先,我们先声明一些变量:

int size[maxn],dep[maxn],f[maxn],hson[maxn];

size就是子树大小,dep是节点深度,f是节点父亲,hson是节点的重儿子

接着,我们开始写第一个dfs:

void dfs1(int u,int fa,int d){
	size[u]=1;
	f[u]=fa;
	dep[u]=d;
	int maxs=-1;
	for(int i=0;i<gpe[u].size();i++){
		int v=gpe[u][i];
		if(v==fa) continue;
		dfs1(v,u,d+1);
		size[u]+=size[v];
		if(size[v]>maxs) hson[u]=v,maxs=size[v];
	}
}

整个程序还是非常简单的,我们已经处理出来了整棵树的一些基本信息,那么我们现在就要开始把链整合在一起了。

第二个dfs: 这里我们引入一个dfs序的东西。

其实也蛮简单,就是对于某一个节点u,它在dfs的过程中被访问到的顺序,再来张图:

可能有人会问,图是不是错了啊,因为id[2]不应该是2吗,为什么id[3]是2啊?因为我们有一个规定,在第二次dfs的时候优先对重儿子进行搜索。因为我们必须保证任何一条重链上的点的dfs序是连续的,如果我们优先搜索2,那么3的dfs序就是5了,和1的dfs序不连续,也就失去了意义。我们还要对每一条重链进行标识。或者是说,对于某一个点u,我们要给出它所在的重链的顶部节点(轻节点的顶部节点是其本身),这里我们用top[u]来表示顶部节点。再最后,我们把点值也转移到另外一个数组里面。具体细节看程序吧,反正也不是很长:

int id[maxn],wt[maxn],top[maxn],cnt=0;
void dfs2(int u,int tc){
	id[u]=++cnt;
	wt[cnt]=a[u];
	top[u]=tc;
	if(!hson[u]) return;
	dfs2(hson[u],tc);
	for(int i=0;i<gpe[u].size();i++){
		int v=gpe[u][i];
		if(v==f[u]||v==hson[u]) continue;
		dfs2(v,v);
	}
}

id就是dfs序,wt是原点值a的转换,top就是链顶,cnt是用来计算dfs序的。

通过观察上面的图,我们发现,对于某一条链,可以把它拆成重链和轻链的组合,虽然我们没办法维护轻链,因为它们的dfs序并不连续,但是重链的dfs序是连续的。如果提到维护数列的区间和,那么我们肯定会想到 (分块) 线段树。接下来就是套一个线段树的模板了,这里就不多说,直接上代码:

struct node{
	ll sum,tag;
}t[maxn*2];
ll ans=0;
void update(ll pos){
	t[pos].sum=(t[pos<<1].sum+t[pos<<1|1].sum)%MOD;
}
void build(ll l,ll r,ll pos){
	if(l==r){
		t[pos].sum=wt[l];
		return;
	}
	ll mid=(l+r)/2;
	build(l,mid,pos<<1);
	build(mid+1,r,pos<<1|1);
	update(pos);
}
void change(ll pos,ll l,ll r,ll k)
{
    t[pos].tag=(t[pos].tag+k)%MOD;
    t[pos].sum=(t[pos].sum+k*(r-l+1))%MOD;
}
void pushdown(ll l,ll r,ll pos){
	if(!t[pos].tag) return;
	ll mid=(l+r)/2;
	change(pos<<1,l,mid,t[pos].tag);
	change(pos<<1|1,mid+1,r,t[pos].tag);
	t[pos].tag=0;
}
void add(ll tl,ll tr,ll l,ll r,ll v,ll pos){
	if(tl<=l&&tr>=r){
		t[pos].sum+=v*(r-l+1);
		t[pos].tag+=v;
		return;
	}
	if(r<tl||l>tr){
		return;
	}
	ll mid=(l+r)/2;
	pushdown(l,r,pos);
	add(tl,tr,l,mid,v,pos<<1);
	add(tl,tr,mid+1,r,v,pos<<1|1);
	update(pos);
}
void query(ll tl,ll tr,ll l,ll r,ll pos){
	if(tl<=l&&tr>=r){
		ans+=t[pos].sum;
		ans%=MOD;
		return;
	}
	if(r<tl||l>tr){
		return;
	}
	ll mid=(l+r)/2;
	pushdown(l,r,pos);
	query(tl,tr,l,mid,pos<<1);
	query(tl,tr,mid+1,r,pos<<1|1);
	return;
}

那么,对于某一个询问x和y之间的点值之和的询问,我们可以把它分成两部分:

  1. x到top[x]的重链区间和
  2. top[x]到top[y]的轻链和
  3. top[y]到y的重链和

其实现实中情况比这个复杂,打个比方,有一种很奇怪的食物,两块面包中间由一根面条连接(?)我们可以一口吃掉一块面包O(logn),但是吃面条要用到O(n),那么我们最简单的想法就是从这种奇怪的食物的某一个节点id[x]吃到id[top[x]]来吃掉一个面包(重链),然后从id[top[x]]到f[top[x]]去吃掉一根面条(轻链),就这么下去直到吃掉最后一块面包。程序如下:

int c_ask(int x,int y){
	int ret=0;
	ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		query(id[top[x]],id[x],1,n,1);
		ret=(ret+ans)%MOD;
		ans=0;
		x=f[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	query(id[x],id[y],1,n,1);
	ret=(ret+ans)%MOD;
	ans=0;
	return ret;
}

(由于我线段树板子写的太恶心所以要一遍一遍地重置ans值,但是我懒得改了)

链上修改也很简单,照着套就完事了:

void c_add(int x,int y,int val){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		add(id[top[x]],id[x],1,n,val,1);
		x=f[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	add(id[x],id[y],1,n,val,1);
}

再就是子树修改:

add(id[x],id[x]+size[x]-1,1,n,v%MOD,1);

因为实际上一棵子树的dfs序也是连续的,可以自己手动模拟一下,所以就是简单地加上size[x]-1就好了

子树查询:

query(id[x],id[x]+size[x]-1,1,n,1);

一道模板题:https://www.luogu.com.cn/problem/P3384

AC代码:

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=2*1e6+10;
vector<int> gpe[maxn];
int a[maxn],n,m,r,MOD;
int size[maxn],dep[maxn],f[maxn],hson[maxn];
void dfs1(int u,int fa,int d){
	size[u]=1;
	f[u]=fa;
	dep[u]=d;
	int maxs=-1;
	for(int i=0;i<gpe[u].size();i++){
		int v=gpe[u][i];
		if(v==fa) continue;
		dfs1(v,u,d+1);
		size[u]+=size[v];
		if(size[v]>maxs) hson[u]=v,maxs=size[v];
	}
}
int id[maxn],wt[maxn],top[maxn],cnt=0;
void dfs2(int u,int tc){
	id[u]=++cnt;
	wt[cnt]=a[u];
	top[u]=tc;
	if(!hson[u]) return;
	dfs2(hson[u],tc);
	for(int i=0;i<gpe[u].size();i++){
		int v=gpe[u][i];
		if(v==f[u]||v==hson[u]) continue;
		dfs2(v,v);
	}
}
struct node{
	ll sum,tag;
}t[maxn*2];
ll ans=0;
void update(ll pos){
	t[pos].sum=(t[pos<<1].sum+t[pos<<1|1].sum)%MOD;
}
void build(ll l,ll r,ll pos){
	if(l==r){
		t[pos].sum=wt[l];
		return;
	}
	ll mid=(l+r)/2;
	build(l,mid,pos<<1);
	build(mid+1,r,pos<<1|1);
	update(pos);
}
void change(ll pos,ll l,ll r,ll k)
{
    t[pos].tag=(t[pos].tag+k)%MOD;
    t[pos].sum=(t[pos].sum+k*(r-l+1))%MOD;
}
void pushdown(ll l,ll r,ll pos){
	if(!t[pos].tag) return;
	ll mid=(l+r)/2;
	change(pos<<1,l,mid,t[pos].tag);
	change(pos<<1|1,mid+1,r,t[pos].tag);
	t[pos].tag=0;
}
void add(ll tl,ll tr,ll l,ll r,ll v,ll pos){
	if(tl<=l&&tr>=r){
		t[pos].sum+=v*(r-l+1);
		t[pos].tag+=v;
		return;
	}
	if(r<tl||l>tr){
		return;
	}
	ll mid=(l+r)/2;
	pushdown(l,r,pos);
	add(tl,tr,l,mid,v,pos<<1);
	add(tl,tr,mid+1,r,v,pos<<1|1);
	update(pos);
}
void query(ll tl,ll tr,ll l,ll r,ll pos){
	if(tl<=l&&tr>=r){
		ans+=t[pos].sum;
		ans%=MOD;
		return;
	}
	if(r<tl||l>tr){
		return;
	}
	ll mid=(l+r)/2;
	pushdown(l,r,pos);
	query(tl,tr,l,mid,pos<<1);
	query(tl,tr,mid+1,r,pos<<1|1);
	return;
}
int c_ask(int x,int y){
	int ret=0;
	ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		query(id[top[x]],id[x],1,n,1);
		ret=(ret+ans)%MOD;
		ans=0;
		x=f[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	query(id[x],id[y],1,n,1);
	ret=(ret+ans)%MOD;
	ans=0;
	return ret;
}
void c_add(int x,int y,int val){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		add(id[top[x]],id[x],1,n,val,1);
		x=f[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	add(id[x],id[y],1,n,val,1);
}
int main(void){
	scanf("%d %d %d %d",&n,&m,&r,&MOD);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	for(int i=1;i<=n-1;i++){
		int u,v;
		scanf("%d %d",&u,&v);
		gpe[u].push_back(v);
		gpe[v].push_back(u);
	}
	dfs1(r,r,1);
	dfs2(r,r);
	build(1,n,1);
	while(m--){
		int opt,x,y,v;
		scanf("%d",&opt);
		if(opt==1){
			scanf("%d %d %d",&x,&y,&v);
			c_add(x,y,v);
		}else if(opt==2){
			scanf("%d %d",&x,&y);
			printf("%d
",c_ask(x,y));
		}else if(opt==3){
			scanf("%d %d",&x,&v);
			add(id[x],id[x]+size[x]-1,1,n,v%MOD,1);
		}else{
			scanf("%d",&x);
			query(id[x],id[x]+size[x]-1,1,n,1);
			printf("%d
",ans);
			ans=0;
		}
	}
}

树链剖分时间复杂度

假设(u,v)是一条轻边,那么size(v)<size(u)/2,并且从根节点到任意节点x之间的路径上轻重链的个数<logn

所以,树链剖分的时间复杂度是O(nlog^2n)

原文地址:https://www.cnblogs.com/jrdxy/p/12350133.html