【模板】Splay

Splay 均摊复杂度证明见此处 ( ightarrow) 链接
代码如下

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
const int inf=0x3f3f3f3f;

struct node{
	#define ls(x) t[x].ch[0]
	#define rs(x) t[x].ch[1]
	int fa,ch[2],val,size,cnt;
}t[maxn];
int tot,root;
inline int get(int x){return x==rs(t[x].fa);}
inline void pushup(int x){
	t[x].size=t[ls(x)].size+t[rs(x)].size+t[x].cnt;
}
inline int find(int val){
	int x=root;
	while(t[x].val!=val&&t[x].ch[t[x].val<val])x=t[x].ch[t[x].val<val];
	return x;
}
inline void rotate(int x){
	int fa=t[x].fa,gfa=t[fa].fa;
	int d1=get(x),d2=get(fa);
	t[fa].ch[d1]=t[x].ch[d1^1],t[t[x].ch[d1^1]].fa=fa;
	t[x].ch[d1^1]=fa,t[fa].fa=x;
	t[x].fa=gfa,t[gfa].ch[d2]=x;
	pushup(fa),pushup(x);
}
inline void splay(int x,int goal){
	while(t[x].fa!=goal){
		int fa=t[x].fa,gfa=t[fa].fa;
		if(gfa!=goal)get(x)==get(fa)?rotate(fa):rotate(x);
		rotate(x);
	}
	if(!goal)root=x;
}
void insert(int val){
	int x=root,fa=0;
	while(x&&t[x].val!=val)fa=x,x=t[x].ch[t[x].val<val];
	if(x)++t[x].cnt;
	else{
		x=++tot;
		if(fa)t[fa].ch[t[fa].val<val]=x;
		t[x].fa=fa,t[x].val=val,t[x].cnt=t[x].size=1;
	}
	splay(x,0);
}
int kth(int x,int k){
	if(k<=t[ls(x)].size)return kth(ls(x),k);
	else if(k>t[ls(x)].size+t[x].cnt)return kth(rs(x),k-t[ls(x)].size-t[x].cnt);
	else return t[x].val;
}
int getrank(int val){
	splay(find(val),0);
	return t[ls(root)].size;
}
int getpre(int val){
	splay(find(val),0);
	if(t[root].val<val)return root;
	int x=ls(root);
	while(rs(x))x=rs(x);
	return x;
}
int getnxt(int val){
	splay(find(val),0);
	if(t[root].val>val)return root;
	int x=rs(root);
	while(ls(x))x=ls(x);
	return x;
}
void remove(int val){
	int pre=getpre(val),nxt=getnxt(val);
	splay(pre,0),splay(nxt,pre);
	if(t[ls(nxt)].cnt>1)--t[ls(nxt)].cnt,splay(ls(nxt),0);
	else ls(nxt)=0,splay(nxt,0);
}
void initial(){insert(-inf),insert(inf);}

int main(){
	initial();
	int opt,val,n;
    scanf("%d",&n);
    while(n--){
        scanf("%d%d",&opt,&val);
        switch(opt){
            case 1:insert(val);break;
            case 2:remove(val);break;
            case 3:printf("%d
",getrank(val));break;
            case 4:printf("%d
",kth(root,val+1));break;
            case 5:printf("%d
",t[getpre(val)].val);break;
            case 6:printf("%d
",t[getnxt(val)].val);break;
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/wzj-xhjbk/p/9941928.html