平衡树

平衡树

概述

一种数据结构。代码巨长。

其实平衡树的思想挺简单的,代码也不难写。

平衡树满足的性质:

(1、)左儿子权值小于父亲,右儿子权值大于父亲

(2、)左右儿子分别是平衡树

若仅是这样,很容易被毒瘤出题人卡成链,所以我们再人为的(虽然之前的性质也是人为的)给他加上一个性质(k),让这棵树不仅是权值满足上述性质,(k)也满足上述性质。

(k)是随机化出来的。我们可以依据(k)通过旋转改变树高,这样复杂度就变低了。期望时间复杂度是(O(nlogn))的。

例题

普通平衡树

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

(1.)插入(x)

(2.)删除(x)数(若有多个相同的数,因只删除一个)

(3.)查询(x)数的排名(排名定义为比当前数小的数的个数(+1)。若有多个相同的数,因输出最小的排名)

(4.)查询排名为(x)的数

(5.)(x)的前驱(前驱定义为小于(x),且最大的数)

(6.)(x)的后继(后继定义为大于(x),且最小的数)

输入输出格式

输入格式:

第一行为(n),表示操作的个数,下面(n)行每行有两个数(opt)(x)(opt)表示操作的序号( (1 leq opt leq 6))

输出格式:

对于操作(3,4,5,6)每行输出一个数,表示对应答案

懒得放样例了= =

定义

struct szh{
	int v, k, ls, rs, su, si, f;
    //权值 k 左儿子 右儿子 v出现的次数 树的大小 父亲;
	szh(){v = k = -inf, ls = rs = -1, su = si = 0, f = -1;}
}tr[100005];

初始化

我们搞一个权值炒鸡大,(k)炒鸡小的炒鸡点来当做根。

void T_begin(){
	tr[0].v = inf, tr[0].k = -inf, tr[0].su = tr[0].si = 1;
}

添加

要加入一个数,首先我们要找到他应该位于哪个位置。按照性质,只要左右判断一下就好啦。

void add(int u, int v){ //寻找添加点  
	Dier &t = treap[u];
	if(t.v == -INF){build(u, v); return;} //如果这个节点不存在的话,就新建一个
	if(t.v == v){add(u); return;} //如果当前点就是我们要找的点,就给这个点的计数器加一
    //以上两个函数下面会说
	if(t.v > v){ //若要加入的值大于当前点,说明它应该在它的左子树里
		if(t.ls == -1) t.ls = ++cnt, treap[cnt].f = u; //若左子树没有,就新建一个
		add(t.ls, v); //递归左子树
	}
	else{ //若大于它,就在它的右子树里
		if(t.rs == -1) t.rs = ++cnt, treap[cnt].f = u; //没有就新建
		add(t.rs, v); //递归右子树
	}
	turn(u), updata(u); //旋转,更新,一会讲
}

新建一个子树非常容易,只要记录一下(v、sum、size、k)就好啦

void build(int u, int v){ //初始化 
	treap[u].v = v, treap[u].size = treap[u].sum = 1, treap[u].k = rand();
}

若这个点出现多次,就++计数器

void add(int u){ //添加 
	++treap[u].sum, ++treap[u].size;
}

喜欢压行的可以把这三个函数变成一个

(updata)

更新节点信息的时候,只需要修改(size)

void updata(int u){ //更新节点信息 
	Dier &t = treap[u];
	t.size = t.sum;
	if(t.ls != -1) t.size += treap[t.ls].size; //若左子树不为空
	if(t.rs != -1) t.size += treap[t.rs].size; //若右子树不为空
}

旋转

旋转是按照我们随机化出来的(k)的大小操作的。每加入一个值,我们就看一下这棵树需不需要旋转。显然,每次我们只需要旋转一次就够了。证明脑补。

每次旋转,分为左旋和右旋。这里的方向与平时我们所说的方向相反,即左旋为将左儿子作为根,原根的左儿子变成新根的右儿子,右旋相反。

void turn(int u){ //旋转 
	Dier &t = treap[u];
	if((t.ls != -1) && (t.k > treap[t.ls].k)) left_turn(u);
	else if((t.rs != -1) && (t.k > treap[t.rs].k)) right_turn(u);
}

这是判断该左旋还是右旋

