[国家集训队]Tree II

V.[国家集训队]Tree II

LCT维护这种东西是要比线段树要恶心的多的……毕竟线段树的区间大小是可以直接通过区间左右端点算出的,但是LCT就不行,必须手动维护。并且,线段树的维护是(左儿子+右儿子),但是LCT的维护是(左儿子+自己+右儿子)!

请务必先把线段树模板做掉,关于运算的优先级什么的实在不应该放到这道题里讲吧……

注意细节,比如修改时,乘法\(tag\),加法\(tag\),本身的值,区间的和,都要修改,并且每个修改的方式都还不一样。

代码:

#include<bits/stdc++.h>
using namespace std;
const int mod=51061;
#define lson t[x].ch[0]
#define rson t[x].ch[1]
int n,q;
struct LCT{
	int ch[2],fa,plu,mul,rev,sum,val,sz;
}t[100100];
int identify(int x){
	if(t[t[x].fa].ch[0]==x)return 0;
	if(t[t[x].fa].ch[1]==x)return 1;
	return -1;
}
void REV(int x){
	t[x].rev^=1,swap(lson,rson);
}
void pushdown(int x){
	if(t[x].rev){
		if(lson)REV(lson);
		if(rson)REV(rson);
		t[x].rev=0;
	}
	if(lson)t[lson].val=(1ll*t[lson].val*t[x].mul+t[x].plu)%mod,t[lson].plu=(1ll*t[lson].plu*t[x].mul+t[x].plu)%mod,t[lson].mul=(1ll*t[lson].mul*t[x].mul)%mod,t[lson].sum=(1ll*t[lson].sum*t[x].mul+1ll*t[x].plu*t[lson].sz)%mod;
	if(rson)t[rson].val=(1ll*t[rson].val*t[x].mul+t[x].plu)%mod,t[rson].plu=(1ll*t[rson].plu*t[x].mul+t[x].plu)%mod,t[rson].mul=(1ll*t[rson].mul*t[x].mul)%mod,t[rson].sum=(1ll*t[rson].sum*t[x].mul+1ll*t[x].plu*t[rson].sz)%mod;
	t[x].plu=0,t[x].mul=1;
}
void pushup(int x){
	t[x].sum=(t[lson].sum+t[rson].sum+t[x].val)%mod;
	t[x].sz=t[lson].sz+t[rson].sz+1;
}
void rotate(int x){
	int y=t[x].fa;
	int z=t[y].fa;
	int dirx=identify(x);
	int diry=identify(y);
	int b=t[x].ch[!dirx];
	if(diry!=-1)t[z].ch[diry]=x;t[x].fa=z;
	if(b)t[b].fa=y;t[y].ch[dirx]=b;
	t[x].ch[!dirx]=y,t[y].fa=x;
	pushup(y),pushup(x);
}
void pushall(int x){
	if(identify(x)!=-1)pushall(t[x].fa);
	pushdown(x);
}
void splay(int x){
	pushall(x);
	while(identify(x)!=-1){
		int fa=t[x].fa;
		if(identify(fa)==-1)rotate(x);
		else if(identify(fa)==identify(x))rotate(fa),rotate(x);
		else rotate(x),rotate(x);
	}
}
void access(int x){
	for(int y=0;x;x=t[y=x].fa)splay(x),rson=y,pushup(x);
}
void makeroot(int x){
	access(x),splay(x),REV(x);
}
int findroot(int x){
	access(x),splay(x);
	pushdown(x);
	while(lson)x=lson,pushdown(x);
	splay(x);
	return x;
}
void link(int x,int y){
	makeroot(x),t[x].fa=y;
}
int split(int x,int y){
	makeroot(x),access(y),splay(y);
	return t[y].sum;
}
void cut(int x,int y){
	split(x,y),t[x].fa=t[y].ch[0]=0,pushup(y);
}
int main(){
	scanf("%d%d",&n,&q);
	for(int i=1;i<=n;i++)t[i].val=t[i].mul=t[i].sum=1;
	for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),link(x,y);
	for(int i=1,t1,t2,t3,t4;i<=q;i++){
		char s[10];
		scanf("%s",s);
		if(s[0]=='+')scanf("%d%d%d",&t1,&t2,&t3),split(t1,t2),t[t2].plu=(t[t2].plu+t3)%mod,t[t2].val=(t[t2].val+t3)%mod,t[t2].sum=(1ll*t[t2].sz*t3+t[t2].sum)%mod,pushup(t2);
		if(s[0]=='-')scanf("%d%d%d%d",&t1,&t2,&t3,&t4),cut(t1,t2),link(t3,t4);
		if(s[0]=='*')scanf("%d%d%d",&t1,&t2,&t3),split(t1,t2),t[t2].mul=(1ll*t[t2].mul*t3)%mod,t[t2].val=(1ll*t[t2].val*t3)%mod,t[t2].sum=(1ll*t[t2].sum*t3)%mod,pushup(t2);
		if(s[0]=='/')scanf("%d%d",&t1,&t2),printf("%d\n",split(t1,t2));
	}
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14601909.html