普通平衡树(AVL树)

某辣鸡考研党复习到平衡树时突然心血来潮想自己实现一下AVL树QAQ。快一年没敲代码了码力下降严重,断断续续写了好久QAQ。写了快300行,不过是全凭自己感觉写的,也算是完成了当年没完成的心愿吧(自己独立写出一种平衡树)。

代码:

#include <bits/stdc++.h>
#define ls(pos) tr[pos].ch[0]
#define rs(pos) tr[pos].ch[1]
#define Fa(pos) tr[pos].fa
using namespace std;
const int maxn = 101010;
struct AVL_tree {
	int tot = 0, root = 0;
	struct node {
		int ch[2], fa;
		int bal, val, dep, sz, cnt;
		
		void init() {
			ch[0] = ch[1] = fa = bal = val = dep = sz = cnt = 0;
		}
	};
	
	node tr[maxn];
	int Next_pos, pre_pos, Rank, ans;
	
	void init() {
		memset(tr, 0, sizeof(node));
	}
	
	int creat(int val, int fa) {
		int ret = 0;
//		if(st.size()) {
//			ret = st.top();
//			st.pop();
//		} else {
			ret = ++tot;
//		}
		tr[ret].val = val;
		tr[ret].fa = fa;
		tr[ret].sz = 1;
		tr[ret].bal = 0;
		tr[ret].dep = 1;
		tr[ret].cnt = 1;
		return ret;
	}
	
	int son (int pos) {
		if(ls(Fa(pos)) == pos) return 0;
		return 1;
	}
	
	void rrotate(int pos) {
		if(Fa(pos) != 0) {
			if(son(pos) == 0) {
				ls(Fa(pos)) = ls(pos);
			} else {
				rs(Fa(pos)) = ls(pos);
			}
		}
		Fa(ls(pos)) = Fa(pos);
		Fa(pos) = ls(pos);
		ls(pos) = rs(ls(pos));
		if(ls(pos)) Fa(ls(pos)) = pos;
		rs(Fa(pos)) = pos;
		maintain1(pos);
		maintain1(Fa(pos));
		if(pos == root) {
			root = Fa(pos);
		}
	}
	
	void lrotate(int pos) {
		if(Fa(pos) != 0) {
			if(son(pos) == 0) {
				ls(Fa(pos)) = rs(pos);
			} else {
				rs(Fa(pos)) = rs(pos);
			}	
		}
		Fa(rs(pos)) = Fa(pos);
		Fa(pos) = rs(pos);
		rs(pos) = ls(rs(pos));
		if(rs(pos)) Fa(rs(pos)) = pos;
		ls(Fa(pos)) = pos;
		maintain1(pos);
		maintain1(Fa(pos));
		if(pos == root) {
			root = Fa(pos);
		}
	}
	
	void rotate(int pos) {
		if(tr[pos].bal > 1) {
			if(tr[ls(pos)].bal > 0) {
				rrotate(pos);
			} else {
				lrotate(ls(pos));
				rrotate(pos);
			}
		} else {
			if(tr[rs(pos)].bal < 0) {
				lrotate(pos);
			} else {
				rrotate(rs(pos));
				lrotate(pos);
			}
		}
	}
	
	void maintain1(int pos) {
		if(pos == 0) return;
		tr[pos].dep = max(tr[ls(pos)].dep, tr[rs(pos)].dep) + 1;
		tr[pos].bal = tr[ls(pos)].dep - tr[rs(pos)].dep;
		tr[pos].sz = tr[ls(pos)].sz + tr[rs(pos)].sz + tr[pos].cnt;
	}
	
	void maintain(int pos) {
		if(pos == 0) return;
		maintain1(pos);
		if(tr[pos].bal > 1 || tr[pos].bal < -1) {
			rotate(pos);
		}
	}
	
	
	void insert(int pos, int val) {
		if(tr[pos].val == val) {
			tr[pos].cnt++;
		} else {
			if(val > tr[pos].val) {
				if(rs(pos) == 0) {
					rs(pos) = creat(val, pos);
				} else {
					insert(rs(pos), val);
				}
			} else {
				if(ls(pos) == 0) {
					ls(pos) = creat(val, pos);
				} else {
					insert(ls(pos), val);
				}
			}
		}
		maintain(pos);
	}
	
	int find(int pos, int x) {
		if(pos == 0) return pos;
		if(tr[pos].val == x) {
			return pos;
		}
		if(tr[pos].val > x) return find(ls(pos), x);
		else return find(rs(pos), x);
	}
	