void left_turn(int u){ //向左转 
	Dier &t = treap[u];
	int f = t.f, newt = t.ls;
	t.ls = treap[newt].rs; //原根的左儿子连到新根的右儿子上
	if(treap[newt].rs != -1) treap[treap[newt].rs].f = u; //更新爸爸
	t.f = newt, treap[newt].rs = u; //原根的父亲连到新根上,新根的右儿子是原根
	if(treap[f].ls == u) treap[f].ls = newt, treap[newt].f = f; //将原根的爸爸与新根相连
	else treap[f].rs = newt, treap[newt].f = f;
	updata(u), updata(newt); //更新一下,注意顺序
}
void right_turn(int u){ //向右转 
	Dier &t = treap[u];
	int f = t.f, newt = t.rs;
	t.rs = treap[newt].ls; //原根的右儿子即新根的左儿子
	if(treap[newt].ls != -1) treap[treap[newt].ls].f = u;
	t.f = newt, treap[newt].ls = u; //原根与新根相连
	if(treap[f].ls == u) treap[f].ls = newt, treap[newt].f = f; //原根的爸爸与新根相连
	else treap[f].rs = newt, treap[newt].f = f;
	updata(u), updata(newt); //更新
}

删除

同添加,先找到点,再删除

void del(int u, int v){ //寻找删除点 
	Dier &t = treap[u];
	if(t.v == v){del(u); return;} //若就是当前点,直接删除
	if(t.v > v) del(t.ls, v); //递归左儿子
	else del(t.rs, v); //递归右儿子
	updata(u); //不要忘记更新
}
void del(int u){ //删除点 
	if(treap[u].sum != 1) --treap[u].size, --treap[u].sum; //如果这个点出现过很多次,计数器--
	else end(u); //否则,我们要将他旋转到叶子结点再删除
}

旋转到叶子结点,就按照旋转的规则,转下去就好了

void end(int u){ //将某个点旋转到叶子结点 
	Dier &t = treap[u];
	t.k = INF;
	while(t.ls != -1 || t.rs != -1) //只要她还有儿子
		if(t.ls != -1) //若有左儿子
			if(t.rs != -1) //若也有右儿子
				if(treap[t.ls].k < treap[t.rs].k) left_turn(u); //若左儿子的k小,左旋
				else right_turn(u); //反正右旋
			else left_turn(u); //反之左旋
		else right_turn(u); //反之右旋
	if(treap[t.f].ls == u) treap[t.f].ls = -1; //删除
	else treap[t.f].rs = -1;
	for(int i = t.f; i != -1; i = treap[i].f) updata(i); //更新这条路径上的所有点
}

查询

查询(x)的排名

int rak(int u,int k){ //查询x数的排名
	Dier &t = treap[u];
	if(t.v == k) return treap[t.ls].size + 1; //根据定义
	if(t.v > k) return rak(t.ls, k);
	return rak(t.rs, k) + treap[t.ls].size + t.sum;
}

查询排名为(k)的数

int ask_rak(int u, int k){ //查询排名为x的数
	Dier &t = treap[u];
	if(treap[t.ls].size >= k) return ask_rak(t.ls, k);
	int s = treap[t.ls].size + t.sum;
	if(s >= k) return t.v;
	return ask_rak(t.rs, k - s);
}

找第一个大于他的数

int ask_upper(int u,int k){ //找第一个小于他的数 
	if(u == -1) return -INF;
	Dier &t = treap[u];
	if(t.v < k) return max(t.v, ask_upper(t.rs, k));
	return ask_upper(t.ls, k);
}

找第一个小于他的数

int ask_lower(int u, int k){//找第一个大于他的数 
	if(u == -1) return INF;
	Dier &t = treap[u];
	if(t.v > k) return min(t.v, ask_lower(t.ls, k));
	return ask_lower(t.rs, k);
}

