Splay

前言

2020 年学的最后一个算法!

然而到了现在才来填

有的人说 Splay 常数大,还难打。

但是这迟早是要学的:总会遇到 LCT

基本操作

约定

\(cnt_i\) :节点 \(i\) 重复的个数

\(val_i\):节点 \(i\) 的权值

\(sz_i\):节点 \(i\) 的子树大小

\(ch_{i,0/1}\):节点 \(i\) 的左右儿子

\(fa_i\):节点 \(i\) 的父亲

\(root\):当前根节点

\(tot\):总共节点数

\(\text{Get(x)}\):获得 \(x\) 是左子树还是右子树

\(\text{Up(x)}\):更新当前子树大小

int cnt[N],val[N],sz[N],fa[N],ch[N][2],root,tot;
// ! ch[x][0]:left
// ! ch[x][1]:right
inline int Get(int x) { return ch[fa[x]][1]==x; }
inline void Up(int x) { sz[x]=cnt[x]+sz[ch[x][0]]+sz[ch[x][1]]; }

旋转 Rotate

作用:把 x 旋转到 y 的位置并且保持 BST 性质
  1. y 是 z 的哪个儿子, x 就是 z 的哪个儿子
  2. x 是 y 的哪个儿子, y 就是 x 的那个儿子的兄弟
  3. x 是 y 的哪个儿子, y 的那个儿子就是 x 的那个儿子的兄弟
inline void Rot(int x) {
	register int y=fa[x],z=fa[y],k=Get(x);
	ch[z][Get(y)]=x,fa[x]=z;
	ch[y][k]=ch[x][k^1],fa[ch[x][k^1]]=y;
	ch[x][k^1]=y,fa[y]=x;
	Up(y),Up(x);
}

伸展 Splay

既然叫 Splay ,肯定有 Splay 这个操作。

作用:将 \(x\) 旋转到 \(goal\) 下面

最朴素的方法是一直 \(\text{Rotate}\),但是要考虑一种情况

如果是一条链,这样子 Splay 后仍然是一条链

只要先旋转父亲,可以保证链深度减半

inline void Splay(int x,int goal=0) {
	for(int y;(y=fa[x])^goal;Rot(x))
	if(fa[y]^goal)Rot(Get(x)^Get(y)?x:y);
	if(!goal)root=x;
}

正规操作

Find

找到最接近值为 \(x\) 的点,并伸展到根

inline void Find(int x) {
	if(!root)return;
	register int u=root;
	while(ch[u][x>val[u]] && x^val[u])
		u=ch[u][x>val[u]];
	Splay(u);
}

Insert

插入值为 \(x\) 的点

  1. 找到插入点
  2. 如果有重复,直接加对应 \(cnt\)
  3. 否则新加一个点
  4. 把新节点伸展到根
inline void Ins(int x) {
	register int u=root,fu=0;
	while(u && val[u]^x)fu=u,u=ch[u][x>val[u]];
	if(u)cnt[u]++;
	else {
		u=++tot;
		if(fu)ch[fu][x>val[fu]]=u;
		ch[u][0]=ch[u][1]=0;
		fa[u]=fu,val[u]=x;
		cnt[u]=sz[u]=1;
	}
	Splay(u);
}

前驱和后继

可以写成一个函数

因为只要先 \(\text{Find(x)}\),那么

前驱就是根的左子树里最大的一个

后继就是根的右子树里最小的一个

inline int Nxt(int x,int f) {
	Find(x);
	register int u=root;
	if(val[u]>x && f)return u;
	if(val[u]<x &&!f)return u;
	u=ch[u][f];
	while(ch[u][f^1])u=ch[u][f^1];
	return u;
}

Delete

删除值为 \(x\) 的点:有更快的写法,但蒟蒻不会

找到 \(x\) 的前驱和后继

\(\text{Splay(pre),Splay(suc,pre)}\)

好了现在直接删点即可

inline void Del(int x) {
	register int lst=Nxt(x,0),nxt=Nxt(x,1);
	Splay(lst),Splay(nxt,lst);
	register int del=ch[nxt][0];
	if(cnt[del]>1)--cnt[del],Splay(del);
	else ch[nxt][0]=0;
}

Kth

找到排名为 \(k\) 的节点权值

inline int Kth(int k) {
	register int u=root,sn=0;
	for(;;) {
		sn=ch[u][0];
		if(k>sz[sn]+cnt[u])
			k-=sz[sn]+cnt[u],u=ch[u][1];
		else if(sz[sn]>=k)u=sn;
		else return val[u];
	}
}

例 1

普通平衡树

Tips:\(\text{Splay}\)一般先加两个哨兵节点,即 \(+\infin,-\infin\)

防止 Splay 出锅

