数据结构:树套树-线段树套平衡树

BZOJ3196

这是第二道树套树的题了,如果说是树套树的板子题,其实也不过分,毕竟树套树应该算是数据结构,乃至整个OI里,最难写的之一

这种嵌套形式比较好理解,用线段树来与要查询的区间对齐 ,在每一个线段树节点(区间)内建立一棵平衡树来维护区间内的这些数据

其实刚开始我一直不明白,树套树,两层树都有信息,是不是要同步记录一起维护?其实不是这样的

我们回想替罪羊树套权值线段树,实现动态区间的第k大查询,我们外层平衡树的子树节点数信息是依靠内层权值线段树的sum来维护的

也就是说我要把主要维护的信息放在内层树,而其他的东西放在外面?

有时候是这样的,就比如这道题,所有的操作都是平衡树操作,那么肯定节点信息就都放在平衡树里了,这里的线段树只是一个和一棵棵平衡树对齐的架子而已,没有灵魂

但是可能以后做题多了就会知道,什么安排在内层树,什么安排在外层

扯远了。。

这道题比平衡树在外面的好理解很多,平衡树在外面里面的东西感觉乱乱的,都得拆了拼拼了拆。。

const int maxn=200005;
const int maxm=3000005;
const int INF=100000000;
int n,m,sz,tmp;
int a[maxn],s[maxm],w[maxm],v[maxm],rnd[maxm],lch[maxm],rch[maxm],root[maxn];

序列长度为n,m个询问,sz是平衡树节点数,tmp是全局变量(这个学会了)

a是初始的数据数组,s是平衡树子树节点数,w是平衡树每个节点的权,v是平衡树每个节点的值,rnd是Treap的标志来维护优先级堆结构的,lch和rch是平衡树的,root是每一线段树节点对应的平衡树的根

建树,建外层的线段树,那个空壳

void build(int k,int l,int r,int x,int num)
{
    insert(root[k],num);
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) build(k<<1,l,mid,x,num);
    else build(k<<1|1,mid+1,r,x,num);
}

我们看到线段树的每一个节点都插入了一个平衡树节点,也就是线段树的每一个区间都对应了一个平衡树

然后这里的线段树不是像上一题那样散称链表的,可以直接2k,2k+1了,数组无敌(如果写垃圾回收可能还得那么干)

在k这个线段树节点所引出来的平衡树的根是root[k],然后我们看一下往root[k]这棵树里面插入一个数,怎么搞的

void insert(int &k,int num)
{
    if(k==0) {k=++sz;s[k]=w[k]=1;v[k]=num;rnd[k]=rand();return;}
    s[k]++;
    if(num==v[k]) w[k]++;
    else if(num<v[k])
    {
        insert(lch[k],num);
        if(rnd[lch[k]]<rnd[k]) rturn(k);
    }
    else
    {
        insert(rch[k],num);
        if(rnd[rch[k]]<rnd[k]) lturn(k);
    }
}

这里超级简单,裸的Treap插入,具体细节详见本博客重量平衡树值Treap那篇

本题的左旋右旋和更新s数组的函数也奉上:

void update(int k)
{
    s[k]=s[lch[k]]+s[rch[k]]+w[k];
}
void rturn(int &k)
{
    int t=lch[k];
    lch[k]=rch[t];
    rch[t]=k;
    s[t]=s[k];
    update(k);
    k=t;
}
void lturn(int &k)
{
    int t=rch[k];
    rch[k]=lch[t];
    lch[t]=k;
    s[t]=s[k];
    update(k);
    k=t;
}

像这种LL,RR,几乎就是固定的,如果一时理解不了,就在纸上画画

第一个查询,得到k数的排名

void ask_rank(int k,int num)
{
    if(k==0) return;
    if(num==v[k]) {tmp+=s[lch[k]];return;}
    else if(num<v[k]) ask_rank(lch[k],num);
    else {tmp+=s[lch[k]]+w[k];ask_rank(rch[k],num);}
}
void get_rank(int k,int l,int r,int x,int y,int num)
{
    if(l==x&&r==y) {ask_rank(root[k],num);return;}
    int mid=(l+r)>>1;
    if(mid>=y) get_rank(k<<1,l,mid,x,y,num);
    else if(mid<x) get_rank(k<<1|1,mid+1,r,x,y,num);
    else
    {
        get_rank(k<<1,l,mid,x,mid,num);
        get_rank(k<<1|1,mid+1,r,mid+1,y,num);
    }
}

简介一下,线段树找到对应区间,然后直接在对应平衡树里找,因为是找排名,这里巧妙利用了tmp和s[]来记录结果

查询排名为x的数

