splay单点模板-5203-BZOJ3224 普通平衡树

题目链接

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#define maxn 10000000
using namespace std;
int fa[100005],num[100005],siz[100005],ch[100005][3],val[100005],root,cnt,neww;
void pushup(int x){siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+num[x];}
void table()
{
	fa[2]=1; ch[1][0]=2;
	val[2]=-maxn; val[1]=maxn;
	siz[1]=2; siz[2]=1; num[2]=num[1]=1;
	cnt=2; root=1;
}
void rot(int x,int &f)
{
	int y=fa[x],z=fa[y],l=(ch[y][0]!=x),r=(l^1);
	if(y==f) f=x;
	else ch[z][ch[z][1]==y]=x;
	fa[x]=z; fa[y]=x; fa[ch[x][r]]=y;
	ch[y][l]=ch[x][r]; ch[x][r]=y;
	pushup(y); pushup(x);
}
void splay(int x,int &f)
{
	while(x!=f)
	{
		int y=fa[x],z=fa[y];
		if(y!=f){if(ch[y][0]==x^ch[z][0]==y) rot(x,f);else rot(y,f);}
		rot(x,f);
	}
}
void insert(int x,int rt)
{
	if(val[rt]>x&&ch[rt][0]) insert(x,ch[rt][0]);
	else if(val[rt]<x&&ch[rt][1]) insert(x,ch[rt][1]);
	else if(val[rt]==x) num[rt]++,siz[rt]++,neww=rt;
	else
	{
		if(val[rt]>x) ch[rt][0]=++cnt;
		else if(val[rt]<x) ch[rt][1]=++cnt;
		fa[cnt]=rt; siz[cnt]=1; val[cnt]=x; num[cnt]=1; neww=cnt;
	}
	pushup(rt);
}
int getfront()
{
	int x=ch[root][0];
	while(ch[x][1]) x=ch[x][1];
	return x;
}
int getback()
{
	int x=ch[root][1];
	while(ch[x][0]) x=ch[x][0];
	return x;
}
void del(int x,int rt)
{
	if(val[rt]>x&&ch[rt][0]) del(x,ch[rt][0]);
	else if(val[rt]<x&&ch[rt][1]) del(x,ch[rt][1]);
	else
	{
		num[rt]--;
		if(!num[rt])
		{
			splay(rt,root);
			int head=getfront(),tail=getback();
			splay(head,root); splay(tail,ch[head][1]);
			ch[tail][0]=0;
			siz[rt]=num[rt]=fa[rt]=val[rt]=0;
			pushup(tail); pushup(head);
		}
	}
	pushup(rt);
}
int findx(int x,int rt)
{
	if(val[rt]>x&&ch[rt][0]) return findx(x,ch[rt][0]);
	else if(val[rt]<x&&ch[rt][1]) return findx(x,ch[rt][1]);
	else if(val[rt]==x) return rt;
}
int findk(int x,int rt)
{
	if(siz[ch[rt][0]]+num[rt]>=x&&siz[ch[rt][0]]<x) return rt;
	else if(siz[ch[rt][0]]>=x) return findk(x,ch[rt][0]);
	else return findk(x-siz[ch[rt][0]]-num[rt],ch[rt][1]);
}
int main()
{
	int n;
	scanf("%d",&n); table();
	for(int i=1;i<=n;++i)
	{
		
		int f1,f2,pos; scanf("%d%d",&f1,&f2);
		if(f1==1) insert(f2,root),splay(neww,root);
		else if(f1==2) del(f2,root);
		else if(f1==3) pos=findx(f2,root),splay(pos,root),printf("%d
",siz[ch[root][0]]);
		else if(f1==4) printf("%d
",val[findk(f2+1,root)]);
		else if(f1==5) insert(f2,root),splay(neww,root),printf("%d
",val[getfront()]),del(f2,root);
		else if(f1==6) insert(f2,root),splay(neww,root),printf("%d
",val[getback()]),del(f2,root);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/wuwendongxi/p/13159473.html