[luogu3384][模板]树链剖分

Description

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

Input

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

Output

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

Sample Input

5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3

Sample Output

2
21

Hint

时空限制:1s,128M

数据规模:

对于30%的数据: N leq 10, M leq 10 N≤10,M≤10

对于70%的数据: N leq {10}^3, M leq {10}^3 N≤10
3
,M≤10
3

对于100%的数据: N leq {10}^5, M leq {10}^5 N≤10
5
,M≤10
5

(其实,纯随机生成的树LCA+暴力是能过的,可是,你觉得可能是纯随机的么233

样例说明:

树的结构如下:
此处输入图片的描述

各个操作如下:
此处输入图片的描述

故输出应依次为2、21(重要的事情说三遍:记得取模)

题解

树链剖分模板

#include<cstdio>
#include<cstring>
#include<iostream>

using namespace std;
typedef long long LL;

LL read()
{
	LL x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

int N,M,R,P,A[100050];
int ecnt,place,head[100050];
int deep[100050],parent[100050],size[100050];
int tid[100050],top[100050],rnk[100050],end[100050];
struct edge{int to,nxt;}e[200050];
struct segmentnode{int sum,tag;}seg[400050];

inline void addedge(int u,int v)
{
	e[ecnt]=(edge){v,head[u]},head[u]=ecnt++;
	e[ecnt]=(edge){u,head[v]},head[v]=ecnt++;
}
inline void pushup(int root)
{
	seg[root].sum=(seg[root<<1].sum+seg[root<<1|1].sum)%P;
}
inline void pushdown(int root,int l,int r)
{
	int x=seg[root].tag;seg[root].tag=0;
	if(x==0)return;
	int mid=(l+r)>>1;
	(seg[root<<1].sum+=(mid-l+1)*x%P)%=P,(seg[root<<1].tag+=x)%=P;
	(seg[root<<1|1].sum+=(r-mid)*x%P)%=P,(seg[root<<1|1].tag+=x)%=P;
}

void getdeep(int root,int step,int fa)
{
	deep[root]=step,parent[root]=fa,size[root]=1;
	for(int i=head[root];~i;i=e[i].nxt)
	{
		int v=e[i].to;if(v==fa)continue;
		getdeep(v,step+1,root);
		size[root]+=size[v];
	}
}

void devide(int root,int chain,int fa)
{
	tid[root]=end[root]=++place,top[root]=chain,rnk[place]=root;
	int k=0;
	for(int i=head[root];~i;i=e[i].nxt)
	{
		int v=e[i].to;if(v==fa)continue;
		if(size[v]>size[k])k=v;
	}
	if(k==0)return;
	devide(k,chain,root),end[root]=max(end[root],end[k]);
	for(int i=head[root];~i;i=e[i].nxt)
	{
		int v=e[i].to;if(v==fa||v==k)continue;
		devide(v,v,root);end[root]=max(end[root],end[v]);
	}
}

void build(int root,int l,int r)
{
	seg[root].tag=0;
	if(l==r){seg[root].sum=A[rnk[l]];return;}
	int mid=(l+r)>>1;
	build(root<<1,l,mid),build(root<<1|1,mid+1,r);
	pushup(root);
}

void updata(int root,int l,int r,int a,int b,int val)
{
	if(l==a&&r==b)
	{
		(seg[root].sum+=(r-l+1)*val%P)%=P;
		(seg[root].tag+=val)%=P;
		return;
	}
	int mid=(l+r)>>1; pushdown(root,l,r);
	if(b<=mid)updata(root<<1,l,mid,a,b,val);
	else if(a>mid)updata(root<<1|1,mid+1,r,a,b,val);
	else 
	{
		updata(root<<1,l,mid,a,mid,val);
		updata(root<<1|1,mid+1,r,mid+1,b,val);
	}
	pushup(root);
}

int getsum(int root,int l,int r,int a,int b)
{
	if(l==a&&r==b)return seg[root].sum;
	int mid=(l+r)>>1;pushdown(root,l,r);
	if(b<=mid)return getsum(root<<1,l,mid,a,b);
	else if(a>mid)return getsum(root<<1|1,mid+1,r,a,b);
	else
	{
		int res=getsum(root<<1,l,mid,a,mid);
		(res+=getsum(root<<1|1,mid+1,r,mid+1,b))%=P;
		return res;
	}
}

void updata(int a,int b,int val)
{
	while(top[a]!=top[b])
	{
		if(deep[top[a]]>deep[top[b]])swap(a,b);
		updata(1,1,N,tid[top[b]],tid[b],val);
		b=parent[top[b]];
	}
	if(deep[a]>deep[b])swap(a,b);
	updata(1,1,N,tid[a],tid[b],val);
}

int getsum(int a,int b)
{
	int res=0;
	while(top[a]!=top[b])
	{
		if(deep[top[a]]>deep[top[b]])swap(a,b);
		(res+=getsum(1,1,N,tid[top[b]],tid[b]))%=P;
		b=parent[top[b]];
	}
	if(deep[a]>deep[b])swap(a,b);
	(res+=getsum(1,1,N,tid[a],tid[b]))%=P;
	return res;
}

int main()
{
	memset(head,-1,sizeof(head));
	N=read(),M=read(),R=read(),P=read();
	for(int i=1;i<=N;i++)A[i]=read()%P;
	for(int i=1;i<N;i++)
	{
		int u=read(),v=read();
		addedge(u,v);
	}
	getdeep(R,1,0),devide(R,R,0),build(1,1,N);
	for(int i=1;i<=M;i++)
	{
		int op=read();
		if(op==1)
		{
			int x=read(),y=read(),z=read();
			updata(x,y,z);
		}
		if(op==2)
		{
			int x=read(),y=read();
			printf("%d
",getsum(x,y));
		}
		if(op==3)
		{
			int x=read(),z=read();
			updata(1,1,N,tid[x],end[x],z);
		}
		if(op==4)
		{
			int x=read();
			printf("%d
",getsum(1,1,N,tid[x],end[x]));
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ljzalc1022/p/8757088.html