void get_index(int x,int y,int z)
{
    int l=0,r=INF,ans;
    while(l<=r)
    {
        int mid=(l+r)>>1;
        tmp=1;get_rank(1,1,n,x,y,mid);
        if(tmp<=z) {l=mid+1;ans=mid;}
        else r=mid-1;
    }
    printf("%d
",ans);
}

这个就不能那么干了,所以借助get_rank,二分出来就可以了,还是这种二分法舒服

接着是修改,纯粹的区间点修改

void change(int k,int l,int r,int x,int num,int y)
{
    del(root[k],y);
    insert(root[k],num);
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) change(k<<1,l,mid,x,num,y);
    else change(k<<1|1,mid+1,r,x,num,y);
}

由于我改一个点,这个点要影响到logn个线段树节点,每个线段树节点都要有一棵平衡树

所以这里得递归了,改其实就是纯粹改这logn个平衡树了,先删值,再插值

删除的给出来:

void del(int &k,int num)
{
    if(v[k]==num)
    {
        if(w[k]>1){w[k]--;s[k]--;return;}
        if(lch[k]*rch[k]==0) k=lch[k]+rch[k];
        else if(rnd[lch[k]]<rnd[rch[k]]){rturn(k);del(k,num);}
        else {lturn(k);del(k,num);}
    }
    else if(num<v[k]){del(lch[k],num);s[k]--;}
    else {del(rch[k],num);s[k]--;}
}

这里的删除也是模板

查前驱,我有必要理解一下这里的max函数是个啥了,或者说前驱是个啥

void before(int k,int num)
{
    if(k==0) return;
    if(v[k]<num) {tmp=max(v[k],tmp);before(rch[k],num);}
    else before(lch[k],num);
}
void ask_before(int k,int l,int r,int x,int y,int num)
{
    if(l==x&&r==y){before(root[k],num);return;}
    int mid=(l+r)>>1;
    if(mid>=y) ask_before(k<<1,l,mid,x,y,num);
    else if(mid<x) ask_before(k<<1|1,mid+1,r,x,y,num);
    else
    {
        ask_before(k<<1,l,mid,x,mid,num);
        ask_before(k<<1|1,mid+1,r,mid+1,y,num);
    }
}

还是很容易看懂的,一个线段树区间查询,一个平衡树内求前驱

树套树就是麻烦,求答案得合并,很蛋疼,还是得借助tmp数组

当然,求后继是对称的

void after(int k,int num)
{
    if(k==0) return;
    if(v[k]>num) {tmp=min(v[k],tmp);after(lch[k],num);}
    else after(rch[k],num);
}
void ask_after(int k,int l,int r,int x,int y,int num)
{
    if(l==x&&r==y) {after(root[k],num);return;}
    int mid=(l+r)>>1;
    if(mid>=y) ask_after(k<<1,l,mid,x,y,num);
    else if(mid<x) ask_after(k<<1|1,mid+1,r,x,y,num);
    else
    {
        ask_after(k<<1,l,mid,x,mid,num);
        ask_after(k<<1|1,mid+1,r,mid+1,y,num);
    }    
}

这道树套树之二结束了,我要去进攻线段树套权值线段树了(顺便进攻一下嵌套版的二维线段树,四分树的是有毒)

