树套树

树套树

这玩意没什么新东西,就是树里面再套树,但是码量极大,及其难调。

树套树本身也不是一种特定的数据结构,它是一种思想,将两个树套一起的思想。


具体怎么回事?

比如我们用线段树维护一个序列。这个线段树的每个节点都代表着一段子序列,我们对每个节点再开一棵平衡树维护这个序列,那么这个杂合子数据结构就叫 “线段树套平衡树”。

树套树分外层树内层树

外层树,就是最外面的的那颗树,它的每个节点都有一棵内层树维护。常见的一般用线段树,树状数组。

内层树,单独维护外层树各个节点信息的树。一般是某种平衡树。大部分时候我们可以直接用 STL

从定义上讲,你可以随便抓两个树套在一起。

这玩意真没啥新定义,所以我们以题目为纲,看一下这个东西的思想方法。


例题

树套树-Lite

2s/64M

请你写出一种数据结构,来维护一个长度为 (n) 的序列,其中需要提供以下操作:

  1. 1 pos x,将 (pos) 位置的数修改为 (x)
  2. 2 l r x,查询整数 (x) 在区间 ([l,r]) 内的前驱(前驱定义为小于 (x),且最大的数)。
    数列中的位置从左到右依次标号为 (1∼n)

区间 ([l,r]) 表示从位置 (l) 到位置 (r) 之间(包括两端点)的所有数字。

区间内排名为 (k) 的值指区间内从小到大排在第 (k) 位的数值。(位次从 (1) 开始)

输入格式

第一行包含两个整数 (n,m),表示数列长度以及操作次数。

第二行包含 (n) 个整数,表示有序数列。

接下来 (m) 行,每行包含一个操作指令,格式如题目所述。

输出格式

对于所有操作 (2),每个操作输出一个查询结果,每个结果占一行。

数据范围

(1≤n,m≤5×10^4,\ 1≤l≤r≤n,\ 1≤pos≤n,\ 0≤x≤10^8,)

有序数列中的数字始终满足在 ([0,10^8]) 范围内,
数据保证所有操作一定合法,所有查询一定有解。

输入样例:

5 3
3 4 2 1 5
2 2 4 4
1 3 5
2 2 4 4

输出样例:

2
1

解析

求区间内前驱,带修。

我们先从查询入手。

如果没有区间限制,我们可以迅速地利用 STL 中的 set 中的 lower_boundupper_bound 得到答案。

而考虑到,(x) 的前驱是指 “(<x) 的最大的数”,是一个带有最大值属性的值。

而最大值一般是可以合并的。或者就题而言,设已知 ([l,r])(x) 的前驱为 (p),则对于 (forall {[a,b] | [l,r]subset [a,b]}) ,都有 (x)([a,b]) 的前驱 (qge p)。并且,由于 (qin [a,b]) 所以 (q) 也一定是某个子区间内 (x) 的前驱。我们只需要利用线段树将整个区间分成几个零区间,每个区间单独求前驱 ,然后再将答案合并。

每一个区间内部求就比较简单了。直接套 set 即可。

再来说修改。

修改仍然是如同线段树一逐层修改。由于单点修改,我们每一层有且只有一个区间会被修改到。所以修改的次数复杂度和树的高度复杂度一样,都是 (O(log n))

#include <bits/stdc++.h>
using namespace std;

const int N=5e5+10,INF=1e8;

struct Node
{
	int l,r;
	multiset<int> s;//set(本质平衡树)维护当前的区间
} tree[N<<2];

int n,m;
int w[N];

#define lnode node<<1
#define rnode node<<1|1

void build(int node,int start,int end)//建树
{
	tree[node].l=start,tree[node].r=end;
	tree[node].s.insert(-INF),tree[node].s.insert(INF);//插入哨兵节点
	for(int i=start;i<=end;i++) tree[node].s.insert(w[i]);//将区间内的数逐个插入
	if(start==end) return ;
	int mid=start+end>>1;
	build(lnode,start,mid);
	build(rnode,mid+1,end);
}

void update(int node,int pos,int x)
{
	tree[node].s.erase(tree[node].s.find(w[pos]));//先将这个位置的数字删去
	tree[node].s.insert(x);//再插入我们想要的数字
	if(tree[node].l==tree[node].r) return ;
	int mid=tree[node].l+tree[node].r>>1;
	if(pos<=mid) update(lnode,pos,x);
	else update(rnode,pos,x);
}

