普通平衡树 Splay

Code:

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int maxn=400006;
int ch[maxn][2],f[maxn],siz[maxn],num[maxn],val[maxn];
int root,cnt;
int get(int x){return ch[f[x]][1]==x;}
void pushup(int x){siz[x]=num[x]+siz[ch[x][0]]+siz[ch[x][1]];}
void rotate(int x)
{
	int old=f[x],oldf=f[old],which=get(x);
	ch[old][which]=ch[x][which^1],f[ch[old][which]]=old;
	ch[x][which^1]=old,f[old]=x,f[x]=oldf;
	if(oldf)ch[oldf][ch[oldf][1]==old]=x;
	pushup(old);pushup(x);
}
int findx(int x){
	int p=root;
	while(val[p]!=x)p=ch[p][x>val[p]];
	return p;
}
void splay(int x,int& tar){
	int a=f[tar];
	for(int fa;(fa=f[x])!=a;rotate(x))
		if(f[fa]!=a)rotate(get(x)==get(fa)?fa:x);
	tar=x;
}
int x_rank(int x){                    
	splay(findx(x),root);return siz[ch[root][0]]+1;
}
int rank_x(int x){                    
	int p=root;
	while(1){
		if(x<=siz[ch[p][0]])p=ch[p][0];
		else {
			x-=(siz[ch[p][0]]+num[p]);
			if(x<=0){splay(p,root);return val[p];}
			p=ch[p][1];
		}
	}
}
int pre_x(int x){                      
	int ans;
	int p=root;
	while(p){
		if(val[p]<x){ans=val[p];p=ch[p][1];}
		else p=ch[p][0];
	}
	return ans;
}
int aft_x(int x){                     
	int ans;
	int p=root;
	while(p){
		if(val[p]>x){ans=val[p],p=ch[p][0];}
		else p=ch[p][1];
	}
	return ans;
}
void insert_x(int x){
	if(!root){
		++cnt;root=cnt,val[root]=x,num[root]=1;pushup(root);return;
	}
	int p,fa;
	p=fa=root;
	while(p&&val[p]!=x)fa=p,p=ch[p][x>val[p]];
	if(!p){
		++cnt;val[cnt]=x,num[cnt]=1,f[cnt]=fa,ch[fa][x>val[fa]]=cnt;
		pushup(cnt);splay(cnt,root);return;
	}
	++num[p];pushup(p);splay(p,root);
}
void delete_x(int x){
	int p=findx(x);splay(p,root);
	if(num[root]>1){--num[root];pushup(root);return;}
	if(!ch[root][0]&&!ch[root][1])root=0;
	else if(!ch[root][0])root=ch[root][1],f[root]=0;
	else if(!ch[root][1])root=ch[root][0],f[root]=0;
	else{
		p=ch[root][0];
		while(ch[p][1])p=ch[p][1];
		splay(p,ch[root][0]);
		ch[p][1]=ch[root][1];f[ch[p][1]]=p,f[p]=0;pushup(p);
		root=p;
	}
}
int main(){
	int N;scanf("%d",&N);
	for(int i=1;i<=N;++i)
	{
		int opt,x;scanf("%d%d",&opt,&x);
		if(opt==1)insert_x(x);
		if(opt==2)delete_x(x);
		if(opt==3)printf("%d
",x_rank(x));
		if(opt==4)printf("%d
",rank_x(x));
		if(opt==5)printf("%d
",pre_x(x));
		if(opt==6)printf("%d
",aft_x(x));
	}
	return 0;
}

  

原文地址:https://www.cnblogs.com/guangheli/p/9845185.html