算法初探 平衡树

更新记录

【1】2020.08.14-17:01

  • 1.完善Splay内容

正文

Splay

Splay是一种没有用随机数函数的平衡树,它依靠伸展操作来维持平衡

不过也正是这样,导致其能维护数列的某些特殊的区间操作,例如区间反转

直接来说Splay的核心:rotate与splay操作

rotate就是旋转,又称之为上旋

就是将一个非根结点向上旋转,旋转前后都满足二叉搜索树的性质

那么易得一个左儿子n旋转到father的位置:

  • 其左儿子比n小,不动
  • 其右儿子比n大,但是比father小,连接到father的左儿子上

看个例子,考虑一棵完全二叉搜索树:

将其中结点2旋转到根上

按照之前的结论,将2转到根上,然后将3连接到5上

最后就是这样:

很简单,对吧?

所以我们很容易的就写出了rotate的代码

inline void rotate(int n){
	int fa=t[n].fa,grfa=t[t[n].fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa);update(n);
}

在这里可能你不知道某些函数是干嘛的
没关系,这些函数因为太简单了被我放在了后面统一说明

splay就是多个rotate集合,用来将一个结点旋转到根上

考虑结点n,其父亲fa,父亲的父亲grfa

  • 如果三点一线:先旋转fa,再旋转n
  • 其他情况:旋转两次n

对于三点一线的情况,如果旋转两次n,极易出现链的情况且易被卡,考虑一条链即可

所以我们依然能够非常容易的写出代码

inline void splay(int n){
	while(t[n].fa){
		if(t[t[n].fa].fa)
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		rotate(n);
	}
	root=n;
}

好了这就是splay的核心操作了,是不是非常简单呢?

接下来是其他非主要函数的讲解

connect连边函数,将两个点连在一起

inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}

update更新以此结点为根的子树的大小

inline void update(int n){
	if(n){
		t[n].size=t[n].num;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}

confirm函数用来确认这个结点是父结点的左儿子还是右儿子

inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}

clear用来清除一个结点的全部信息

inline void clear(int n){
	t[n].fa=t[n].num=t[n].size=t[n].v=t[n].son[0]=t[n].son[1]=0;
}

insert插入结点

  • 当树中没有结点时:新建结点
  • 当树中有结点的时候
    • 当要插入的结点已经存在,直接num+1即可
    • 不存在时,新建结点

具体细节:

inline void insert(int n){
	if(!root){
		root=++node;t[root].v=n;
		t[root].num=t[root].size=1;
		return;
	}
	int nd=root,fa=0;
	while(1){
		if(t[nd].v==n){
			t[nd].num+=1;
			update(nd);update(fa);
//更新
			splay(nd);
//旋转
			return;
		}
		fa=nd,nd=t[nd].son[n>t[nd].v];
		if(!nd){
			node+=1;t[node].v=n;
			t[node].num=t[node].size=1;
			t[node].fa=fa;
			t[fa].son[n>t[fa].v]=node;
//n>t[fa].v是根据二叉搜索树的性质来确定这个结点在哪里
			update(fa);splay(node);
			return;
		}
	}
}

numrank用来查看一个数的排名

依然考虑二叉搜索树的性质:

  • 如果这个数比这个结点小,那么搜索左子树
  • 如果这个数和这个结点一样大,那么它肯定比左子树的所有节点都大,sum加上左子树的size,之后+1就是排名
  • 如果这个数比这个结点大,那么sum加上左子树的size,加上这个结点的num,之后搜索右子树
inline int numrank(int n){
	int nd=root,sum=0;
	while(1){
		if(n<t[nd].v){
			nd=t[nd].son[0];
			continue;
		}
		sum+=t[t[nd].son[0]].size;
		if(n==t[nd].v){
			splay(nd);
			return sum+1;
		}
		sum+=t[nd].num;
		nd=t[nd].son[1];
	}
}