	void pre(int pos, int x) {
		if(tr[pos].val >= x) {
			if(ls(pos)) pre(ls(pos), x);
		} else {
			pre_pos = pos;
			if(rs(pos)) pre(rs(pos), x);
		}
 	}
	
	void Next(int pos, int x) {
		if(tr[pos].val <= x) {
			if(rs(pos)) Next(rs(pos), x);
		} else {
			Next_pos = pos;
			if(ls(pos)) Next(ls(pos), x);
		}
	}
	
	bool del(int pos) {
		bool ret = false; 
		int s = son(pos), tmp = Fa(pos);
		if(!ls(pos) && !rs(pos)) {
			if(s == 0) ls(Fa(pos)) = 0;
			else rs(Fa(pos)) = 0;
			if(root == pos) root = 0;
			ret = true;
		}
		else if(ls(pos) == 0) {
			if(s == 0) ls(Fa(pos)) = rs(pos);
			else rs(Fa(pos)) = rs(pos);
			Fa(rs(pos)) = Fa(pos);
			if(pos == root) root = rs(pos);
			ret = true;
		}
		else if(rs(pos) == 0) {
			if(s == 0) ls(Fa(pos)) = ls(pos);
			else rs(Fa(pos)) = ls(pos);
			Fa(ls(pos)) = Fa(pos); 
			if(pos == root) root = ls(pos);
			ret = true;
		}
		if(ret) {
//			st.push(pos);
			tr[pos].init();
			return true;
		}
		return false;
	}
	
	void rank_of_val(int pos, int val) {
		if(!pos) return;
		if(tr[pos].val < val) {
			Rank += tr[ls(pos)].sz + tr[pos].cnt;
			rank_of_val(rs(pos), val);
		} else {
			rank_of_val(ls(pos), val);
		}
	}
	
	void val_of_rank(int pos, int remain) {
		if(!pos) 
			return;
		if(tr[ls(pos)].sz < remain) {
			if(tr[ls(pos)].sz + tr[pos].cnt >= remain) {
				ans = tr[pos].val;
				return;
			} else {
				remain -= tr[ls(pos)].sz + tr[pos].cnt;
				val_of_rank(rs(pos), remain);
			}
		} else {
			val_of_rank(ls(pos), remain);
		}
	}
	
	void erase(int pos) {
		int t = Fa(pos);
		if(tr[pos].cnt == 1) {
			if(!del(pos)) {
				//int tmp = Next(root, pos);
				int tmp = ls(pos);
				while(rs(tmp)) tmp = rs(tmp);
				t = Fa(tmp);
				tr[pos].val = tr[tmp].val;
				tr[pos].cnt = tr[tmp].cnt;
				del(tmp);
			}
		} else {
			tr[pos].cnt--;
			maintain1(pos);
		}
		while(t) {
			maintain(t);
			t = Fa(t);
		}
	}
};
AVL_tree solve;
int main() {
	srand(time(0));
	int n, x, y;
	cin >> n;
	solve.init();
	for (int i = 1; i <= n; i++) {
		cin >> x >> y; 
		if(x == 1) {
			if(solve.root == 0) {
				solve.root = solve.creat(y, solve.root);
			} else {
				solve.insert(solve.root, y);
			}
		}
		else if(x == 2) {
			y = solve.find(solve.root, y);
			if(y == 0) {
				printf("miss
");
			} else {
				solve.erase(y);
			}
		}
		else if (x == 3) {
			solve.Rank = 0;
			solve.rank_of_val(solve.root, y);
			printf("%d
", solve.Rank + 1);
		}
		else if(x == 4) {
			solve.ans = 0;
			solve.val_of_rank(solve.root, y);
			printf("%d
", solve.ans);
		}
		else if (x == 5) {
			solve.pre_pos = -1;
			solve.pre(solve.root, y);
			if(solve.pre_pos == -1) printf("not found
");
			printf("%d
", solve.tr[solve.pre_pos].val); 
		} 
		else if(x == 6) {
			solve.Next_pos = -1;
			solve.Next(solve.root, y);
			if(solve.Next_pos == -1) printf("not found
");
			printf("%d
", solve.tr[solve.Next_pos].val); 
		}
//		printf("root : %d
", solve.tr[solve.root].sz);
	}
}

  

原文地址:https://www.cnblogs.com/pkgunboat/p/13860926.html