下面给出完整的实现:

  1 #include<cstdio>
  2 #include<cstdlib>
  3 #include<algorithm>
  4 using namespace std;
  5 const int maxn=200005;
  6 const int maxm=3000005;
  7 const int INF=100000000;
  8 int n,m,sz,tmp;
  9 int a[maxn],s[maxm],w[maxm],v[maxm],rnd[maxm],lch[maxm],rch[maxm],root[maxn];
 10 void update(int k)
 11 {
 12     s[k]=s[lch[k]]+s[rch[k]]+w[k];
 13 }
 14 void rturn(int &k)
 15 {
 16     int t=lch[k];
 17     lch[k]=rch[t];
 18     rch[t]=k;
 19     s[t]=s[k];
 20     update(k);
 21     k=t;
 22 }
 23 void lturn(int &k)
 24 {
 25     int t=rch[k];
 26     rch[k]=lch[t];
 27     lch[t]=k;
 28     s[t]=s[k];
 29     update(k);
 30     k=t;
 31 }
 32 void insert(int &k,int num)
 33 {
 34     if(k==0) {k=++sz;s[k]=w[k]=1;v[k]=num;rnd[k]=rand();return;}
 35     s[k]++;
 36     if(num==v[k]) w[k]++;
 37     else if(num<v[k])
 38     {
 39         insert(lch[k],num);
 40         if(rnd[lch[k]]<rnd[k]) rturn(k);
 41     }
 42     else
 43     {
 44         insert(rch[k],num);
 45         if(rnd[rch[k]]<rnd[k]) lturn(k);
 46     }
 47 }
 48 void del(int &k,int num)
 49 {
 50     if(v[k]==num)
 51     {
 52         if(w[k]>1){w[k]--;s[k]--;return;}
 53         if(lch[k]*rch[k]==0) k=lch[k]+rch[k];
 54         else if(rnd[lch[k]]<rnd[rch[k]]){rturn(k);del(k,num);}
 55         else {lturn(k);del(k,num);}
 56     }
 57     else if(num<v[k]){del(lch[k],num);s[k]--;}
 58     else {del(rch[k],num);s[k]--;}
 59 }
 60 void build(int k,int l,int r,int x,int num)
 61 {
 62     insert(root[k],num);
 63     if(l==r) return;
 64     int mid=(l+r)>>1;
 65     if(x<=mid) build(k<<1,l,mid,x,num);
 66     else build(k<<1|1,mid+1,r,x,num);
 67 }
 68 void ask_rank(int k,int num)
 69 {
 70     if(k==0) return;
 71     if(num==v[k]) {tmp+=s[lch[k]];return;}
 72     else if(num<v[k]) ask_rank(lch[k],num);
 73     else {tmp+=s[lch[k]]+w[k];ask_rank(rch[k],num);}
 74 }
 75 void get_rank(int k,int l,int r,int x,int y,int num)
 76 {
 77     if(l==x&&r==y) {ask_rank(root[k],num);return;}
 78     int mid=(l+r)>>1;
 79     if(mid>=y) get_rank(k<<1,l,mid,x,y,num);
 80     else if(mid<x) get_rank(k<<1|1,mid+1,r,x,y,num);
 81     else
 82     {
 83         get_rank(k<<1,l,mid,x,mid,num);
 84         get_rank(k<<1|1,mid+1,r,mid+1,y,num);
 85     }
 86 }
 87 void get_index(int x,int y,int z)
 88 {
 89     int l=0,r=INF,ans;
 90     while(l<=r)
 91     {
 92         int mid=(l+r)>>1;
 93         tmp=1;get_rank(1,1,n,x,y,mid);
 94         if(tmp<=z) {l=mid+1;ans=mid;}
 95         else r=mid-1;
 96     }
 97     printf("%d
",ans);
 98 }
 99 void change(int k,int l,int r,int x,int num,int y)
100 {
101     del(root[k],y);
102     insert(root[k],num);
103     if(l==r) return;
104     int mid=(l+r)>>1;
105     if(x<=mid) change(k<<1,l,mid,x,num,y);
106     else change(k<<1|1,mid+1,r,x,num,y);
107 }
108 void before(int k,int num)
109 {
110     if(k==0) return;
111     if(v[k]<num) {tmp=max(v[k],tmp);before(rch[k],num);}
112     else before(lch[k],num);
113 }
114 void ask_before(int k,int l,int r,int x,int y,int num)
115 {
116     if(l==x&&r==y){before(root[k],num);return;}
117     int mid=(l+r)>>1;
118     if(mid>=y) ask_before(k<<1,l,mid,x,y,num);
119     else if(mid<x) ask_before(k<<1|1,mid+1,r,x,y,num);
120     else
121     {
122         ask_before(k<<1,l,mid,x,mid,num);
123         ask_before(k<<1|1,mid+1,r,mid+1,y,num);
124     }
125 }
126 void after(int k,int num)
127 {
128     if(k==0) return;
129     if(v[k]>num) {tmp=min(v[k],tmp);after(lch[k],num);}
130     else after(rch[k],num);
131 }
132 void ask_after(int k,int l,int r,int x,int y,int num)
133 {
134     if(l==x&&r==y) {after(root[k],num);return;}
135     int mid=(l+r)>>1;
136     if(mid>=y) ask_after(k<<1,l,mid,x,y,num);
137     else if(mid<x) ask_after(k<<1|1,mid+1,r,x,y,num);
138     else
139     {
140         ask_after(k<<1,l,mid,x,mid,num);
141         ask_after(k<<1|1,mid+1,r,mid+1,y,num);
142     }   
143 }
144 int main()
145 {
146     scanf("%d%d",&n,&m);
147     for(int i=1;i<=n;i++) scanf("%d",&a[i]);
148     for(int i=1;i<=n;i++) build(1,1,n,i,a[i]);
149     for(int i=1;i<=m;i++)
150     {
151         int f;scanf("%d",&f);
152         int x,y,k;
153         switch(f)
154         {
155             case 1:scanf("%d%d%d",&x,&y,&k);tmp=1;
156                 get_rank(1,1,n,x,y,k);printf("%d
",tmp);break;
157             case 2:scanf("%d%d%d",&x,&y,&k);
158                 get_index(x,y,k);break;
159             case 3:scanf("%d%d",&x,&y);change(1,1,n,x,y,a[x]);a[x]=y;break;
160             case 4:scanf("%d%d%d",&x,&y,&k);tmp=0;ask_before(1,1,n,x,y,k);printf("%d
",tmp);break;
161             case 5:scanf("%d%d%d",&x,&y,&k);tmp=INF;ask_after(1,1,n,x,y,k);printf("%d
",tmp);break;
162         }
163     }
164     return 0;
165 }
原文地址:https://www.cnblogs.com/aininot260/p/9368669.html