splay与他的解析
二叉搜索树
我们搞一棵树,保证左子树所有点的权值比父亲小,右子树所有点权值比父亲大
显然这个玩意可以快速查询一个数存不存在,排名啦啥的
插入的时候直接顺序造节点,删除的时候,断开重连是件很愚蠢的事情
应该把删除节点和他右子树最左边那个或者左子树最右边那个交换,销毁它
这样有什么问题呢?
如果坑逼出题人给你搞成了链,直接起飞成$ O(n)$了
这不就挂了
那么该怎么搞呢
伸展树splay出现了
splay,旋转,自我调整,
旋转
本图片引用自liuzhangfeiabc课件
我们可以看到,转x,就是让f变成x的右儿子,x的右儿子变成f的左儿子,x的右儿子变成f,x代替f
void ro(int x){
int f=po[x].fa;
int gf=po[f].fa;
int xp=ident(x);
int fp=ident(f);
connect(po[x].son[xp^1],f,xp);
connect(f,x,xp^1);
connect(x,gf,fp);
update(f);
update(x);
}
但是出题人还能接着卡你
怎么办呢,我们一次转俩,直接强行链改树
splay
void splay(int x,int to){
to=po[to].fa;
while(po[x].fa!=to){
int f=po[x].fa;
if(po[f].fa==to) ro(x);
else if(ident(x)==ident(f)) ro(f),ro(x);
else ro(x),ro(x);
}
}.
这又是在干什么呢
我们可以把一个节点移到你想去的任何位置如果爹,爹的爹和儿子共线,就要先转爹,不然就转两次儿子,就可以成功的压成树了
更具体的操作
插入 这里没什么大坑,就是记得最后splay一下
void insert(int x){
int now=po[0].son[1];
if(now==0){
creat(x,0);
po[0].son[1]=tot;
}else{
while(1){
po[now].sum++;
if(po[now].val==x){
po[now].cnt++;
splay(now,po[0].son[1]);
return ;
}
int nex=x<po[now].val?0:1;
if(!po[now].son[nex]){
int p=creat(x,now);
po[now].son[nex]=p;
splay(p,po[0].son[1]);
return ;
}
now=po[now].son[nex];
}
}
}
寻找排名
也很好理解,记得splay
int find(int x){
int now=po[0].son[1];
while(1){
if(!now) return 0;
if(po[now].val==x) {
splay(now,po[0].son[1]);
return now;
}
int nex=x<po[now].val?0:1;
now=po[now].son[nex];
}
}
删除
还是很好解决的,先找到它,如果它是叶子,直接删掉,如果它没有左儿子的,就直接把右儿子接上去,如果都有,就找它右子树最左的那个,splay进行交换,销毁
void del(int x){
int pos=find(x);
if(!pos) return ;
if(po[pos].cnt>1){
po[pos].cnt--;
po[pos].sum--;
return ;
}else{
if(!po[pos].son[0]&&!po[pos].son[1]) {
po[0].son[1]=0;
return ;
}else if(!po[pos].son[0]){
po[0].son[1]=po[pos].son[1];
po[po[0].son[1]].fa=0;
return ;
} else{
int le=po[pos].son[0];
while(po[le].son[1]) le=po[le].son[1];
splay(le,po[pos].son[0]);
connect(po[pos].son[1],le,1);
connect(le,0,1);
update(le);
}
}
}
按照排名找数
同理
找前缀
记得一直取max
后缀
和前缀反着
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cstring>
#define int long long
using namespace std;
const int maxn=1e6+10;
const int mod=10007;
const int inf=1e9+10;
struct p{
int fa;
int son[2];
int val;
int cnt;
int sum;
}po[maxn];
int tot;
int n,opt;int x;
int ps;
inline int read(){
int res=0,k=1;
char c=getchar();
while(!isdigit(c)){
if(c=='-')k=-1;
c=getchar();
}
while(isdigit(c)){
res=(res<<1)+(res<<3)+c-48;
c=getchar();
}
return res*k;
}
void update(int x){
po[x].sum=po[po[x].son[0]].sum+po[po[x].son[1]].sum+po[x].cnt;
}
int ident(int x){
return po[po[x].fa].son[0]==x?0:1;
}
void connect(int x,int fa,int s){
po[fa].son[s]=x;
po[x].fa=fa;
}
void ro(int x){
int f=po[x].fa;
int gf=po[f].fa;
int xp=ident(x);
int fp=ident(f);
connect(po[x].son[xp^1],f,xp);
connect(f,x,xp^1);
connect(x,gf,fp);
update(f);
update(x);
}
void splay(int x,int to){
to=po[to].fa;
while(po[x].fa!=to){
int f=po[x].fa;
if(po[f].fa==to) ro(x);
else if(ident(x)==ident(f)) ro(f),ro(x);
else ro(x),ro(x);
}
}
int creat(int val,int f){
++tot;
po[tot].fa=f;
po[tot].sum=po[tot].cnt=1;
po[tot].val=val;
return tot;
}
void insert(int x){
int now=po[0].son[1];
if(now==0){
creat(x,0);
po[0].son[1]=tot;
}else{
while(1){
po[now].sum++;
if(po[now].val==x){
po[now].cnt++;
splay(now,po[0].son[1]);
return ;
}
int nex=x<po[now].val?0:1;
if(!po[now].son[nex]){
int p=creat(x,now);
po[now].son[nex]=p;
splay(p,po[0].son[1]);
return ;
}
now=po[now].son[nex];
}
}
}
int find(int x){
int now=po[0].son[1];
while(1){
if(!now) return 0;
if(po[now].val==x) {
splay(now,po[0].son[1]);
return now;
}
int nex=x<po[now].val?0:1;
now=po[now].son[nex];
}
}
void del(int x){
int pos=find(x);
if(!pos) return ;
if(po[pos].cnt>1){
po[pos].cnt--;
po[pos].sum--;
return ;
}else{
if(!po[pos].son[0]&&!po[pos].son[1]) {
po[0].son[1]=0;
return ;
}else if(!po[pos].son[0]){
po[0].son[1]=po[pos].son[1];
po[po[0].son[1]].fa=0;
return ;
} else{
int le=po[pos].son[0];
while(po[le].son[1]) le=po[le].son[1];
splay(le,po[pos].son[0]);
connect(po[pos].son[1],le,1);
connect(le,0,1);
update(le);
}
}
}
int rak(int x){
int now=po[0].son[1];
int ans=0;
while(1){
if(po[now].val==x) {
// splay(now,po[0].son[1]);
ans+=po[po[now].son[0]].sum+1;
splay(now,po[0].son[1]);;
return ans;
}
int nex=x<po[now].val?0:1;
if(nex==1) ans =ans+po[po[now].son[0]].sum+po[now].cnt;
now=po[now].son[nex];
}
}
int kth(int x){
int now=po[0].son[1];
while(1){
int re=po[now].sum-po[po[now].son[1]].sum;
if(po[po[now].son[0]].sum<x&&x<=re){
return po[now].val;
}
if(x<re) now=po[now].son[0];
else now =po[now].son[1],x-=re;
}
}
int lower(int x){
int now=po[0].son[1];
int ans=-inf;
while(now){
if(po[now].val<x) ans=max(ans,po[now].val);
int nex=x<=po[now].val?0:1;//等于也要往左走
now=po[now].son[nex];
}
return ans;
}
int upper(int x){
int now=po[0].son[1];
int ans=inf;
while(now){
if(po[now].val>x) ans=min(ans,po[now].val);
int nex=x<po[now].val?0:1;
now=po[now].son[nex];
}
return ans;
}
signed main(){
n=read();
while(n--){
opt=read();
x=read();
if(opt==1) insert(x);
else if(opt==2) del(x);
else if(opt==3) printf("%d
",rak(x));
else if(opt==4) printf("%d
",kth(x));
else if(opt==5) printf("%d
",lower(x));
else if(opt==6) printf("%d
",upper(x));
}
return 0;
}