树链剖分--P3384 【模板】轻重链剖分

经过了一系列的前置知识,终于学会了树链剖分!!

重链剖分的思想:

重链剖分可以将树上的任意一条路径划分成不超过(O(logn))条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的(LCA)$为链的一个端点)。

重链剖分还能保证划分出的每条链上的节点(DFS)序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。

如:

  1. 修改 树上两点之间的路径上 所有点的值。

  2. 查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)。

我们给出一些定义:

  1. 重子节点 :表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

  2. 轻子节点 :表示剩余的所有子结点。

  3. 重边 :从这个结点到重子节点的边为重边 。

  4. 轻边 :到其他轻子节点的边为 轻边 。

  5. 重链 :若干条首尾衔接的重边构成 重链 。

实现:

树剖的实现分两个(DFS)的过程

第一个(DFS)记录每个结点的父节点、深度、子树大小、重子节点。

  • (siz_x),表示子树(x)的大小

  • (dep_x),表示点(x)的深度

  • (fa_x),表示点(x)的父亲

  • (son_x),表示点(x)的重儿子

void dfs1(int x){
	siz[x] = 1;dep[x] = dep[fa[x]]+1;
	for (int i = head[x];i;i = ed[i].nxt){
		int to = ed[i].to;
		if (to == fa[x]) continue;
		fa[to] = x;
		dfs1(to);
		if (siz[to] > siz[son[x]]) son[x] = to;
		siz[x] += siz[to];
	}
}

第二个(DFS)记录所在链的链顶((root),应初始化为结点本身)、重边优先遍历时的(DFS)((dfn))(DFS)序对应的节点编号((pos))

  • (str_x),表示(x)所在重链的链顶

  • (dfn_x),表示点(x)(dfs)

  • (pos_x),表示(dfs)序为(x)的点

void dfs2(int x,int root){
	str[x] = root;
	dfn[x] = ++cnt;pos[cnt] = x;
	if (son[x]) dfs2(son[x],root);
	for (int i = head[x];i;i = ed[i].nxt){
		int to = ed[i].to;
		if (to == fa[x]||to == son[x]) continue;
		dfs2(to,to);
	}
}

路径上修改和查询:

链上的(DFS)序是连续的,可以使用线段树、树状数组维护,每次选择深度较大的链往上跳,直到两点在同一条链上。

void fix1(){
	int a = read(),b = read(),x = read();
	while (str[a] != str[b]){
		if (dep[str[a]] < dep[str[b]]) swap(a,b);
		modify(1,1,n,dfn[str[a]],dfn[a],x);
		a = fa[str[a]];
	}
	if (dep[a] > dep[b]) swap(a,b);
	modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
	int a = read(),b = read();
	int res = 0;
	while (str[a] != str[b]){
		if (dep[str[a]] < dep[str[b]]) swap(a,b);
		(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
		a = fa[str[a]];
	}
	if (dep[a] > dep[b]) swap(a,b);
	(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
	printf("%lld
",res);
}

子树修改和查询:

(DFS)搜索的时候,子树中的结点的(DFS)序是连续的,每一个结点到子树末端的结点的(dfs)序就为他本身的(dfs)序+子树大小-1。

void fix3(){
	int a = read(),x = read();
	modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
	int a = read();
	printf("%lld
",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));	
}

例题的完整代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
int read(){
	int x = 1,a = 0;char ch = getchar();
	while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
	while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
	return x*a;
}
const int maxn = 1e5+10;
int n,m,r,mod,a[maxn];
struct node{
	int to,nxt;
}ed[maxn*2];
int head[maxn*2],tot;
void add(int u,int to){
	ed[++tot].to = to;
	ed[tot].nxt = head[u];
	head[u] = tot;
}
int fa[maxn],siz[maxn],son[maxn],dep[maxn];
void dfs1(int x){
	siz[x] = 1;dep[x] = dep[fa[x]]+1;
	for (int i = head[x];i;i = ed[i].nxt){
		int to = ed[i].to;
		if (to == fa[x]) continue;
		fa[to] = x;
		dfs1(to);
		if (siz[to] > siz[son[x]]) son[x] = to;
		siz[x] += siz[to];
	}
}
int cnt,str[maxn],dfn[maxn],pos[maxn];
void dfs2(int x,int root){
	str[x] = root;
	dfn[x] = ++cnt;pos[cnt] = x;
	if (son[x]) dfs2(son[x],root);
	for (int i = head[x];i;i = ed[i].nxt){
		int to = ed[i].to;
		if (to == fa[x]||to == son[x]) continue;
		dfs2(to,to);
	}
}
int tree[maxn*4],lazy[maxn*4];
int ls(int x){return x<<1;}
int rs(int x){return x<<1|1;}
void pushup(int x){
	tree[x] = tree[ls(x)] + tree[rs(x)];
}
void build(int x,int l,int r){
	if (l == r){tree[x] = a[pos[l]];return;}
	int mid = (l+r)>>1;
	build(ls(x),l,mid);build(rs(x),mid+1,r);
	pushup(x);
}
void tag(int x,int l,int r,int k){
	lazy[x] += k;
	tree[x] += (r-l+1)*k;
}
void pushdown(int x,int l,int r){
	int mid = (l+r)>>1;
	tag(ls(x),l,mid,lazy[x]);
	tag(rs(x),mid+1,r,lazy[x]);
	lazy[x] = 0;
}
void modify(int x,int l,int r,int nl,int nr,int k){
	if (nl <= l&&r <= nr){tag(x,l,r,k);return;}
	int mid = (l+r)>>1;
	pushdown(x,l,r);
	if (nl <= mid) modify(ls(x),l,mid,nl,nr,k);
	if (nr > mid) modify(rs(x),mid+1,r,nl,nr,k);
	pushup(x);
}
int query(int x,int l,int r,int nl,int nr){
	int res = 0;
	if (nl <= l&&r <= nr) return tree[x];
	int mid = (l+r)>>1;
	pushdown(x,l,r);
	if (nl <= mid) (res+=query(ls(x),l,mid,nl,nr))%=mod;
	if (nr > mid) (res+=query(rs(x),mid+1,r,nl,nr))%=mod;
	return res;
}
void fix1(){
	int a = read(),b = read(),x = read();
	while (str[a] != str[b]){
		if (dep[str[a]] < dep[str[b]]) swap(a,b);
		modify(1,1,n,dfn[str[a]],dfn[a],x);
		a = fa[str[a]];
	}
	if (dep[a] > dep[b]) swap(a,b);
	modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
	int a = read(),b = read();
	int res = 0;
	while (str[a] != str[b]){
		if (dep[str[a]] < dep[str[b]]) swap(a,b);
		(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
		a = fa[str[a]];
	}
	if (dep[a] > dep[b]) swap(a,b);
	(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
	printf("%lld
",res);
}
void fix3(){
	int a = read(),x = read();
	modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
	int a = read();
	printf("%lld
",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));	
}
signed main(){
	n = read(),m = read(),r = read(),mod = read();
	for (int i = 1;i <= n;i++) a[i] = read(); 
	for (int i = 1;i <= n-1;i++){
		int x = read(),y = read();
		add(x,y),add(y,x);
	} 
	dfs1(r);dfs2(r,r);
	build(1,1,n);
	for (int i = 1;i <= m;i++){	
		int op = read();
		if (op == 1) fix1();
		if (op == 2) fix2();
		if (op == 3) fix3();
		if (op == 4) fix4();
	}
	return 0;
}
原文地址:https://www.cnblogs.com/little-uu/p/13971680.html