平衡树SPLAY

平衡树SPLAY

平衡树这个东西一点用没有,非常有用,而且锻炼码力,是个非常好的模板题!

那么SPLAY这个东西 十分的毒瘤 我只调了一上午就调出来了!(我真棒)

首先 我们要知道平衡树是一棵二叉查找树 他可以处理:加点,删点,前驱,后继,num的排名,某排名的num,合并两棵平衡树,分离两棵平衡树

关于其原理 网上有很多 而且比较简单

对于splay来说 原理并不是最难的 码代码才是最难的!

那么就提一提细节吧:

当这一棵树为空的时候,(tot)不一定=0,而(root=0)

当这一棵树为空的时候,如果新加入一个节点,那么没有必要对于这个节点进行splay操作,(反正我的代码会死循环QAQ)

在rotate完了时候 记得update呀

对于该有返回值的函数 一定要返回值!(可怜孩子调了半天)

还有就是在:(insert(build),find,dele,oprank)函数里面要(splay)

注意:在我的函数(rank)中 我只能找到在平衡树中有的(val)的排名 而不存在的(val)这个函数是查不到的

注释掉的代码也可以查询rank 但是 好丑啊!!

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn=150005;
const int INF=2147480000;
int n,op,num;
inline int read()
{
	int x=0,f=1;
	char c=getchar();
	while(c>57||c<48)
	{
	    if(c=='-') f=0;
         c=getchar();
	}
	while(c>47&&c<58)	
	x=(x<<3)+(x<<1)+(c&15),c=getchar();
   	return f?x:~x+1;
}
struct wow{
	int chi[2],size,v,fa,times;
}node[maxn];
int tot,tot_size;
#define root node[0].chi[1]
int identify(int x){
	return (x==node[node[x].fa].chi[0])?0:1;
}
void connect(int x,int father,int son){
	node[x].fa=father;
	node[father].chi[son]=x;
}
void update(int x){
	node[x].size=node[node[x].chi[0]].size+node[node[x].chi[1]].size+node[x].times;
}
void rotate(int x){
	int y=node[x].fa;
	int mroot=node[y].fa;
	int mrootson=identify(y);
	int yson=identify(x);
	int b=node[x].chi[yson^1];
	connect(b,y,yson);connect(y,x,yson^1);connect(x,mroot,mrootson);
	update(y);update(x);
}
void splay(int at,int to){
	to=node[to].fa;
	while(node[at].fa!=to){
		int up=node[at].fa;
		if(node[up].fa==to)	rotate(at);
		else if(identify(at)==identify(up)){
			rotate(up);rotate(at);
         }
		else{
			rotate(at);rotate(at);
		}
	}
}
	
int create_point(int v,int father){
	tot++;
	node[tot].fa=father;node[tot].v=v;
	node[tot].size=1;node[tot].times=1;
	return tot;
}
void insert(int v){
	tot_size++;
	if(root==0){
		root=1;create_point(v,0);
	}
	else{
		int now=root;
		while(1){
			node[now].size++;
			if(v==node[now].v){
				node[now].times++;
				splay(now,root);
				return ;
			}
			int nxt=(v<node[now].v)?0:1;
			if(!node[now].chi[nxt]){
				int p=create_point(v,now);node[now].chi[nxt]=p;
				splay(p,root);
				return ;
			}
			now=node[now].chi[nxt];	
         }
	}
}
int find(int v){
	int now=root;
	while(1){
		if(node[now].v==v){
			splay(now,root);return now;
		}
		int nxt=(v<node[now].v)?0:1;
		if(!node[now].chi[nxt])	return 0;
		now=node[now].chi[nxt];
	}
}
void dele(int v){
	int pos=find(v);
	if(!pos) return ;
	tot_size--;
	if(node[pos].times>1){
		node[pos].times--;node[pos].size--;
	}
	else{
		if(!node[pos].chi[0] && !node[pos].chi[1])	root=0;
		else if(!node[pos].chi[0]){
			root=node[pos].chi[1];
			node[root].fa=0;
		}
		else{
			int left=node[pos].chi[0];
			while(node[left].chi[1]) left=node[left].chi[1];
			splay(left,node[pos].chi[0]);
			int right=node[pos].chi[1];
			connect(right,left,1);connect(left,0,1);
			update(left);
		}
	}
	return ;
}
int rank(int v){
	/*
	int ans=0,now=root;
	while(1){
		if(node[now].v==v){
			splay(now,root);
			return ans+node[node[now].chi[0]].size+1;
		} 
		if(!now)	return 0;
		if(v<node[now].v) now=node[now].chi[0];
		else{
			ans+=node[node[now].chi[0]].size+node[now].times;
			now=node[now].chi[1];
		}
	}*/
	int pos=find(v);
   	return node[node[pos].chi[0]].size+1;
}
int oprank(int x){
	int sum=0,now=root;
	if(x>tot_size)	return 0;
	while(1){
		int mleftsum=node[now].size-node[node[now].chi[1]].size;
		if(x>node[node[now].chi[0]].size && x<=mleftsum) break;
		if(x<mleftsum)	now=node[now].chi[0];
		else{
			x-=mleftsum;
			now=node[now].chi[1];
		} 
	}
	splay(now,root);
	return node[now].v;
}
int upper(int v){
	int now=root,ans=INF;
	while(now){
		if(node[now].v>v && node[now].v<ans) ans=node[now].v;
		if(v<node[now].v) now=node[now].chi[0];
		else now=node[now].chi[1];
	}
	return ans;
}
int lower(int v){
	int now=root,ans=-INF;
	while(now){
		if(node[now].v<v && node[now].v>ans) ans=node[now].v;
		if(v>node[now].v) now=node[now].chi[1];
		else now=node[now].chi[0];
	}
	return ans;
}
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		 op=read();num=read();
		 if(op==1)
		 	insert(num);
		 else if(op==2)
		 	dele(num);
		 else if(op==3)
		 	printf("%d
",rank(num));
		 else if(op==4)
		 	printf("%d
",oprank(num));
		 else if(op==5)
		 	printf("%d
",lower(num));
		 else
		 	printf("%d
",upper(num));
	}
}
原文地址:https://www.cnblogs.com/mendessy/p/11755771.html