线段树套平衡树

#include<bits/stdc++.h>
using namespace std;
const long long maxn=50010;
int n,m;
long long p[maxn];
inline long long read(){
    long long x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-'){
            f=-1;
		}
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}
int lson[maxn*40],rson[maxn*40];//左右儿子 
long long val[maxn*40];//值 
long long pri[maxn*40];//优先级 
int cnt[maxn*40];//个数 
int size[maxn*40];//子树大小 
long long tot,sum;
struct Treap{
	int rt;
	inline void zig(int &x){//左旋 
		long long y=lson[x];
		lson[x]=rson[y];
		rson[y]=x;
		size[y]=size[x];
		size[x]=size[lson[x]]+size[rson[x]]+cnt[x];
		x=y;
	}
	inline void zag(int &x){//右旋 
		long long y=rson[x];
		rson[x]=lson[y];
		lson[y]=x;
		size[y]=size[x];
		size[x]=size[lson[x]]+size[rson[x]]+cnt[x];
		x=y;
	}
	inline void ins(int &x,long long v){//在以x为根的树内加一个权值为 v 的点 
		if(!x){//x 是空节点 
			x=++sum;
			val[x]=v;
			pri[x]=rand();
			size[x]=cnt[x]=1;
			return;
		}
		size[x]++;
		if(v==val[x]){
			cnt[x]++;//考虑重复的点 
			sum++;
		}
		else if(v<val[x]){//比他小进左子树 
			ins(lson[x],v);
			if(pri[lson[x]]<pri[x]){
				zig(x);
			}
		}
		else{//否则进右子树 
			ins(rson[x],v);
			if(pri[rson[x]]<pri[x]){
				zag(x);
			}
		}
	}
	inline void del(int &x,long long v){
		if(val[x]==v){
			if(cnt[x]>1){//考虑重复的点 
				cnt[x]--;
				size[x]--;
			}
			else if(!lson[x]||!rson[x]){//链节点或叶子 
				x=lson[x]+rson[x];
			}
			else{//双儿子
				if(pri[lson[x]]<pri[rson[x]]){//右儿子优先级高就左旋 
					zig(x);
				}
				else zag(x);//反之则右旋 
				del(x,v);
			}
		}
		else{
			size[x]--;
			if(v<val[x]){
				del(lson[x],v);
			}
			else del(rson[x],v);
		}
	}
	inline long long rank(long long k){//比 k 小的数的个数 
		long long x=rt;//当前节点 
		long long temp=0;//已经找到的节点数 
		while(x){
			if(k==val[x]){//若找到则直接返回 
				return temp+size[lson[x]];
			}
			else if(k<val[x]){//目前的数大了就往左 
				x=lson[x];
			}
			else{//目前的数小了,将左子树和自己计入答案 
				temp+=size[lson[x]]+cnt[x];
				x=rson[x];
			}
		}
		return temp;
	}
	inline long long kth(int k){//第x小的数 
		int x=rt;//当前节点 
		while(x){
			if(size[lson[x]]<k&&size[lson[x]]+cnt[x]>=k){//若当前点就是答案则直接返回 
				return val[x];
			}
			if(size[lson[x]]>=k){//若左子树的 size 大于 kth 
				x=lson[x];
			}
			else{
				k-=size[lson[x]]+cnt[x];
				x=rson[x];
			}
		}
		return 0;
	}
	inline long long pre(long long k){//k 的前驱 
		long long x=rt;
		long long temp=-2147483647ll;
		while(x){//能往右就往右,反之跳到左儿子 
			if(val[x]<k){
				temp=val[x];
				x=rson[x];
			}
			else x=lson[x];
		}
		return temp;
	}
	inline long long nxt(long long k){//k 的后继 
		long long x=rt;
		long long temp=2147483647ll;
		while(x){//能往左就往左,反之跳到右儿子 
			if(val[x]>k){
				temp=val[x];
				x=lson[x];
			}
			else x=rson[x];
		}
		return temp;
	}
}a[maxn<<2];
void build(int k,int l,int r){//建树 
	for(int i=l;i<=r;i++){//对于线段树的区间直接暴力插入 
		a[k].ins(a[k].rt,p[i]);//每一个节点最多每层插入一次 总共不超过 n*log^2 
	}
	if(l==r){
		return;
	}
	build(k<<1,l,l+r>>1);
	build(k<<1|1,(l+r>>1)+1,r);
}
long long rnk(int k,int l,int r,int x,int y,long long num){//区间查排名 
	if(l>y||r<x){
		return 0;
	}
	if(x<=l&&r<=y){//在查询区间内 
		return a[k].rank(num);
	}
	long long ret=0;
	ret+=rnk(k<<1,l,l+r>>1,x,y,num);//线段树基操 
	ret+=rnk(k<<1|1,(l+r>>1)+1,r,x,y,num);
	return ret;
}
long long kth(int x,int y,long long v){//区间 kth 
	int l=0,r=1e8,ans=-1; 
	while(l<=r){//二分答案 
		long long mid=l+r>>1; 
		if(rnk(1,1,n,x,y,mid)<v){//若区间内此数排名 <v 
			ans=mid;
			l=mid+1;//则往右 
		}
		else r=mid-1;//反之往左 
	}
	return ans;
}
void modify(int k,int l,int r,int x,long long v){//修改 
	if(x<l||r<x){
		return;
	}
	a[k].del(a[k].rt,p[x]);//在 x 所属的所有平衡树内先 delete 原数,再 insert 修改后的数 
	a[k].ins(a[k].rt,v);
	if(l==r){
		return;
	}
	modify(k<<1,l,l+r>>1,x,v);
	modify(k<<1|1,(l+r>>1)+1,r,x,v);
}
long long pre(int k,int l,int r,int x,int y,long long v){//查前驱 
	if(l>y||r<x){
		return -2147483647ll;
	}
	if(x<=l&&r<=y){//在区间内则直接返回在此平衡树内的前驱 
		return a[k].pre(v);
	}
	return max(pre(k<<1,l,l+r>>1,x,y,v),pre(k<<1|1,(l+r>>1)+1,r,x,y,v));//答案即为左右两区间中前驱的最大值 
}
long long nxt(int k,int l,int r,int x,int y,long long v){//查后继同理 
	if(l>y||r<x){
		return 2147483647ll;
	}
	if(x<=l&&r<=y){
		return a[k].nxt(v);
	}
	return min(nxt(k<<1,l,l+r>>1,x,y,v),nxt(k<<1|1,(l+r>>1)+1,r,x,y,v));
}
int main(){
	srand(time(NULL));
	n=read();
	m=read();
	for(int i=1;i<=n;i++){
		p[i]=read();
	}
	build(1,1,n);
	for(int i=1;i<=m;i++){
		long long l,r,v,op=read();
		if(op==1){
			l=read();
			r=read();
			v=read();
			printf("%lld
",rnk(1,1,n,l,r,v)+1);
		}
		else if(op==2){
			l=read();
			r=read();
			v=read();
			printf("%lld
",kth(l,r,v));
		}
		else if(op==3){
			l=read();
			v=read();
			modify(1,1,n,l,v);
			p[l]=v;
		}
		else if(op==4){
			l=read();
			r=read();
			v=read();
			printf("%lld
",pre(1,1,n,l,r,v));
		}
		else{
			l=read();
			r=read();
			v=read();
			printf("%lld
",nxt(1,1,n,l,r,v));
		}
	}
	return 0;
}

原文地址:https://www.cnblogs.com/xiong-6/p/13750744.html