【模板】树链剖分

洛咕

已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1:格式:1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2:格式:2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3:格式:3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4:格式:4 x 表示求以x为根节点的子树内所有节点值之和

分析:学会树链剖分之前,一定要学会线段树(包括延迟标记,标记下传).因为个人认为树链剖分并不难理解,码量也不大.而之所以树链剖分的题往往至少一百行,都是线段树的锅.树链剖分之于线段树,就像分块之于莫队一样,只是个前置技能.

树链剖分,顾名思义就是把树剖分成很多条链.根据剖分方式,分有重链剖分,长链剖分等等.而我们平常的树链剖分一般就是指的重链剖分(因为一般只用重链剖分).

树链剖分的核心就是如何恰当地将树剖分成若干条链.剖分完之后,把一条条链当做一个个区间,就可以进行恶心的线段树操作了.

首先我们要明确"重边"和"轻边",定义size[u]表示以u为根的子树的节点个数.对于一个节点u,若v是u的儿子节点中size最大的节点,则(u,v)是重边.(注意:如果有两个节点v1,v2同时满足条件,则任意选一个点就行了),然后u到其它儿子的路径就全都是轻边.全由重边构成的路径叫做重路径.

明确几个数组的含义:

fa[x]    //x在树上的父亲节点
deep[x]  //x在树中的深度
size[x]  //以x为根的子树的节点个数
son[x]   //x的重儿子
top[x]   //x所在重路径的顶部节点(深度最小的节点)
seg[x]   //x在线段树中的位置(下标)
rev[x]   //线段树中第x个位置对应的树中节点编号
//最后两个数组是为了便于线段树和树的节点转换,有点绕.

dfs1函数可以计算前4个值

void dfs1(int u,int father){
    size[u]=1;deep[u]=deep[father]+1;fa[u]=father;
    for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==father)continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[son[u]]<size[v])son[u]=v;//更新重儿子
    }
}

dfs2函数计算后3个值

void dfs2(int u,int father){
    if(son[u]){//优先处理重儿子
		top[son[u]]=top[u];
		seg[son[u]]=++seg[0];
		rev[seg[0]]=son[u];
		dfs2(son[u],u);
    }
    for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(!top[v]){
	    	top[v]=v;
	    	seg[v]=++seg[0];
	    	rev[seg[0]]=v; 
	    	dfs2(v,u);
		}
    }
}

至此,树链剖分就大功告成了,接下来就是线段树上场了.

还有一处代码想特别讲一下,自己学习的时候卡了很久才理解.有点类似于树上倍增LCA,那里是(2^k)一步往上跳,而这里是借助重链往上跳,当我们确定两个节点不在同一条重链上时,就可以直接让深度较大的节点x跳到它所在的重路径的顶端.因为我们dfs2函数中每次优先处理重儿子,所以保证了重路径对应在线段树上一定会是一段连续的区间,这样就可以很方便地处理了.

