P3380 【模板】二逼平衡树(树套树) 线段树套平衡树

(color{#0066ff}{ 题目描述 })

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询k在区间内的排名
  2. 查询区间内排名为k的值
  3. 修改某一位值上的数值
  4. 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
  5. 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

(color{#0066ff}{输入格式})

第一行两个数 n,m 表示长度为n的有序序列和m个操作

第二行有n个数,表示有序序列

下面有m行,opt表示操作标号

若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名

若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数

若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k

若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱

若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

(color{#0066ff}{输出格式})

对于操作1,2,4,5各输出一行,表示查询结果

(color{#0066ff}{输入样例})

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

(color{#0066ff}{输出样例})

2
4
3
4
9

(color{#0066ff}{数据范围与提示})

时空限制:2s,128M

(n,m leq 5cdot {10}^4)保证有序序列所有值在任何时刻满足 ([0, {10} ^8])

(color{#0066ff}{ 题解 })

可以线段树套平衡树

对于操作1,线段树每个区间在平衡树上找比k小的数的个数,加起来再加1就是排名

对于操作2,可以二分答案,然后通过操作1来判断(O(log^3n))

对于操作3,相当于删除再插入,注意线段树整个一条链都要改

对于操作4,5,线段树子区间答案取max和min即可

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
const int maxn = 5e4 + 10;
const int inf = 0x7fffffff;
struct Splay {
protected:
	struct node {
		node *ch[2], *fa;
		int val, siz;
		node(node *fa = NULL, int val = 0, int siz = 0): fa(fa), val(val), siz(siz) { ch[0] = ch[1] = NULL; }
		void upd() { siz = (ch[0]? ch[0]->siz : 0) + (ch[1]? ch[1]->siz : 0) + 1; }
		bool isr() { return this == fa->ch[1]; }
		int rk() { return ch[0]? ch[0]->siz + 1 : 1; }
	}*root;
	void rot(node *x) {
		node *y = x->fa, *z = y->fa;
		bool k = x->isr(); node *w = x->ch[!k];
		if(y != root) z->ch[y->isr()] = x;
		else root = x;
		x->ch[!k] = y, y->ch[k] = w;
		y->fa = x, x->fa = z;
		if(w) w->fa = y;
		y->upd(), x->upd();
	}
	void splay(node *o) {
		while(o != root) {
			if(o->fa != root) rot(o->isr() ^ o->fa->isr()? o : o->fa);
			rot(o);
		}
	}
	node *merge(node *x, node *y, node *fa) {
		if(x) x->fa = fa;
		if(y) y->fa = fa;
		if(!x || !y) return x? x : y;
		if(rand() & 1) return x->ch[1] = merge(x->ch[1], y, x), x->upd(), x;
		else return y->ch[0] = merge(x, y->ch[0], y), y->upd(), y;
	}
public:
	int rnk(int val) {
		node *o = root, *lst = root; int rank = 0;
		while(o) {
			lst = o;
			if(val > o->val) rank += o->rk(), o = o->ch[1];
			else o = o->ch[0];
		}
		return splay(lst), rank;
	}
	int kth(int k) {
		node *o = root;
		while(o->rk() != k) {
			if(k > o->rk()) k -= o->rk(), o = o->ch[1];
			else o = o->ch[0];
		}
		return splay(o), o->val;
	}
	int pre(int val) {
		node *o = root, *lst = root;
		while(o) {
			if(o->val < val) lst = o, o = o->ch[1];
			else o = o->ch[0];
		}
		return splay(lst), lst->val;
	}
	int nxt(int val) {
		node *o = root, *lst = root;
		while(o) {
			if(o->val > val) lst = o, o = o->ch[0];
			else o = o->ch[1];
		}
		return splay(lst), lst->val;
	}
	void ins(int val) {
		if(!root) return (void)(root = new node(NULL, val, 1));
		node *o = root, *fa = NULL;
		while(o) fa = o, o = o->ch[val > o->val];
		fa->ch[val > fa->val] = o = new node(fa, val, 1);
		splay(o);
	}
	void del(int val) {
		node *o = root;
		while(o->val != val) o = o->ch[val > o->val];
		if(!o) return;
		splay(o);
		root = merge(o->ch[0], o->ch[1], NULL);
		delete o;
	}
};
struct SGT {
private:
	struct node {
		int l, r;
		node *ch[2];
		Splay *s;
		node(int l = 0, int r = 0, Splay *s = NULL): l(l), r(r), s(s) { ch[0] = ch[1] = NULL; }
	}*root;
	void build(node *&o, int l, int r, int *a) {
		o = new node(l, r, new Splay());
		for(int i = l; i <= r; i++) o->s->ins(a[i]);
		if(l == r) return;
		int mid = (l + r) >> 1;
		build(o->ch[0], l, mid, a), build(o->ch[1], mid + 1, r, a);
	}
	int rnk(node *o, int l, int r, int val) {
		if(o->r < l || o->l > r) return 0;
		if(l <= o->l && o->r <= r) return o->s->rnk(val);
		return rnk(o->ch[0], l, r, val) + rnk(o->ch[1], l, r, val);
	}
	int pre(node *o, int l, int r, int val) {
		if(o->r < l || o->l > r) return inf;
		if(l <= o->l && o->r <= r) return o->s->pre(val);
		int ans = -inf;
		int L = pre(o->ch[0], l, r, val);
		int R = pre(o->ch[1], l, r, val);
		if(L < val) ans = std::max(ans, L);
		if(R < val) ans = std::max(ans, R);
		return ans;
	}
	int nxt(node *o, int l, int r, int val) {
		if(o->r < l || o->l > r) return -inf;
		if(l <= o->l && o->r <= r) return o->s->nxt(val);
		int ans = inf;
		int L = nxt(o->ch[0], l, r, val);
		int R = nxt(o->ch[1], l, r, val);
		if(L > val) ans = std::min(ans, L);
		if(R > val) ans = std::min(ans, R);
		return ans;
	}
	void change(node *o, int pos, int val, int old) {
		if(o->r < pos || o->l > pos) return;
		o->s->del(old);
		o->s->ins(val);
		if(o->l == o->r) return;
		change(o->ch[0], pos, val, old);
		change(o->ch[1], pos, val, old);
	}
public:
	void build(int *a, int l, int r) { build(root, l, r, a); }
	int rnk(int val, int l, int r) { return rnk(root, l, r, val) + 1; }
	int kth(int k, int L, int R) {
		int l = 0, r = 1e8, ans = 0;
		while(l <= r) {
			int mid = (l + r) >> 1;
			if(rnk(mid, L, R) <= k) ans = mid, l = mid + 1;
			else r = mid - 1;
		}
		return ans;
	}
	void change(int pos, int old, int now) { change(root, pos, now, old); }
	int pre(int val, int l, int r) { return pre(root, l, r, val); }
	int nxt(int val, int l, int r) { return nxt(root, l, r, val); }
}v;
int a[maxn];
int main() {
	int p, l, r, k, n = in(), m = in();
	for(int i = 1; i <= n; i++) a[i] = in();
	v.build(a, 1, n);
	while(m --> 0) {
		p = in();
		if(p == 1) l = in(), r = in(), k = in(), printf("%d
", v.rnk(k, l, r));
		if(p == 2) l = in(), r = in(), k = in(), printf("%d
", v.kth(k, l, r));
		if(p == 3) l = in(), k = in(), v.change(l, a[l], k), a[l] = k;
		if(p == 4) l = in(), r = in(), k = in(), printf("%d
", v.pre(k, l, r));
		if(p == 5) l = in(), r = in(), k = in(), printf("%d
", v.nxt(k, l, r));
	}
	return 0;
}
原文地址:https://www.cnblogs.com/olinr/p/10333100.html