完整代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
long long read(){
	long long x = 0; int f = 0; char c = getchar();
	while(c < '0' || c > '9') f |= c == '-', c = getchar();
	while(c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
	return f? -x:x;
}

const int inf = 2147483647;
int n, cnt;
struct szh{
	int v, k, ls, rs, su, si, f;
	szh(){v = k = -inf, ls = rs = -1, su = si = 0, f = -1;}
}tr[100005];

void T_begin(){
	tr[0].v = inf, tr[0].k = -inf, tr[0].su = tr[0].si = 1;
}

void u_d(int u){
	szh &t = tr[u];
	t.si = t.su;
	if(t.ls != -1) t.si += tr[t.ls].si;
	if(t.rs != -1) t.si += tr[t.rs].si;
}
void l_t(int u){
	szh &t = tr[u];
	int f = t.f, nt = t.ls;
	t.ls = tr[nt].rs;
	if(tr[nt].rs != -1) tr[tr[nt].rs].f = u;
	t.f = nt, tr[nt].rs = u;
	if(tr[f].ls == u) tr[f].ls = nt, tr[nt].f = f;
	else tr[f].rs = nt, tr[nt].f = f;
	u_d(u), u_d(nt);
}
void r_t(int u){
	szh &t = tr[u];
	int f = t.f, nt = t.rs;
	t.rs = tr[nt].ls;
	if(tr[nt].ls != -1) tr[tr[nt].ls].f = u;
	t.f = nt, tr[nt].ls = u;
	if(tr[f].ls == u) tr[f].ls = nt, tr[nt].f = f;
	else tr[f].rs = nt, tr[nt].f = f;
	u_d(u), u_d(nt);
}
void turn(int u){
	szh &t = tr[u];
	if(t.ls != -1 && tr[t.ls].k < t.k) l_t(u);
	else if(t.rs != -1 && tr[t.rs].k < t.k) r_t(u);
}

void build(int u, int v){
	tr[u].v = v, tr[u].si = tr[u].su = 1, tr[u].k = rand();
}
void add(int u){
	++tr[u].su, ++tr[u].si;
}
void add(int u, int v){
	szh &t = tr[u];
	if(t.v == -inf){build(u, v); return;}
	if(t.v == v){add(u); return;}
	if(t.v > v){
		if(t.ls == -1) t.ls = ++cnt, tr[cnt].f = u;
		add(t.ls, v);
	}
	else{
		if(t.rs == -1) t.rs = ++cnt, tr[cnt].f = u;
		add(t.rs, v);
	}
	turn(u), u_d(u);
}

void end(int u){
	szh &t = tr[u];
	while(t.ls != -1 || t.rs != -1)
		if(t.ls != -1)
			if(t.rs != -1)
				if(tr[t.ls].k < tr[t.rs].k) l_t(u);
				else r_t(u);
			else l_t(u);
		else r_t(u);
	if(tr[t.f].ls == u) tr[t.f].ls = -1;
	else tr[t.f].rs = -1;
	for(int i = t.f; ~i; i = tr[i].f) u_d(i);
}
void del(int u){
	szh &t = tr[u];
	if(t.su != 1) t.su--, t.si--;
	else end(u);
}
void del(int u, int v){
	szh &t = tr[u];
	if(t.v == v){del(u); return;}
	if(t.v > v) del(t.ls, v);
	else del(t.rs, v);
	u_d(u);
}

int rak(int u, int k){
	szh &t = tr[u];
	if(t.v == k) return tr[t.ls].si + 1;
	if(t.v > k) return rak(t.ls, k);
	return rak(t.rs, k) + tr[t.ls].si + t.su;
}
int a_r(int u, int k){
	szh &t = tr[u];
	if(tr[t.ls].si >= k) return a_r(t.ls, k);
	int s = tr[t.ls].si + t.su;
	if(s >= k) return t.v;
	return a_r(t.rs, k - s);
}
int a_u(int u, int k){
	if(u == -1) return -inf;
	szh &t = tr[u];
	if(t.v >= k) return a_u(t.ls, k);
	return max(t.v, a_u(t.rs, k));
}
int a_l(int u, int k){
	if(u == -1) return inf;
	szh &t = tr[u];
	if(t.v > k) return min(t.v, a_l(t.ls, k));
	return a_l(t.rs, k);
}

int main(){
	n = read();
	srand(37022059);
	T_begin();
	while(n--){
		int a, b; a = read(); b = read();
		switch(a){
			case 1: add(0, b); break;
			case 2: del(0, b); break;
			case 3: printf("%d
", rak(0, b)); break;
			case 4: printf("%d
", a_r(0, b)); break;
			case 5: printf("%d
", a_u(0, b)); break;
			case 6: printf("%d
", a_l(0, b)); break;
		}
	}
	return 0;
}

(144)行,是我写过的最长的代码。

原文地址:https://www.cnblogs.com/kylinbalck/p/9909746.html