树套树学习笔记

树套树是处理区间问题/二维数点问题的一种常见的数据结构。

树套树也有很多种。最常见的一般就是线段树(树状数组/平衡树)套线段树(平衡树)共6种。

其实树套树的原理很简单,就是利用外层树的树高为 (O(log n)) 和内层树允许动态开点的性质。依次保证空间复杂度 (O(nlog^2 n))

具体来说,进行一次单点修改,对应就是在外层树对应点的所有祖先分别进行一次内层树的修改。由于树高的限制,这样一次的时间复杂度为 (O(log^2 n))

对于一次区间查询,利用线段树/平衡树的性质分到外层树的 (O(log n)) 个节点上,对于这些节点对应的内层树分别进行查询,同样时间复杂度也为 (O(log^2 n))

由此也可以看出,树套树处理的问题的局限性在于要求询问可以被分成 (log n) 个区间分别处理后合并

应用:
1.二维数点

由于区间 ([l,r]) 本质就是一个二元组,区间之间的包含/相交关系也对应一个矩形,所以很多的区间问题都可以转换成二维数点问题。

具体来说这类问题通常可以转换成:动态修改一个点/矩形,查询某个矩形内点的信息。

这个可以通过树套树实现。具体例题:[ZJOI2017] 树状数组

当然,你硬要说离线大法好我也没什么话说。

2.动态区间第k大

题目看着很吓人,但仔细一想好想也没什么。区间第k大可以用主席树完成,动态第k大可以用平衡树/值域线段树完成。

那么把两个套一起不就好了。考虑线段树套平衡树,某个外层节点的内层平衡树维护的是该节点对应区间的信息。

考虑之前那个性质,我们可以把区间分到若干外层点上,这样可以 (O(log^2 n)) 求出比k大的值。

但是我们要求的不是区间第 (k) 大吗?一种可行的方案要用线段树套值域线段树,是把这 (log n) 外层点揪出来,然后通过比较左子树大小和与k的关系判断范围。

这样是 (O(log nlog a)) 的,不过由于权值线段树的常数跑的并不快。事实上有一种偷懒的写法,直接二分区间第 (k) 大,然后求出区间比 (k) 大的数量即可。复杂度 (O(nlog^3 n)),可以勉强卡过。

例题:二逼平衡树_
这题中多了两个pre和nxt操作。这个其实很好处理,pre就是将 (log n) 个区间分别询问pre,然后取最大值即可。nxt同理。

总时间复杂度 (O(nlog^3 n))

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#define N 50010
#define M N*40
#define inf 2147483647
using namespace std;
int val[M],rnd[M],ch[M][2],siz[M],cnt;
void update(int u){siz[u]=siz[ch[u][0]]+siz[ch[u][1]]+1;}
void rot(int &u,int lf)
{
	int v=ch[u][lf];
	ch[u][lf]=ch[v][!lf],ch[v][!lf]=u;
	update(u),update(v);u=v;
}
int new_node(int v){int u=++cnt;siz[u]=1;rnd[u]=rand();val[u]=v;return u;}
void insert(int &u,int v)
{
	if(!u){u=new_node(v);return;}
	siz[u]++;
	if(v<=val[u]){insert(ch[u][0],v);if(rnd[ch[u][0]]<rnd[u]) rot(u,0);}
	else{insert(ch[u][1],v);if(rnd[ch[u][1]]<rnd[u]) rot(u,1);}
}
void erase(int &u,int v)
{
	if(val[u]==v)
	{
		if(!ch[u][0] || !ch[u][1]){u=ch[u][0]|ch[u][1];return;}
		if(rnd[ch[u][0]]>rnd[ch[u][1]]) rot(u,1),erase(ch[u][0],v);
		else rot(u,0),erase(ch[u][1],v);
	}
	else if(val[u]>v) erase(ch[u][0],v);
	else erase(ch[u][1],v);
	update(u);
}
int rnk(int u,int v)
{
	if(!u) return 1;
	if(val[u]>=v) return rnk(ch[u][0],v);
	else return rnk(ch[u][1],v)+siz[ch[u][0]]+1;
}
int pre(int u,int v)
{
	if(!u) return -inf;
	if(val[u]<v) return max(val[u],pre(ch[u][1],v));
	else return pre(ch[u][0],v);
}
int nxt(int u,int v)
{
	if(!u) return inf;
	if(val[u]>v) return min(val[u],nxt(ch[u][0],v));
	else return nxt(ch[u][1],v);
}
int root[N<<2],a[N];
void build(int u,int l,int r)
{
	for(int i=l;i<=r;i++) insert(root[u],a[i]);
	if(l==r) return;
	int mid=(l+r)>>1;
	build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
void insert(int u,int l,int r,int p,int v)
{
	erase(root[u],a[p]);
	insert(root[u],v);
	if(l==r){a[p]=v;return;}
	int mid=(l+r)>>1;
	if(p<=mid) insert(u<<1,l,mid,p,v);
	else insert(u<<1|1,mid+1,r,p,v);
}
int rnk(int u,int l,int r,int L,int R,int k)
{
	if(L<=l && r<=R) return rnk(root[u],k)-1;
	int mid=(l+r)>>1,ans=0;
	if(L<=mid) ans+=rnk(u<<1,l,mid,L,R,k);
	if(R>mid) ans+=rnk(u<<1|1,mid+1,r,L,R,k);
	return ans;
}
int pre(int u,int l,int r,int L,int R,int k)
{
	if(L<=l && r<=R) return pre(root[u],k);
	int mid=(l+r)>>1,ans=-inf;
	if(L<=mid) ans=max(ans,pre(u<<1,l,mid,L,R,k));
	if(R>mid) ans=max(ans,pre(u<<1|1,mid+1,r,L,R,k));
	return ans;
}
int nxt(int u,int l,int r,int L,int R,int k)
{
	if(L<=l && r<=R) return nxt(root[u],k);
	int mid=(l+r)>>1,ans=inf;
	if(L<=mid) ans=min(ans,nxt(u<<1,l,mid,L,R,k));
	if(R>mid) ans=min(ans,nxt(u<<1|1,mid+1,r,L,R,k));
	return ans;
}
int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		int opt,l,r,k;
		scanf("%d%d%d",&opt,&l,&r);
		if(opt==3) insert(1,1,n,l,r);
		else
		{
			scanf("%d",&k);
			if(opt==1) printf("%d
",rnk(1,1,n,l,r,k)+1);
			else if(opt==2)
			{
				int lf=0,rf=1e8,ans=0;
				while(lf<=rf)
				{
					int mid=(lf+rf)>>1,p=rnk(1,1,n,l,r,mid);
					if(p<k) lf=mid+1,ans=mid;
					else rf=mid-1;
				}
				printf("%d
",ans);
			}
			else if(opt==4) printf("%d
",pre(1,1,n,l,r,k));
			else if(opt==5) printf("%d
",nxt(1,1,n,l,r,k));
		}
	}
	return 0;
}
原文地址:https://www.cnblogs.com/Flying2018/p/13615844.html