替罪羊树

替罪羊树

一种基于部分重建的自平衡二叉搜索树。在替罪羊树上,插入或删除节点的平摊最坏时间复杂度是(O(log n)),搜索节点的最坏时间复杂度是(O(log n))

我们定义一个平衡树因子(alpha)。对于替罪羊树的每个节点(t),需要满足(max(siz[ls],siz[rs]<alpha * size[t])),其中(ls,rs)分别是(t)的左儿子,右儿子。

通俗的来讲,就是要保证每一个节点的左右子树的大小都不超过它本身大小的(alpha)倍,否则就把这个节点及它的子树重构,使其满足这个性质

一般取(alpha=0.75),使其达到最佳性能

#include<bits/stdc++.h>
using namespace std;
const int N = 100005;

struct node 
{
	int l, r, val, siz, cnt;
}nod[N];

int n, cnt, root;
vector<int> p;

void pushup(int rt)
{
	int l = nod[rt].l, r = nod[rt].r;
	nod[rt].siz = nod[l].siz + nod[r].siz;
}

int build(int l, int r)//按照线段树的方法建树 
{
	if(l > r) return 0;
	int mid = (l + r) >> 1;
	nod[p[mid]].l = build(l, mid - 1);
	nod[p[mid]].r = build(mid + 1, r);
	pushup(p[mid]);
	return p[mid];
}

void dfs(int rt)//要保证大小关系 
{
	if(rt == 0) return;
	dfs(nod[rt].l);
	p.push_back(rt);
	dfs(nod[rt].r);
}

void rebuild(int &rt)
{
	if(rt == 0) return;
	if(nod[rt].siz * 0.75 < nod[nod[rt].l].siz || nod[rt].siz * 0.75 < nod[nod[rt].r].siz)
	{
		p.clear();
		p.push_back(-1);//为了保证下标从1开始 
		dfs(rt);
		rt = build(1, p.size() - 1);
	}
}

void insert(int &rt, int x)
{
	if(rt == 0)
	{
		rt = ++cnt;
		nod[rt].val = x;
		nod[rt].siz = nod[rt].cnt = 1;
		return; 
	}
	rebuild(rt);
	if(nod[rt].val == x)
	{
		nod[rt].cnt ++;
		nod[rt].siz ++;
		return;
	}
	if(nod[rt].val < x)
	{
		insert(nod[rt].r, x);
		pushup(rt);
		return;
	}
	if(nod[rt].val > x)
	{
		insert(nod[rt].l, x);
		pushup(rt);
		return;
	}
}

int delmin(int &rt)
{
	if(nod[rt].l)//向左儿子跳 
	{
		int ret = delmin(nod[rt].l);
		pushup(rt);
		return ret;
	}
	int ret = rt;
	rt = nod[rt].r;//传址符 
	return ret;
}

void del(int &rt,int x)
{
	if(nod[rt].val > x)
	{
		del(nod[rt].l, x);
		pushup(rt); 
	}
	if(nod[rt].val < x)
	{
		del(nod[rt].r, x);
		pushup(rt);
	}
	if(nod[rt].val == x)
	{
		if(nod[rt].cnt > 1)
		{
			nod[rt].cnt --;
			nod[rt].siz --;
			return;
		}
		if(nod[rt].l == 0)
		{
			rt = nod[rt].r;
			return;
		}
		if(nod[rt].r == 0)
		{
			rt = nod[rt].l;
			return;
		}
		int tmp = delmin(nod[rt].r);
		nod[rt].val = nod[tmp].val;
		nod[rt].cnt = nod[tmp].cnt;
		pushup(rt);
		return;
	}
}

int getk(int rt, int x)
{
	if(nod[rt].val == x) return nod[nod[rt].l].siz + 1;
	if(nod[rt].val < x) return nod[nod[rt].l].siz + nod[rt].cnt + getk(nod[rt].r, x);
	if(nod[rt].val > x) return getk(nod[rt].l, x);
}

int getkth(int rt, int x)
{
	if(nod[nod[rt].l].siz + 1 <= x && x <= nod[nod[rt].l].siz + nod[rt].cnt) return nod[rt].val;
	if(nod[nod[rt].l].siz + 1 > x) return getkth(nod[rt].l, x);
	if(nod[nod[rt].l].siz + nod[rt].cnt < x) return getkth(nod[rt].r, x-(nod[nod[rt].l].siz + nod[rt].cnt));
}

int getpre(int rt, int x)
{
	int p = rt, ans;
	while(p)
	{
		if(x <= nod[p].val) p = nod[p].l;
		else 
		{
			ans = p;
			p = nod[p].r;
		}
	}
	return ans;
}
int getsuc(int rt, int x)
{
	int p = rt, ans;
	while(p)
	{
		if(nod[p].val <= x) p = nod[p].r;
		else 
		{
			ans = p;
			p = nod[p].l;
		}
	}
	return ans;
}

int main()
{
	scanf("%d", &n);
	while(n --)
	{
		int opt, x;
		scanf("%d%d", &opt, &x);
		if(opt == 1) insert(root, x);
		if(opt == 2) del(root, x);
		if(opt == 3) printf("%d
", getk(root, x));
		if(opt == 4) printf("%d
", getkth(root, x));
		if(opt == 5) printf("%d
", nod[getpre(root, x)].val);
		if(opt == 6) printf("%d
", nod[getsuc(root, x)].val);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/lcezych/p/12266508.html