树链剖分笔记(轻重链剖分)

我今天居然一次提交就A了!惊喜!

树链剖分可以在树上维护点的权值,可以进行一条链上的求和和修改操作,也可以将一个子树进行求和和修改。

我们的具体做法是要将这个树放到一个序列里,然后使用线段树来维护它。

我们定义重链和轻链,重链是指一个节点连向它最重的子节点的边,我们可以发现我们将相连的重链看做一条重链。

我们可以将一颗树拆分成轻链和重链,优先dfs重链,可以使得一条重链的dfs序是连续的,所以重链在线段树上也是一段连续的区间。

我们来找两个点的最近公共祖先,通过重链来向上跳,可以找到两个点的LCA,由于跳的都是重链,可以线段树区间维护一下。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
#define mid (l+r>>1)
#define B printf("!");
using namespace std;

int read()
{
	int a = 0,x = 1;
	char ch = getchar();
	while(ch > '9' || ch < '0'){
		if(ch == '-') x = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9'){
		a = a*10 + ch-'0';
		ch = getchar();
	}
	return a*x;
}
const int N=1e6+7;
int n,m,r,p;
int arr[N];

int head[N],go[N],nxt[N],cnt;
void add(int u,int v)
{
	go[++cnt] = v;
	nxt[cnt] = head[u];
	head[u] = cnt;
}
int siz[N],son[N],dfn[N],dis[N],str[N],pos[N],fa[N];

void dfs1(int u)
{
	siz[u] = 1;
	for(int e = head[u];e;e = nxt[e]){
		int v = go[e];
		if(v == fa[u]) continue;
		dis[v] = dis[u] + 1;
		fa[v] = u;
		dfs1(v);
		if(siz[v] > siz[son[u]]) son[u] = v;
		siz[u] += siz[v];
	}
}

void dfs2(int u,int h)
{
	str[u] = h;
	dfn[u] = ++cnt;
	pos[cnt] = u;
	if(son[u]) dfs2(son[u],h);
	for(int e = head[u];e;e = nxt[e]){
		int v = go[e];
		if(v == fa[u] || v == son[u]) continue;
		dfs2(v,v);
	}
}

int lazy[N],tre[N];

void build(int root,int l,int r)
{
	if(l == r) {tre[root] = arr[pos[l]];return;}
	build(root<<1,l,mid);build(root<<1|1,mid+1,r);
	tre[root] = (tre[root<<1]+tre[root<<1|1]) % p;
}

void push_down(int root,int l,int r)
{
	(tre[root<<1] += (mid-l+1)*lazy[root]%p)%=p,(tre[root<<1|1] += (r-mid)*lazy[root]%p)%=p;
	(lazy[root<<1] += lazy[root])%=p,(lazy[root<<1|1] += lazy[root])%=p;
	lazy[root] = 0;
}

void modify(int root,int l,int r,int ql,int qr,int x)
{
	if(l >= ql && r <= qr) {(tre[root] += (r-l+1)*x%p)%=p,(lazy[root] += x)%=p;return;}
	if(l > qr || r < ql) return;
	push_down(root,l,r);
	modify(root<<1,l,mid,ql,qr,x);modify(root<<1|1,mid+1,r,ql,qr,x);
	tre[root] = (tre[root<<1] + tre[root<<1|1])%p;
}

int query(int root,int l,int r,int ql,int qr)
{
	if(l >= ql && r <= qr) return tre[root];
	if(l > qr || r < ql) return 0;
	push_down(root,l,r);
	return (query(root<<1,l,mid,ql,qr)+query(root<<1|1,mid+1,r,ql,qr))%p;
}


void LCA_updata(int u,int v,int x)
{
	while(str[u] != str[v]){
//		printf("%d %d
",u,v);
		if(dis[str[u]] < dis[str[v]]) swap(u,v);
		modify(1,1,n,dfn[str[u]],dfn[u],x);
		u = fa[str[u]];
	}
	if(dis[u] > dis[v]) swap(u,v);
	modify(1,1,n,dfn[u],dfn[v],x);
}

void LCA_query(int u,int v)
{
	int ret = 0;
	while(str[u] != str[v]){
		if(dis[str[u]] < dis[str[v]]) swap(u,v);
		(ret += query(1,1,n,dfn[str[u]],dfn[u])) %= p;
		u = fa[str[u]];
	}
	if(dis[u] > dis[v]) swap(u,v);
	(ret += query(1,1,n,dfn[u],dfn[v])) %= p;
	printf("%d
",ret);
}

void tre_updata(int u,int x)
{
	modify(1,1,n,dfn[u],dfn[u]+siz[u]-1,x);
}

void tre_query(int u)
{
	int ret = query(1,1,n,dfn[u],dfn[u]+siz[u]-1);
	printf("%d
",ret);
}
int main()
{
//	freopen("in.in","r",stdin);
//	freopen("out.out","w",stdout);
	n = read(),m = read(),r = read(),p = read();
	for(int i = 1;i <= n;i ++) arr[i] = read();
	for(int i = 1;i < n;i ++){
		int u = read(),v = read();
		add(u,v);add(v,u);
	}
	cnt = 0;
	dfs1(r);dfs2(r,r);

	build(1,1,n);

	for(int i = 1;i <= m;i ++){
		int op = read();
		if(op == 1){
			int x = read(),y = read(),z = read();
			LCA_updata(x,y,z);
		} else if(op == 2) {
			int x = read(),y = read();
			LCA_query(x,y);
		} else if(op == 3) {
			int x = read(),z = read();
			tre_updata(x,z);
		} else if(op == 4) {
			int x = read();
			tre_query(x);
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/nao-nao/p/13651390.html