[HNOI2010]弹飞绵羊

III.[HNOI2010]弹飞绵羊

首先,可以发现,如果从一个装置的目标向这个装置连一条边,并且建立虚拟节点\(n+1\),向所有可以弹飞的装置连边的话,这肯定构成一颗树。

理解就行。然后我们就可以在每个节点统计一个\(size\),并用LCT在改变弹力系数时修改这棵树。则最终答案为:

int ask(int x){
	makeroot(n+1),access(x),splay(x);
	return t[x].sz;
}

显然,答案为当前节点的深度。我们可以打通从\(x\)\(n+1\)的路径,然后答案即为这个splay的大小。另,最后答案还要减去\(1\),因为有\(n+1\)这个虚拟节点。

代码:

#include<bits/stdc++.h>
using namespace std;
#define lson t[x].ch[0]
#define rson t[x].ch[1] 
int n,m,fa[200100];
struct LCT{
	int ch[2],sz,fa,rev;
}t[200100];
void pushup(int x){
	t[x].sz=t[lson].sz+t[rson].sz+1;
}
void REV(int x){
	t[x].rev^=1,swap(t[x].ch[0],t[x].ch[1]);
}
void pushdown(int x){
	if(!t[x].rev)return;
	if(lson)REV(lson);
	if(rson)REV(rson);
	t[x].rev=0;
}
int identify(int x){
	if(t[t[x].fa].ch[0]==x)return 0;
	if(t[t[x].fa].ch[1]==x)return 1;
	return -1;
}
void rotate(int x){
	int y=t[x].fa;
	int z=t[y].fa;
	int dirx=identify(x);
	int diry=identify(y);
	int b=t[x].ch[!dirx];
	if(diry!=-1)t[z].ch[diry]=x;t[x].fa=z;
	if(b)t[b].fa=y;t[y].ch[dirx]=b;
	t[x].ch[!dirx]=y,t[y].fa=x;
	pushup(y),pushup(x);
}
void pushall(int x){
	if(identify(x)!=-1)pushall(t[x].fa);
	pushdown(x);
}
void splay(int x){
	pushall(x);
	while(identify(x)!=-1){
		int fa=t[x].fa;
		if(identify(fa)==-1)rotate(x);
		else if(identify(fa)==identify(x))rotate(fa),rotate(x);
		else rotate(x),rotate(x);
	}
}
void access(int x){
	for(int y=0;x;x=t[y=x].fa)splay(x),rson=y,pushup(x);
}
void makeroot(int x){
	access(x),splay(x),REV(x);
}
void split(int x,int y){
	makeroot(x),access(y),splay(y);
}
int findroot(int x){
	access(x),splay(x);
	pushdown(x);
	while(lson)x=lson,pushdown(x);
	splay(x);
	return x;
}
void link(int x,int y){
	makeroot(x),t[x].fa=y;
}
void cut(int x,int y){
	split(x,y),t[x].fa=t[y].ch[0]=0,pushup(y);
}
void change(int x,int y){
	cut(x,fa[x]),link(x,fa[x]=y);
}
int ask(int x){
	makeroot(n+1),access(x),splay(x);
	return t[x].sz;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++)scanf("%d",&fa[i]),fa[i]+=i,fa[i]=min(fa[i],n+1),link(i,fa[i]),t[i].sz=1;
	scanf("%d",&m);
	for(int i=1,t1,t2,t3;i<=m;i++){
		scanf("%d%d",&t1,&t2),t2++;
		if(t1==1)printf("%d\n",ask(t2)-1);
		else scanf("%d",&t3),change(t2,min(t2+t3,n+1));
	}
	return 0;
}

原文地址:https://www.cnblogs.com/Troverld/p/14601886.html