#include<bits/stdc++.h>
using namespace std;
const int N=100005,INF=2147483647;
int cnt[N],val[N],sz[N],fa[N],ch[N][2],root,tot;
// ! ch[x][0]:left
// ! ch[x][1]:right
inline int Get(int x) { return ch[fa[x]][1]==x; }
inline void Up(int x) { sz[x]=cnt[x]+sz[ch[x][0]]+sz[ch[x][1]]; }
inline void Rot(int x) {
	register int y=fa[x],z=fa[y],k=Get(x);
	ch[z][Get(y)]=x,fa[x]=z;
	ch[y][k]=ch[x][k^1],fa[ch[x][k^1]]=y;
	ch[x][k^1]=y,fa[y]=x;
	Up(y),Up(x);
}
inline void Splay(int x,int goal=0) {
	for(int y;(y=fa[x])^goal;Rot(x))
	if(fa[y]^goal)Rot(Get(x)^Get(y)?x:y);
	if(!goal)root=x;
}
inline void Find(int x) {
	if(!root)return;
	register int u=root;
	while(ch[u][x>val[u]] && x^val[u])
		u=ch[u][x>val[u]];
	Splay(u);
}
inline void Ins(int x) {
	register int u=root,fu=0;
	while(u && val[u]^x)fu=u,u=ch[u][x>val[u]];
	if(u)cnt[u]++;
	else {
		u=++tot;
		if(fu)ch[fu][x>val[fu]]=u;
		ch[u][0]=ch[u][1]=0;
		fa[u]=fu,val[u]=x;
		cnt[u]=sz[u]=1;
	}
	Splay(u);
}
inline int Nxt(int x,int f) {
	Find(x);
	register int u=root;
	if(val[u]>x && f)return u;
	if(val[u]<x &&!f)return u;
	u=ch[u][f];
	while(ch[u][f^1])u=ch[u][f^1];
	return u;
}
inline void Del(int x) {
	register int lst=Nxt(x,0),nxt=Nxt(x,1);
	Splay(lst),Splay(nxt,lst);
	register int del=ch[nxt][0];
	if(cnt[del]>1)--cnt[del],Splay(del);
	else ch[nxt][0]=0;
}
inline int Kth(int k) {
	register int u=root,sn=0;
	for(;;) {
		sn=ch[u][0];
		if(k>sz[sn]+cnt[u])
			k-=sz[sn]+cnt[u],u=ch[u][1];
		else if(sz[sn]>=k)u=sn;
		else return val[u];
	}
}
int T;
int main() {
	Ins(-INF);
	Ins(+INF);
	scanf("%d",&T);
	for(int opt,x;T--;) {
		scanf("%d%d",&opt,&x);
		if(opt==1)Ins(x);
		else if(opt==2)Del(x);
		else if(opt==3)Find(x),printf("%d\n",sz[ch[root][0]]);
		else if(opt==4)printf("%d\n",Kth(x+1));
		else if(opt==5)printf("%d\n",val[Nxt(x,0)]);
		else if(opt==6)printf("%d\n",val[Nxt(x,1)]);
	}
} 

不正常操作

区间翻转

由于原数列的顺序已经给定,就不能按照权值排序。考虑按照点的编号排序。

这样的话建树也方便了许多,\(org_{i}\) 代表 \(i\) 对应的值

int bui(int l,int r,int f) {
	if(l>r)return 0;
	register int mid=l+r>>1,u=++tot;
	fa[u]=f,sz[u]=1,vl[u]=org[mid];
	ch[u][0]=bui(l,mid-1,u);
	ch[u][1]=bui(mid+1,r,u);
	return Up(u),u;
}

对于翻转,考虑用 \(\text{Splay}\) 的性质,将 \(l-1\) 翻转到根,再将 \(r+1\) 翻转到 \(l-1\)

然后将 \([l,r]\) 对应子树打上懒标记即可。

翻转可以用第 K 大结合 \(\text{Splay}\)

输出直接中序遍历

#include<bits/stdc++.h>
using namespace std;
const int N=100005,inf=2147483647;
int fa[N],vl[N],ch[N][2],sz[N],tg[N],org[N],root,tot;
inline int Get(int x) { return ch[fa[x]][1]==x; }
inline void Up(int x) { sz[x]=1+sz[ch[x][0]]+sz[ch[x][1]]; }
inline void Rot(int x) {
	register int y=fa[x],z=fa[y],k=Get(x);
	ch[z][Get(y)]=x,fa[x]=z;
	ch[y][k]=ch[x][k^1],fa[ch[x][k^1]]=y;
	ch[x][k^1]=y,fa[y]=x;
	Up(y),Up(x);
}
inline void Splay(int x,int goal=0) {
	for(int y;(y=fa[x])^goal;Rot(x))
	if(fa[y]^goal)Rot(Get(x)^Get(y)?x:y);
	if(!goal)root=x;
}
int bui(int l,int r,int f) {
	if(l>r)return 0;
	register int mid=l+r>>1,u=++tot;
	fa[u]=f,sz[u]=1,vl[u]=org[mid];
	ch[u][0]=bui(l,mid-1,u);
	ch[u][1]=bui(mid+1,r,u);
	return Up(u),u;
}
inline void down(int x) {
	if(!tg[x])return;
	tg[ch[x][0]]^=1;
	tg[ch[x][1]]^=1;
	tg[x]=0;
	swap(ch[x][0],ch[x][1]);
}
inline int Kth(int x){
	register int u=root,sn=0;
	for(;;) {
		down(u),sn=ch[u][0];
		if(x==sz[sn]+1)return u;
		else if(x<=sz[sn])u=sn;
		else x-=sz[sn]+1,u=ch[u][1];
	}
}
inline void rev(int l,int r) {
	register int p=Kth(l-1),q=Kth(r+1),rv;
	Splay(p),Splay(q,p);
	rv=ch[root][1],rv=ch[rv][0],tg[rv]^=1;
}
void print(int x) {
	down(x);
	if(ch[x][0])print(ch[x][0]);
	if(-inf<vl[x] && vl[x]<inf)printf("%d ",vl[x]);
	if(ch[x][1])print(ch[x][1]);
}
int n,T; 
int main() {
	scanf("%d%d",&n,&T);
	org[1]=-inf,org[n+2]=inf;
	for(int i=1;i<=n;i++)org[i+1]=i;
	root=bui(1,n+2,0);
	for(int l,r;T--;) {
		scanf("%d%d",&l,&r);
		rev(l+1,r+1);
	}
	print(root);
}
原文地址:https://www.cnblogs.com/KonjakLAF/p/14398355.html