int query(int node,int l,int r,int x)
{
	if(l<=tree[node].l&&tree[node].r<=r)//找对区间
	{
		auto its=tree[node].s.lower_bound(x);
		--its;//迭代器直接写了 auto,原本的迭代器名称又长又臭
		return *its;
	}
	int mid=tree[node].l+tree[node].r>>1,res=-INF;
	if(l<=mid) res=max(res,query(lnode,l,r,x));
	if(r>mid) res=max(res,query(rnode,l,r,x));
	return res;
}

int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		int opt;
		scanf("%d",&opt);
		if(opt==1)
		{
			int pos,x;
			scanf("%d%d",&pos,&x);
			update(1,pos,x);
			w[pos]=x;//将原数列中的数也要修改掉
		}
		if(opt==2)
		{
			int l,r,x;
			scanf("%d%d%d",&l,&r,&x);
			int ans=query(1,l,r,x);
			printf("%d
",ans);
		}
	}
	return 0;
}

很多时候真的没必要自己手写平衡树,STL 有的东西还是挺优秀的,并且还省下来大量的调试时间。

『模板』树套树

4s/128M

真的模板来了。

请你写出一种数据结构,来维护一个长度为 (n) 的数列,其中需要提供以下操作:

  1. l r x,查询整数 (x) 在区间 ([l,r]) 内的排名。
  2. l r k,查询区间 ([l,r]) 内排名为 (k) 的值。
  3. pos x,将 (pos) 位置的数修改为 (x)
  4. l r x,查询整数 (x) 在区间 ([l,r]) 内的前驱(前驱定义为小于 (x),且最大的数)。
  5. l r x,查询整数 (x) 在区间 ([l,r]) 内的后继(后继定义为大于 (x),且最小的数)。
    数列中的位置从左到右依次标号为 (1sim n)

区间 ([l,r]) 表示从位置 (l) 到位置 (r) 之间(包括两端点)的所有数字。

区间内排名为 (k) 的值指区间内从小到大排在第 (k) 位的数值。(位次从 (1) 开始)

输入格式

第一行包含两个整数 (n,m),表示数列长度以及操作次数。

第二行包含 (n) 个整数,表示有序数列。

接下来 (m) 行,每行包含一个操作指令,格式如题目所述。

输出格式

对于所有操作 (1,2,4,5),每个操作输出一个查询结果,每个结果占一行。

数据范围

(1≤n,m≤5×10^4,\ 1≤l≤r≤n,\ 1≤pos≤n,\ 1≤k≤r−l+1,\ 0≤x≤10^8,)

有序数列中的数字始终满足在 ([0,10^8]) 范围内,
数据保证所有操作一定合法,所有查询一定有解。

输入样例:

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

输出样例:

2
4
3
4
9

解析

区间查询第 (k) 小,(x) 的排名,(x) 的前驱后继,还带修。

趁这个机会我们着重看一下各个函数的实现。

  • 对于求 (x) 前驱,我们已经知道,一个数的前驱就是最大的比它小的数,这是一个最大值属性的信息,我们查询出各个子区间中小于 (x) 的数,然后在其中取最大值就可以得到整个区间的答案。

    求解单个区间的前驱当然可以使用平衡树。

int queryPre(int node,int l,int r,int x)//查找前驱
{
	if(l<=tr1[node].l&&tr1[node].r<=r) return getPre(tr1[node].root,x);//平衡树
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=-INF;
	if(l<=mid) res=max(res,queryPre(lnode,l,r,x));
	if(r>mid) res=max(res,queryPre(rnode,l,r,x));
	return res;
}
  • 对于求 (x) 后缀,后缀的定义是最小的比 (x) 大的数,求法与前驱相同。
int querySuc(int node,int l,int r,int x)//查找后继
{
	if(l<=tr1[node].l && tr1[node].r<=r) return getSuc(tr1[node].root,x);
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=INF;
	if(l<=mid) res=min(res,querySuc(lnode,l,r,x));
	if(r>mid) res=min(res,querySuc(rnode,l,r,x));
	return res;
}

求前驱后继我们已经在上一题中讲过,现在问题在求第 (k) 大和查询排名。

C++ STL 里面的所有平衡树一旦涉及到什么排名第 (k) 大之后就都不能用了。

  • 查询排名就是在问 ([L,R]) 中有多少个数小于 (x) ,个数加 (1) 就是 (x) 的排名。

    区间有多少个数小于 (x) ,这个东西就是可加的了,线段树可以维护。我们对每个区间建立一个平衡树(我用的是 Splay ),按大小关键字排序,可以得出该区间小于 (x) 的数字个数。单次查询复杂度 (O(log^2n))