ranknum用来查询处于这个排名的数

  • 如果现在排名等于小于左子树的结点的个数,那么搜索左子树
  • 否则,如果现在排名小于左子树的结点的个数+这个结点的num,那么就是这个结点
  • 否则,排名减去左子树的size+结点的num,搜索右子树
inline int ranknum(int n){
	int nd=root;
	while(1){
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+t[nd].num){
			splay(nd);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+t[nd].num;
			nd=t[nd].son[1];
		}
	}
}

presuf用来求前趋与后继,本来是两个函数被我合并成了一个

以查前趋为例,插入要查询的n
此时n是根结点,前趋肯定是比n小
于是乎我们从左子树疯狂向右找
最后找到,之后删除n

inline int presuf(int n,bool y,bool r){
	if(y) insert(n);
	int nd=t[root].son[r];
	while(t[nd].son[r^1]) nd=t[nd].son[r^1];
	if(y) del(n);
	return nd;
}

del就是删除结点了

先把n转上来

  • 如果num>1那么直接-1就行了
  • 否则,如果没儿子,那说明就这一个结点,清空
  • 然后哪边缺root是哪边
  • 之后如果都有就转一下,连个边
inline void del(int n){
	numrank(n);
	int cp=root;bool more=1;
	if(t[root].num>1){
		more=0;t[root].num-=1;
		update(root);
	}
	else if(!t[root].son[0]&&!t[root].son[1]) root=0;
	else if(!t[root].son[0]){
		root=t[root].son[1];
		t[root].fa=0;
	}
	else if(!t[root].son[1]){
		root=t[root].son[0];
		t[root].fa=0;
	}
	else{
		splay(presuf(root,0,0));
		connect(root,t[cp].son[1],1);
	}
	if(more) clear(cp);
	update(root);
}

完整Splay代码:

#include<cstdio>
#include<iostream>
#define N 1000100
struct baltree{
	int fa,v,size,num,son[2];
}t[N];
int root,node,n,a,b;
inline void del(int n);
inline int read(){
	int sum=0,chs=1;char c=getchar();
	while(!isdigit(c)){
		if(c=='-') chs=-1;
		c=getchar();
	}
	while(isdigit(c)){
		sum=(sum<<1)+(sum<<3)+c-48;
		c=getchar();
	}
	return sum*chs;
}
inline void clear(int n){
	t[n].fa=t[n].num=t[n].size=t[n].v=t[n].son[0]=t[n].son[1]=0;
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void update(int n){
	if(n){
		t[n].size=t[n].num;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[t[n].fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa);update(n);
}
inline void splay(int n){
	while(t[n].fa){
		if(t[t[n].fa].fa)
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		rotate(n);
	}
	root=n;
}
inline void insert(int n){
	if(!root){
		root=++node;t[root].v=n;
		t[root].num=t[root].size=1;
		return;
	}
	int nd=root,fa=0;
	while(1){
		if(t[nd].v==n){
			t[nd].num+=1;
			update(nd);update(fa);
			splay(nd);
			return;
		}
		fa=nd,nd=t[nd].son[n>t[nd].v];
		if(!nd){
			node+=1;t[node].v=n;
			t[node].num=t[node].size=1;
			t[node].fa=fa;
			t[fa].son[n>t[fa].v]=node;
			update(fa);splay(node);
			return;
		}
	}
}
inline int numrank(int n){
	int nd=root,sum=0;
	while(1){
		if(n<t[nd].v){
			nd=t[nd].son[0];
			continue;
		}
		sum+=t[t[nd].son[0]].size;
		if(n==t[nd].v){
			splay(nd);
			return sum+1;
		}
		sum+=t[nd].num;
		nd=t[nd].son[1];
	}
}
inline int ranknum(int n){
	int nd=root;
	while(1){
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+t[nd].num){
			splay(nd);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+t[nd].num;
			nd=t[nd].son[1];
		}
	}
}
inline int presuf(int n,bool y,bool r){
	if(y) insert(n);
	int nd=t[root].son[r];
	while(t[nd].son[r^1]) nd=t[nd].son[r^1];
	if(y) del(n);
	return nd;
}
inline void del(int n){
	numrank(n);
	int cp=root;bool more=1;
	if(t[root].num>1){
		more=0;t[root].num-=1;
		update(root);
	}
	else if(!t[root].son[0]&&!t[root].son[1]) root=0;
	else if(!t[root].son[0]){
		root=t[root].son[1];
		t[root].fa=0;
	}
	else if(!t[root].son[1]){
		root=t[root].son[0];
		t[root].fa=0;
	}
	else{
		splay(presuf(root,0,0));
		connect(root,t[cp].son[1],1);
	}
	if(more) clear(cp);
	update(root);
}
signed main(){
	n=read();
	while(n--){
		a=read(),b=read();
		if(a==1) insert(b);
		else if(a==2) del(b);
		else if(a==3) printf("%d\n",numrank(b));
		else if(a==4) printf("%d\n",ranknum(b));
		else if(a==5) printf("%d\n",t[presuf(b,1,0)].v);
		else printf("%d\n",t[presuf(b,1,1)].v);
	}
}

区间操作

例如反转区间 \([l,r]\)

那么将l的前趋转到根结点,r的后继转到根结点右方

此时r的后继的左方的子树就是区间 \([l,r]\)

打上标记,之后看见标记反转就可以啦

为什么是这样?

先来问两个问题:

  1. Splay旋转的特点?
  2. 归并排序的思想?

那么聪明的同学直接就能想到答案啦

这个子树在根结点和其父结点不被操作的时候是不会改变的

操作的时候呢?
标记就下传啦!!

归并排序的思想就是先将大区间整体反转,然后小区间反转......

此时一定抛弃二叉搜索树的思想,此时的Splay维护的是区间!

#include<iostream>
#include<cstdio>
#define N 1000100
const int INF=2147483647;
using namespace std;
struct baltree{
	int son[2],v,fa,size,ret;
}t[N];
int root,node=0,n,m,a,b;
inline void pd(int n){
	if(t[n].ret){
		t[n].ret=0;
		t[t[n].son[0]].ret^=1;t[t[n].son[1]].ret^=1;
		swap(t[n].son[0],t[n].son[1]);
	}
}
inline void update(int n){
	if(n){
		t[n].size=1;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa),update(n);
}
inline void splay(int n,int p){
	while(t[n].fa!=p){
		if(t[t[n].fa].fa!=p){
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		}
		rotate(n);
	}
	if(!p) root=n;
}
inline void insert(int v){
	if(!root){
		root=++node;
		t[root].v=v;
		t[root].size=1;
		return;
	}
	int n=root,fa=0;
	while(1){
		fa=n,n=t[n].son[v>t[n].v];
		if(!n){
			t[++node].v=v;
			t[node].size=1;
			t[node].fa=fa;
			t[fa].son[v>t[fa].v]=node;
			update(fa);splay(node,0);
			return;
		}
	}
}
inline int ranknum(int n){
	int nd=root;
	while(1){
		pd(nd);
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+1){
			splay(nd,0);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+1;
			nd=t[nd].son[1];
		}
	}
}
inline void reverse(int l,int r,int lth,int rth){
	splay(lth,0);splay(rth,root);
	t[t[rth].son[0]].ret^=1;
}
inline void outdata(int n){
	pd(n);
	if(t[n].son[0]) outdata(t[n].son[0]);
	if(t[n].v>1&&t[n].v<=(::n+1)) printf("%d ",t[n].v-1);
	if(t[n].son[1]) outdata(t[n].son[1]);
}
signed main(){
	cin>>n>>m;
	for(int i=1;i<=n+2;i++)
		insert(i);
	for(int i=0;i<m;i++){
		cin>>a>>b;
		reverse(a,b,ranknum(a),ranknum(b+2));
	}
	outdata(root);
}