void ask1(int x,int y){
    int ans=0;
    while(top[x]!=top[y]){//如果不在同一条重路径上
		if(deep[top[x]]<deep[top[y]])swap(x,y);
//保证x是深度较大的点
		ans=(ans+query(1,1,seg[0],seg[top[x]],seg[x]))%mod;
//seg[top[x]]到seg[x]在线段树上一定连续
		x=fa[top[x]];
    }
    if(deep[x]>deep[y])swap(x,y);
    ans=(ans+query(1,1,seg[0],seg[x],seg[y]))%mod;
//此时x,y在同一条重路径上,我们保证x的深度较浅
//则seg[x]到seg[y]在线段树上也一定连续
    printf("%d
",ans%mod);
}

over....

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){s=s*10+ch-'0';ch=getchar();}
    return s*w;
}
int n,m,root,mod;
int val[100005],sum[400005],add[400005];
int size[100005],deep[100005],fa[100005],son[100005];
int rev[400005],seg[100005],top[100005];
int tot,head[100005],nxt[200005],to[200005];
void add_edge(int a,int b){
    nxt[++tot]=head[a];head[a]=tot;to[tot]=b;
    nxt[++tot]=head[b];head[b]=tot;to[tot]=a;
}
void dfs1(int u,int father){
    size[u]=1;deep[u]=deep[father]+1;fa[u]=father;
    for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==father)continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[son[u]]<size[v])son[u]=v;
    }
}
void dfs2(int u,int father){
    if(son[u]){
		top[son[u]]=top[u];
		seg[son[u]]=++seg[0];
		rev[seg[0]]=son[u];
		dfs2(son[u],u);
    }
    for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(!top[v]){
	    	top[v]=v;
	    	seg[v]=++seg[0];
	    	rev[seg[0]]=v; 
	    	dfs2(v,u);
		}
    }
}
void build(int k,int l,int r){
    if(l==r){sum[k]=val[rev[l]]%mod;return;}
    int mid=(l+r)>>1;
    build(k<<1,l,mid);build((k<<1)+1,mid+1,r);
    sum[k]=(sum[k<<1]+sum[(k<<1)+1])%mod;
}
void Add(int k,int l,int r,int val){
    add[k]=(add[k]+val)%mod;
    sum[k]=sum[k]%mod+((r-l+1)*val)%mod;
    return;
}
void pushdown(int k,int l,int r,int mid){
    if(add[k]==0)return;
    Add(k<<1,l,mid,add[k]);
    Add((k<<1)+1,mid+1,r,add[k]);
    add[k]=0;
}
void change(int k,int l,int r,int L,int R,int val){
    if(l>=L&&r<=R)return Add(k,l,r,val);
    int mid=(l+r)>>1;
    pushdown(k,l,r,mid);
    if(L<=mid)change(k<<1,l,mid,L,R,val);
    if(R>mid)change((k<<1)|1,mid+1,r,L,R,val);
    sum[k]=(sum[k<<1]+sum[(k<<1)|1])%mod;
}
void change1(int x,int y,int val){
    while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		change(1,1,seg[0],seg[top[x]],seg[x],val);
		x=fa[top[x]];
    }
    if(deep[x]>deep[y])swap(x,y);
    change(1,1,seg[0],seg[x],seg[y],val);
}
int query(int k,int l,int r,int L,int R){
    if(l>=L&&r<=R)return sum[k];
    int mid=(l+r)>>1;
    pushdown(k,l,r,mid);
    int res=0;
    if(L<=mid)res=(res+query(k<<1,l,mid,L,R))%mod;
    if(R>mid)res=(res+query((k<<1)|1,mid+1,r,L,R))%mod;
    return res%mod;
}
void ask1(int x,int y){
    int ans=0;
    while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		ans=(ans+query(1,1,seg[0],seg[top[x]],seg[x]))%mod;
		x=fa[top[x]];
    }
    if(deep[x]>deep[y])swap(x,y);
    ans=(ans+query(1,1,seg[0],seg[x],seg[y]))%mod;
    printf("%d
",ans%mod);
}
void change2(int x,int val){change(1,1,seg[0],seg[x],seg[x]+size[x]-1,val);}
void ask2(int x){
    int ans=0;
    ans=query(1,1,seg[0],seg[x],seg[x]+size[x]-1)%mod;
    printf("%d
",ans%mod);
}
int main(){
    n=read();m=read();root=read();mod=read();
    for(int i=1;i<=n;i++)val[i]=read()%mod;
    for(int i=1;i<n;i++)add_edge(read(),read());
    rev[1]=root;seg[0]=1;seg[root]=1;top[root]=root;
//seg[0]统计线段树上节点个数,我只是省了一个变量而已
    dfs1(root,0);dfs2(root,0);
    build(1,1,seg[0]);
    int x,y,z;
    while(m--){
		int opt=read();
		if(opt==1){x=read();y=read();z=read();change1(x,y,z);continue;}
		if(opt==2){x=read();y=read();ask1(x,y);continue;}
		if(opt==3){x=read();y=read();change2(x,y);continue;}
		if(opt==4){x=read();ask2(x);continue;}
    }
    return 0;
}

原文地址:https://www.cnblogs.com/PPXppx/p/10561580.html