int getRank(int node,int l,int r,int x)//查找区间内比 x 小的个数
{
	if(l<=tr1[node].l&&tr1[node].r<=r) return get_k(tr1[node].root,x)-1;//记得减去哨兵节点
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=0;
	if(l<=mid) res+=getRank(lnode,l,r,x);
	if(r>mid) res+=getRank(rnode,l,r,x);
	return res;
}
  • (k) 大。我们没有有效的办法通过合并区间信息便利地得到这个答案,所以我们可以二分答案。复杂度也才 (O(log^3 n))
if(opt==2)
{
	int l,r,k;
	scanf("%d%d%d",&l,&r,&k);
	int L=0,R=1e8;
	while(L<R)
	{
		int mid=L+R+1>>1;
		if(getRank(1,l,r,mid)+1<=k) L=mid;
		else R=mid-1;
	}
	printf("%d
",L);
}
  • 修改。这是单点修改,我们每次到达一个线段树区间,都直接寻找到这个区间要修改位置的数在平衡树中对应的节点,将其删去,然后再插入一个新的数。

    对于 Splay 的删去,我们可以找到要删去数 (x) ,将其转到根节点,然后就可以去找它的前驱和后继,将前驱转到根节点,后继转到根节点的右儿子,那么后继节点的左儿子就是我们要删去的节点。(这是 Splay 的内容,不会的去复习一下)

void update(int &root,int x,int y)//插入函数
{
	int u=root;
	while(u)//找到这个节点
	{
		if(tree[u].v==x) break;
		else if(tree[u].v>x) u=tree[u].s[0];
		else u=tree[u].s[1];
	}
	splay(root,u,0);
	int l=tree[u].s[0],r=tree[u].s[1];
	while(tree[l].s[1]) l=tree[l].s[1];
	while(tree[r].s[0]) r=tree[r].s[0];
	splay(root,l,0); splay(root,r,l);
	tree[r].s[0]=0;
	push_up(r),push_up(l);
	insert(root,y);
}
void change(int node,int pos,int x)//修改
{
	update(tr1[node].root,arr[pos],x);
	if(tr1[node].l==tr1[node].r) return ;
	int mid=tr1[node].l+tr1[node].r>>1;
	if(pos<=mid) change(lnode,pos,x);
	else change(rnode,pos,x);
}

这个题中所有的操作就是上面五种了。

一个不太优秀的完整实现:

#include <bits/stdc++.h>
using namespace std;

const int N=1800010,INF=2147483647;

/*----------splay部分----------*/
struct Node1
{
	int s[2],p,v,size;

	void init(int _v,int _p)
	{
		v=_v,p=_p;
		size=1;
	}

} tree[N<<1];
int idx=0;

void push_up(int node)
{
	tree[node].size=tree[tree[node].s[0]].size+tree[tree[node].s[1]].size+1;
}

void rotate(int x)//旋转
{
	int y=tree[x].p, z=tree[y].p;
	int k=tree[y].s[1]==x;
	tree[z].s[tree[z].s[1]==y]=x; tree[x].p=z;//x代y做z儿子
	tree[y].s[k]=tree[x].s[k^1], tree[tree[x].s[k^1]].p=y;//x y 子树互换
	tree[x].s[k^1]=y, tree[y].p=x;//y 做 x 儿子
	push_up(y),push_up(x);
}

void splay(int &root,int x,int k)
{
	while(tree[x].p!=k)
	{
		int y=tree[x].p, z=tree[y].p;
		if(z!=k)
		{
			if((tree[y].s[1]==x)^(tree[z].s[1]==y)) rotate(x);//判断折线形
			else rotate(y);
		}
		rotate(x);
	}
	if(!k) root=x;
}

void insert(int &root,int v)//插入
{
	int u=root,p=0;
	while(u) p=u,u=tree[u].s[tree[u].v<v];
	u=++idx;
	if(p) tree[p].s[v>tree[p].v]=u;
	tree[u].init(v,p);
	splay(root,u,0);
}

int get_k(int &root,int v)//查找比 v 小的数的个数
{
	int u=root,res=0;
	while(u)
	{
		if(tree[u].v<v) res+=tree[tree[u].s[0]].size+1,u=tree[u].s[1];
		else u=tree[u].s[0];
	}
	return res;
}

int getPre(int &root,int v)//查找最大的比 v 小的数
{
	int u=root,res=-INF;
	while(u)
	{
		if(tree[u].v<v) res=max(res,tree[u].v),u=tree[u].s[1];
		else u=tree[u].s[0];
	}
	return res;
}

int getSuc(int &root,int v)//查找最小的比 v 大的数
{
	int u=root,res=INF;
	while(u)
	{
		if(tree[u].v>v) res=min(res,tree[u].v),u=tree[u].s[0];
		else u=tree[u].s[1];
	}
	return res;
}

void update(int &root,int x,int y)//插入函数
{
	int u=root;
	while(u)
	{
		if(tree[u].v==x) break;
		else if(tree[u].v>x) u=tree[u].s[0];
		else u=tree[u].s[1];
	}
	splay(root,u,0);
	int l=tree[u].s[0],r=tree[u].s[1];
	while(tree[l].s[1]) l=tree[l].s[1];
	while(tree[r].s[0]) r=tree[r].s[0];
	splay(root,l,0); splay(root,r,l);
	tree[r].s[0]=0;
	push_up(r),push_up(l);
	insert(root,y);
}

/*----------线段树部分----------*/
struct Node2
{
	int l,r;
	int root;
} tr1[N<<1];

int n,m;
int arr[N];

#define lnode node<<1
#define rnode node<<1|1

void build(int node,int l,int r)//建树
{
	tr1[node].l=l,tr1[node].r=r;
	insert(tr1[node].root,INF); insert(tr1[node].root,-INF);//插入哨兵节点
	for(int i=l;i<=r;i++) insert(tr1[node].root,arr[i]);
	if(l==r) return ;
	int mid=l+r>>1;
	build(lnode,l,mid); build(rnode,mid+1,r);
}

int getRank(int node,int l,int r,int x)//查找区间内比 x 小的个数
{
	if(l<=tr1[node].l&&tr1[node].r<=r) return get_k(tr1[node].root,x)-1;//记得减去哨兵节点
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=0;
	if(l<=mid) res+=getRank(lnode,l,r,x);
	if(r>mid) res+=getRank(rnode,l,r,x);
	return res;
}

void change(int node,int pos,int x)//修改
{
	update(tr1[node].root,arr[pos],x);
	if(tr1[node].l==tr1[node].r) return ;
	int mid=tr1[node].l+tr1[node].r>>1;
	if(pos<=mid) change(lnode,pos,x);
	else change(rnode,pos,x);
}

int queryPre(int node,int l,int r,int x)//查找前驱
{
	if(l<=tr1[node].l&&tr1[node].r<=r) return getPre(tr1[node].root,x);
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=-INF;
	if(l<=mid) res=max(res,queryPre(lnode,l,r,x));
	if(r>mid) res=max(res,queryPre(rnode,l,r,x));
	return res;
}

int querySuc(int node,int l,int r,int x)//查找后继
{
	if(l<=tr1[node].l && tr1[node].r<=r) return getSuc(tr1[node].root,x);
	int mid=tr1[node].l+tr1[node].r>>1;
	int res=INF;
	if(l<=mid) res=min(res,querySuc(lnode,l,r,x));
	if(r>mid) res=min(res,querySuc(rnode,l,r,x));
	return res;
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&arr[i]);
	build(1,1,n);

	for(int i=1;i<=m;i++)
	{
		int opt;
		scanf("%d",&opt);
		if(opt==1)
		{
			int l,r,x;
			scanf("%d%d%d",&l,&r,&x);
			int ans=getRank(1,l,r,x)+1;
			printf("%d
",ans);
		}
		if(opt==2)
		{
			int l,r,k;
			scanf("%d%d%d",&l,&r,&k);
			int L=0,R=1e8;
			while(L<R)
			{
				int mid=L+R+1>>1;
				if(getRank(1,l,r,mid)+1<=k) L=mid;
				else R=mid-1;
			}
			printf("%d
",L);
		}
		if(opt==3)
		{
			int pos,x;
			scanf("%d%d",&pos,&x);
			change(1,pos,x);
			arr[pos]=x;
		}
		if(opt==4)
		{
			int l,r,x;
			scanf("%d%d%d",&l,&r,&x);
			int ans=queryPre(1,l,r,x);
			printf("%d
",ans);
		}
		if(opt==5)
		{
			int l,r,x;
			scanf("%d%d%d",&l,&r,&x);
			int ans=querySuc(1,l,r,x);
			printf("%d
",ans);
		}
	}
}

原文地址:https://www.cnblogs.com/IzayoiMiku/p/14679662.html