看完了板子,来看例题吧

序列终结者
这就是板子的变形啊

#pragma gcc optimize(2)
#pragma gcc optimize(3)
#pragma gcc optimize(-Ofast)
#include<iostream>
#include<cstdio>
#define N 1001000
#define INF 0x3f3f3f3f
int n,m,node,rt,a,b,c,d,rth;
struct baltree{
	int size,son[2],fa,v,ret,add,maxn;
}t[N];
inline int max(int a,int b){return a>b?a:b;}
inline void pu(int n){
	t[n].size=1;
	if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
	if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	t[n].maxn=t[n].v;
	if(t[n].son[0]) t[n].maxn=max(t[n].maxn,t[t[n].son[0]].maxn);
	if(t[n].son[1]) t[n].maxn=max(t[n].maxn,t[t[n].son[1]].maxn);
}
inline void pd(int n){
	if(t[n].add){
		t[t[n].son[0]].add+=t[n].add,
		t[t[n].son[0]].v+=t[n].add;
		t[t[n].son[1]].add+=t[n].add,
		t[t[n].son[1]].v+=t[n].add;
		t[t[n].son[0]].maxn+=t[n].add,
		t[t[n].son[1]].maxn+=t[n].add;		
	}
	if(t[n].ret){
		t[t[n].son[0]].ret^=1;
		t[t[n].son[1]].ret^=1;
		std::swap(t[n].son[0],t[n].son[1]);
	}
	t[n].ret=t[n].add=0;
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[fa].fa;
	bool np=confirm(n),fap=confirm(fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	pd(n),pd(fa);
	pu(fa),pu(n);
}
inline void splay(int n,int p){
	while(t[n].fa!=p){
		if(t[t[n].fa].fa!=p){
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		}
		rotate(n);
	}
	if(!p) rt=n;
}
inline void insert(int v){
	if(!rt){
		rt=++node;
		t[rt].v=v;
		t[rt].maxn=v;
		t[rt].size=1;
		return;
	}
	int n=rt,fa=0;
	while(1){
		fa=n,n=t[n].son[v>t[n].v];
		if(!n){
			t[++node].v=v;
			t[node].size=1;
			t[node].fa=fa;
			t[node].maxn=v;
			t[fa].son[v>t[n].v]=node;
			pu(fa);splay(node,0);
			return;
		}
	}
}
inline int ranknode(int rank){
	int n=rt;
	while(1){
		pd(n);
		if(rank<=t[t[n].son[0]].size){
			n=t[n].son[0];
			continue;
		}
		else if(rank<=t[t[n].son[0]].size+1){
			splay(n,0);return n;
		}
		else{
			rank-=t[t[n].son[0]].size+1;
			n=t[n].son[1];
		}
	}
}
inline void reverse(int l,int r){
	t[t[rth].son[0]].ret^=1;
}
inline void add(int l,int r,int v){
	t[t[rth].son[0]].add+=v;
	t[t[rth].son[0]].v+=v;
	t[t[rth].son[0]].maxn+=v;
}
inline void querymax(int l,int r){
	printf("%d\n",t[t[rth].son[0]].maxn);
}
int main(){
	scanf("%d%d",&n,&m);
	for(register int i=0;i<n;++i)
		insert(0);
	insert(INF),insert(-INF);
	for(register int i=0;i<m;++i){
		scanf("%d%d%d",&a,&b,&c);
		rth=ranknode(c+2);
		splay(ranknode(b),0);splay(rth,rt);
		if(a==1){
			scanf("%d",&a);
			add(b,c,a);
		}
		else if(a==2) reverse(b,c);
		else if(a==3) querymax(b,c);
	}
}
原文地址:https://www.cnblogs.com/zythonc/p/13493611.html