[bzoj3196]二逼平衡树

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)

树套树;

外层线段树,内层splay;

具体的做法是在线段树的每个节点上建立一颗splay,利用splay维护每个线段树节点上的信息;

线段树一共有logn层,每层的大小是n,空间复杂度是nlogn,时间复杂度是nlognlogn(?);

常数大得惊人,10s的时限,跑了9s9,汗。

实际上在这种不需要树的合并的场合,用treap就行了;

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<ctime>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<algorithm>
#include<iomanip>
#include<stack>
using namespace std;
#define FILE "dealing"
#define up(i,j,n) for(int i=(j);i<=(n);i++)
#define pii pair<int,int>
#define LL int
#define mem(f,g) memset(f,g,sizeof(f))
namespace IO{
	char buf[1<<15],*fs,*ft;
	int gc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?-1:*fs++;}
	int read(){
		int ch=gc(),f=0,x=0;
		while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=gc();}
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=gc();}
		return f?-x:x;
	}
	int readint(){
		int ch=getchar(),f=0,x=0;
		while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
		return f?-x:x;
	}
}using namespace IO;
const int maxn=4001000,inf=1000000000;
int n,m;
int a[maxn];
int c[maxn][2],v[maxn],siz[maxn],fa[maxn],root[maxn],t[maxn],cnt;
void updata(int x){siz[x]=siz[c[x][0]]+siz[c[x][1]]+t[x];}
void rotate(int x){
	int k=fa[x];
	int d=(c[k][1]==x);
	fa[x]=fa[k];fa[k]=x;fa[c[x][d^1]]=k;
	c[k][d]=c[x][d^1];c[x][d^1]=k;
	if(fa[x])c[fa[x]][c[fa[x]][1]==k]=x;
	updata(k);updata(x);
}
void splay(int x,int s,int rt){
	if(fa[x]==s)return;
	while(fa[x]!=s){
		if(fa[fa[x]]==s)rotate(x);
		else {
			int y=fa[x],z=fa[y];
			if(c[y][1]==x^c[z][1]==y)rotate(x);
			else rotate(y);
			rotate(x);
		}
	}
	if(!s)root[rt]=x;
}
void insert(int key,int rt){
	if(!root[rt]){
		root[rt]=++cnt,t[cnt]=siz[cnt]=1,v[cnt]=key,fa[cnt]=0;
		return;
	}
	int now=root[rt],y;
	while(now){
		if(key==v[now]){splay(now,0,rt),t[now]++,updata(now);return;}
		y=now,now=c[now][key>v[now]];
	}
	now=++cnt;t[now]=siz[now]=1;fa[now]=y;c[y][key>v[y]]=now;v[now]=key;
	splay(now,0,rt);
}
void build(int l,int r,int rt){
	if(l>r)return;
	insert(-inf,rt);//哨兵
	insert(inf<<1,rt);
	up(i,l,r)insert(a[i],rt);
	int mid=(l+r)>>1;
	if(l!=r){
		build(l,mid,rt<<1);
		build(mid+1,r,rt<<1|1);
	}
}
int x,y,sum,key,pos;
int find(int key,int rt){//在rt的子树中寻找最小的大于key的节点
	int now=root[rt],id=0;
	while(now){
		if(v[now]>key&&(v[now]<v[id]||!id))id=now;
		now=c[now][key>=v[now]];
	}
	splay(id,0,rt);
	return id;
}
int findpre(int key,int rt){//在rt的子树中寻找最大的小于key的节点
	int now=root[rt],id=0;
	while(now){
		if(v[now]<key&&(v[now]>v[id]||!id))id=now;
		now=c[now][key>v[now]];
	}
	splay(id,0,rt);
	return id;
}
void query_rank(int l,int r,int rt){
	if(l>y||r<x)return;
	if(x<=l&&r<=y){
		int now=find(key,rt);
		sum+=siz[c[now][0]]-1;
		return;
	}
	int mid=(l+r)>>1;
	query_rank(l,mid,rt<<1);
	query_rank(mid+1,r,rt<<1|1);
}
int getrank(int k,int l,int r){
	key=k;sum=0;x=l,y=r;
	query_rank(1,n,1);
	return sum;
}
int getK(int k,int l,int r){
	int x=0,y=inf,mid;
	while(x+1!=y){
		mid=(x+y)>>1;
		if(getrank(mid,l,r)<=k)x=mid;
		else y=mid;
	}
	if(getrank(x,l,r)<=k&&getrank(y,l,r)>k)return y;
	return x;
}
int getl(int x){while(c[x][0])x=c[x][0];return x;}
int getr(int x){while(c[x][1])x=c[x][1];return x;}
void delet(int k,int rt){
	int x=find(k-1,rt);
	splay(x,0,rt);
	int l=c[x][0],r=c[x][1];
	l=getr(l);r=getl(r);
	splay(l,0,rt);splay(r,l,rt);
	if(t[x]>1){t[x]--;updata(x);}
	else c[r][0]=0;
	updata(r),updata(l);
}
void Change(int l,int r,int rt){
	if(l>pos||r<pos)return;
	delet(a[pos],rt);
	insert(key,rt);
	int mid=(l+r)>>1;
	if(l!=r){
		Change(l,mid,rt<<1);
		Change(mid+1,r,rt<<1|1);
	}
	return;
}
void change(int p,int k){
	pos=p,key=k;
	Change(1,n,1);
}
void query_min(int l,int r,int rt){
	if(l>y||r<x)return;
	if(l>=x&&r<=y){
		int y=findpre(key,rt);
		if(v[y]>sum)sum=v[y];
		return;
	}
	int mid=(l+r)>>1;
	query_min(l,mid,rt<<1);
	query_min(mid+1,r,rt<<1|1);
}
int getleft(int k,int l,int r){
	key=k;x=l,y=r;sum=-inf;
	query_min(1,n,1);
	return sum;
}
void query_max(int l,int r,int rt){
	if(l>y||r<x)return;
	if(l>=x&&r<=y){
		int y=find(key,rt);
		if(v[y]<sum)sum=v[y];
		return;
	}
	int mid=(l+r)>>1;
	query_max(l,mid,rt<<1);
	query_max(mid+1,r,rt<<1|1);
}
int getright(int k,int l,int r){
	key=k;x=l,y=r;sum=inf;
	query_max(1,n,1);
	return sum;
}
int main(){
	freopen(FILE".in","r",stdin);
	freopen(FILE".out","w",stdout);
	n=read(),m=read();
	up(i,1,n)a[i]=read();
	build(1,n,1);
	int ch,l,r,k,pos;
	while(m--){
		ch=read();
		if(ch!=3)l=read(),r=read(),k=read();
		else pos=read(),k=read();
		switch (ch){
			case 1:printf("%d
",getrank(k-1,l,r)+1);break;
			case 2:printf("%d
",getK(k-1,l,r));break;
			case 3:change(pos,k);a[pos]=k;break;
			case 4:printf("%d
",getleft(k,l,r));break;
			case 5:printf("%d
",getright(k,l,r));break;
		}
	}
	return 0;
}

treap套树状数组,可能由于后两个操作由logn变成log^2n的缘故,时间上的进步并不明显(bzoj上的srand不能用,但在校内oj上测跑得比原来快了1s);

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<ctime>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<algorithm>
#include<iomanip>
#include<stack>
using namespace std;
#define FILE "dealing"
#define up(i,j,n) for(int i=(j);i<=(n);i++)
#define pii pair<int,int>
#define LL int
#define mem(f,g) memset(f,g,sizeof(f))
namespace IO{
	char buf[1<<15],*fs,*ft;
	int gc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?-1:*fs++;}
	int read(){
		int ch=gc(),f=0,x=0;
		while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=gc();}
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=gc();}
		return f?-x:x;
	}
	int readint(){
		int ch=getchar(),f=0,x=0;
		while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
		return f?-x:x;
	}
}using namespace IO;
const int maxn=2001000,inf=1000000000;
int n,m,C[maxn],a[maxn];
int lowbit(int x){return x&-x;}
namespace treap{
	int c[maxn][2],v[maxn],siz[maxn],t[maxn],cnt=0,w[maxn];
	void updata(int x){siz[x]=siz[c[x][0]]+siz[c[x][1]]+t[x];}
	void rotate(int &o,int d){int k=c[o][d];c[o][d]=c[k][d^1];c[k][d^1]=o;updata(o);updata(k);o=k;}
	void insert(int& o,int key){
		if(!o){o=++cnt;siz[o]=t[o]=1;v[o]=key;w[o]=rand();return;}
		if(v[o]==key){t[o]++;updata(o);return;}
		int d=(key>v[o]);
		insert(c[o][d],key);
		updata(o);
		if(w[c[o][d]]>w[o])rotate(o,d);
	}
	void delet(int& o,int key){
		if(v[o]==key){
			if(t[o]>1){t[o]--;updata(o);return;}
			if(!c[o][0])o=c[o][1];
			else if(!c[o][1])o=c[o][0];
			else {
				int d=(w[c[o][1]]>w[c[o][0]]);
				rotate(o,d);
				delet(c[o][d^1],key);
				updata(o);
			}
		}
		else delet(c[o][key>v[o]],key),updata(o);
	}
	int getrank(int o,int key){//在o所在的treap内有多少点的值小于等于key
		int ans=0;
		while(o){
			if(key>=v[o])ans+=siz[c[o][0]]+t[o],o=c[o][1];
			else o=c[o][0];
		}
		return ans;
	}
};
namespace Bit{
	int getrank(int key,int l,int r){//返回在[l,r]内有多少点值小于等于key
		int ans=0;l--;
		while(r)ans+=treap::getrank(C[r],key),r-=lowbit(r);
		while(l)ans-=treap::getrank(C[l],key),l-=lowbit(l);
		return ans;
	}
	int getK(int k,int l,int r){
		int left=0,right=inf,mid;
		while(left+1<right){
			mid=(left+right)>>1;
			if(getrank(mid,l,r)>k)right=mid;
			else left=mid;
		}
		if(getrank(left,l,r)<=k&&getrank(right,l,r)>=k)return right;
		return left;
	}
	void change(int pos,int key){
		int i=pos;
		while(pos<=n){
			treap::delet(C[pos],a[i]);
			treap::insert(C[pos],key);
			pos+=lowbit(pos);
		}
	}
	int getleft(int key,int l,int r){
		int k=getrank(key-1,l,r);
		int left=0,right=key,mid;
		while(left+1<right){
			mid=(left+right)>>1;
			if(getrank(mid,l,r)==k)right=mid;
			else left=mid;
		}
		if(getrank(right,l,r)==k&&getrank(left,l,r)<k)return right;
		else return left;
	}
	int getright(int key,int l,int r){
		int k=getrank(key,l,r);
		int left=key,right=inf,mid;
		while(left+1<right){
			mid=(left+right)>>1;
			if(getrank(mid,l,r)==k)left=mid;
			else right=mid;
		}
		if(getrank(right,l,r)>k&&getrank(left,l,r)==k)return right;
		else return left;
	}
};
int main(){
	int __size__ = 20 << 20; // 20MB
	char *__p__ = (char*)malloc(__size__) + __size__;
	__asm__("movl %0, %%esp
" :: "r"(__p__));
	n=read(),m=read();
	srand((int)time(NULL));
	up(i,1,n)a[i]=read();
	up(i,1,n){
		int k=i;
		while(k<=n)treap::insert(C[k],a[i]),k+=lowbit(k);
	}
	int ch,l,r,k,pos;
	int cnt=0;
	while(m--){
		ch=read();
		if(ch!=3)l=read(),r=read(),k=read(),cnt++;
		else pos=read(),k=read();
		switch (ch){
			case 1:printf("%d
",Bit::getrank(k-1,l,r)+1);break;
			case 2:printf("%d
",Bit::getK(k-1,l,r));break;
			case 3:Bit::change(pos,k);a[pos]=k;break;
			case 4:printf("%d
",Bit::getleft(k,l,r));break;
			case 5:printf("%d
",Bit::getright(k,l,r));break;
		}
	}
	return 0;
}

  

原文地址:https://www.cnblogs.com/chadinblog/p/6201679.html