树链剖分-树剖换根

Description

这是一道模板题。
给定一棵 n 个节点的树,初始时该树的根为 1 号节点,每个节点有一个给定的权值。下面依次进行 m 个操作,操作分为如下五种类型:
换根:将一个指定的节点设置为树的新根。
修改路径权值:给定两个节点,将这两个节点间路径上的所有节点权值(含这两个节点)增加一个给定的值。
修改子树权值:给定一个节点,将以该节点为根的子树内的所有节点权值增加一个给定的值。
询问路径:询问某条路径上节点的权值和。
询问子树:询问某个子树内节点的权值和。

Input

第一行为一个整数 nnn,表示节点的个数。
第二行 nnn 个整数表示第 iii 个节点的初始权值 ai​​ 。
第三行 n−1 个整数,表示 i+1号节点的父节点编号 fi+1 (1⩽fi+1⩽n)。
第四行一个整数 m,表示操作个数。
接下来 m 行,每行第一个整数表示操作类型编号:(1⩽u,v⩽n)
若类型为 1,则接下来一个整数 u,表示新根的编号。
若类型为 2,则接下来三个整数 u,v,k,分别表示路径两端的节点编号以及增加的权值。
若类型为 3,则接下来两个整数 u,k,分别表示子树根节点编号以及增加的权值。
若类型为 4,则接下来两个整数 u,v,表示路径两端的节点编号。
若类型为 5,则接下来一个整数 u,表示子树根节点编号。

Output

对于每一个类型为 4 或 5 的操作,输出一行一个整数表示答案。


思路

  • 重构两遍40分,后发现getson函数写错,正确:if(fa[top[u]]==v) return top[u];错误:if(fa[u]==v) return u;
  • 树剖换根相关: root(当前根),x(询问),lca最近公共祖先,以询问总和sum为例:
  • 1. root==x:sum=整棵树之和;
  • 2. root在x的子树中(lca(root,x)==x):sum=整棵树之和-以(x~root路径上靠近x的点)为根的子树;
  • 3. root不在x的子树中(lca(root,x)!=x):sum=以1为根(原树)时x的子树;
  • 以1为根时x的子树在线段树中为连续的一段区间(num[x],num[x]+siz[x]-1);

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define int long long 
#define maxn 100005
using namespace std;
int n,m,cnt,head[maxn],val[maxn],root;
int fa[maxn],top[maxn],deep[maxn],siz[maxn],son[maxn],num[maxn],fnum[maxn];
struct node{int next,to;}e[maxn<<1];
struct fdfdfd{int l,r,flag,sum,len;}a[maxn<<2];
void addedge(int x,int y){e[++cnt].to=y; e[cnt].next=head[x]; head[x]=cnt;}
void dfs_1(int u)
{
	deep[u]=deep[fa[u]]+1; siz[u]=1;
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to; dfs_1(v); siz[u]+=siz[v];
		if(son[u]==-1||siz[v]>siz[son[u]]) son[u]=v;
	}
}
void dfs_2(int u,int topp)
{
	top[u]=topp; num[u]=++cnt; fnum[cnt]=u;
	if(son[u]!=-1) dfs_2(son[u],topp);
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v!=son[u]) dfs_2(v,v);
	}
}
void pushup(int x){a[x].sum=a[x<<1].sum+a[x<<1|1].sum;}
void pushdown(int x)
{
	if(a[x].flag==0) return;
	a[x<<1].flag+=a[x].flag; a[x<<1].sum+=a[x<<1].len*a[x].flag;
	a[x<<1|1].flag+=a[x].flag; a[x<<1|1].sum+=a[x<<1|1].len*a[x].flag;
	a[x].flag=0;
}
void build(int x,int left,int right)
{
	a[x].l=left; a[x].r=right; a[x].len=right-left+1;;
	if(left==right) {a[x].sum=val[fnum[left]]; return;}
	int mid=(left+right)>>1;
	build(x<<1,left,mid); build(x<<1|1,mid+1,right);
	pushup(x);
}
void modify(int x,int left,int right,int d)
{
	if(a[x].r<left||a[x].l>right) return;
	if(left<=a[x].l&&right>=a[x].r) {a[x].flag+=d,a[x].sum+=a[x].len*d; return;}
	pushdown(x);
	modify(x<<1,left,right,d); modify(x<<1|1,left,right,d);
	pushup(x);
}
void change_uv(int u,int v,int k)
{
	while(top[u]!=top[v])
	{
		if(deep[top[u]]<deep[top[v]]) swap(u,v);
		modify(1,num[top[u]],num[u],k);
		u=fa[top[u]];
	}
	if(deep[u]>deep[v]) swap(u,v);
	modify(1,num[u],num[v],k);
}
int query(int x,int left,int right)
{
	if(a[x].r<left||a[x].l>right) return 0;
	if(left<=a[x].l&&right>=a[x].r) return a[x].sum;
	pushdown(x);
	return query(x<<1,left,right)+query(x<<1|1,left,right);
}
int ask_uv(int u,int v)
{
	int ans=0;
	while(top[u]!=top[v])
	{
		if(deep[top[u]]<deep[top[v]]) swap(u,v);
		ans+=query(1,num[top[u]],num[u]);
		u=fa[top[u]];
	}
	if(deep[u]>deep[v]) swap(u,v);
	ans+=query(1,num[u],num[v]);
	return ans;
}
int getlca(int u,int v)
{
	while(top[u]!=top[v])
	{
		if(deep[top[u]]<deep[top[v]]) swap(u,v);
		u=fa[top[u]];
	}
	return deep[u]<deep[v]?u:v;
}
int getson(int u,int v)
{
	while(top[u]!=top[v])
	{
		if(deep[top[u]]<deep[top[v]]) swap(u,v);
		if(fa[top[u]]==v) return top[u];
		u=fa[top[u]];
	}
	return deep[u]<deep[v]?son[u]:son[v];
}
void change_u(int u,int k)
{
	int lca=getlca(u,root),son;
	if(u==root) return modify(1,1,n,k);
	if(u!=lca) return modify(1,num[u],num[u]+siz[u]-1,k);
	if(u==lca) modify(1,1,n,k),son=getson(u,root),modify(1,num[son],num[son]+siz[son]-1,-k);
}
int ask_u(int u)
{
	int lca=getlca(u,root),son,ans=0;
	if(u==root) return query(1,1,n);
	if(u!=lca) return query(1,num[u],num[u]+siz[u]-1);
	if(u==lca) ans=query(1,1,n),son=getson(u,root),ans-=query(1,num[son],num[son]+siz[son]-1);
	return ans;
}
signed main()
{
	memset(son,-1,sizeof(son));
	scanf("%lld",&n);
	for(int i=1;i<=n;++i) scanf("%lld",&val[i]);
	for(int i=1,u;i<n;++i) scanf("%lld",&u),fa[i+1]=u,addedge(u,i+1);
	dfs_1(1); cnt=0; dfs_2(1,1); build(1,1,n); root=1;
	scanf("%lld",&m);
	while(m--)
	{
		int op,u,v,k; scanf("%lld%lld",&op,&u);
		if(op==1) root=u;
		else if(op==2) scanf("%lld%lld",&v,&k),change_uv(u,v,k);
		else if(op==3) scanf("%lld",&k),change_u(u,k);
		else if(op==4) scanf("%lld",&v),printf("%lld
",ask_uv(u,v));
		else printf("%lld
",ask_u(u));
	}
	return 0;
}
原文地址:https://www.cnblogs.com/wuwendongxi/p/13503498.html