树套树三题 题解

1.COGS 1534 [NEERC 2004]K小数

其实是主席树裸题……

(其实这题数据非常水……从O(nlogn)的主席树到O(nlog3n)的树套树+二分到O(nsqrt(n)log2n)的分块套二分套二分到O(n2)的暴力都能过……)

鉴于这就是动态排名系统的静态版,就不说了,贴代码:

线段树套平衡树:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=100010;
struct node{
    int data,size;
    node *lc,*rc,*prt;
    node(int d=0):data(d),size(1),lc(NULL),rc(NULL),prt(NULL){}
    void refresh(){
        size=1;
        if(lc)size+=lc->size;
        if(rc)size+=rc->size;
    }
}*root[maxn<<2]={NULL};
void build(int,int,int);
void query(int,int,int);
void insert(node*,int);
void copy(int,node*);
int rank(int,int);
void splay(node*,node*,int);
void lrot(node*,int);
void rrot(node*,int);
int n,m,x,l,r,k,s,t,tmp,ans;
int main(){
#define MINE
#ifdef MINE
    freopen("kthnumber.in","r",stdin);
    freopen("kthnumber.out","w",stdout);
#endif
    scanf("%d%d",&n,&m);
    build(1,n,1);
    while(m--){
        scanf("%d%d%d",&s,&t,&k);
        l=-(int)(1e9+0.5);r=-l;
        while(l<=r){
            ans=(l+r)>>1;
            tmp=0;
            query(1,n,1);
            if(tmp<k)l=ans+1;
            else r=ans-1;
        }
        printf("%d
",r);
    }
#ifndef MINE
    printf("
--------------------DONE--------------------
");
    for(;;);
#else
    fclose(stdin);
    fclose(stdout);
#endif
    return 0;
}
void build(int l,int r,int rt){
    if(l==r){
        int x;
        scanf("%d",&x);
        insert(new node(x),rt);
        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    copy(rt,root[rt<<1]);
    copy(rt,root[rt<<1|1]);
}
void query(int l,int r,int rt){
    if(s<=l&&t>=r){
        tmp+=rank(ans,rt);
        return;
    }
    int mid=(l+r)>>1;
    if(s<=mid)query(l,mid,rt<<1);
    if(t>mid)query(mid+1,r,rt<<1|1);
}
void copy(int rt,node *x){
    if(!x)return;
    insert(new node(x->data),rt);
    copy(rt,x->lc);
    copy(rt,x->rc);
}
void insert(node *x,int i){
    if(!root[i]){
        root[i]=x;
        return;
    }
    node *rt=root[i];
    for(;;){
        if(x->data<rt->data){
            if(rt->lc)rt=rt->lc;
            else{
                rt->lc=x;
                break;
            }
        }
        else{
            if(rt->rc)rt=rt->rc;
            else{
                rt->rc=x;
                break;
            }
        }
    }
    x->prt=rt;
    while(rt){
        rt->refresh();
        rt=rt->prt;
    }
    splay(x,NULL,i);
}
int rank(int x,int i){
    node *rt=root[i],*y=rt;
    int ans=0;
    while(rt){
        y=rt;
        if(x<=rt->data)rt=rt->lc;
        else{
            if(rt->lc)ans+=rt->lc->size;
            ans++;
            rt=rt->rc;
        }
    }
    splay(y,NULL,i);
    return ans;
}
void splay(node *x,node *tar,int i){
    if(!x)return;
    for(node *rt=x->prt;rt!=tar;rt=x->prt){
        if(rt->prt==tar){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            break;
        }
        if(rt==rt->prt->lc){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            rrot(x->prt,i);
        }
        else{
            if(x==rt->rc)lrot(rt,i);
            else rrot(rt,i);
            lrot(x->prt,i);
        }
    }
}
void lrot(node *x,int i){
    node *y=x->rc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->rc=y->lc;
    if(y->lc)y->lc->prt=x;
    y->lc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
void rrot(node *x,int i){
    node *y=x->lc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->lc=y->rc;
    if(y->rc)y->rc->prt=x;
    y->rc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
View Code

树状数组套平衡树:

#include<cstdio>
#include<cstring>
#include<algorithm>
#define lowbit(x) ((x)&(-(x)))
using namespace std;
const int maxn=100010;
struct node{
    int data,size;
    node *lc,*rc,*prt;
    node(int d=0):data(d),size(1),lc(NULL),rc(NULL),prt(NULL){}
    void refresh(){
        size=1;
        if(lc)size+=lc->size;
        if(rc)size+=rc->size;
    }
}*root[maxn]={NULL};
void build(int,int);
int query(int,int);
void insert(node*,int);
int rank(int,int);
void splay(node*,node*,int);
void lrot(node*,int);
void rrot(node*,int);
int n,m,x,l,r,k,s,t,ans;
int main(){
#define MINE
#ifdef MINE
    freopen("kthnumber.in","r",stdin);
    freopen("kthnumber.out","w",stdout);
#endif
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%d",&x);
        build(i,x);
    }
    while(m--){
        scanf("%d%d%d",&l,&r,&k);
        s=-1000000000;t=1000000000;
        while(s<=t){
            ans=(s+t)>>1;
            if(query(r,ans)-query(l-1,ans)<k)s=ans+1;
            else t=ans-1;
        }
        printf("%d
",t);
    }
#ifndef MINE
    printf("
--------------------DONE--------------------
");
    for(;;);
#else
    fclose(stdin);
    fclose(stdout);
#endif
    return 0;
}
void build(int x,int d){
    while(x<=n){
        insert(new node(d),x);
        x+=lowbit(x);
    }
}
int query(int x,int d){
    int ans=0;
    while(x){
        ans+=rank(d,x);
        x-=lowbit(x);
    }
    return ans;
}
void insert(node *x,int i){
    if(!root[i]){
        root[i]=x;
        return;
    }
    node *rt=root[i];
    for(;;){
        if(x->data<rt->data){
            if(rt->lc)rt=rt->lc;
            else{
                rt->lc=x;
                break;
            }
        }
        else{
            if(rt->rc)rt=rt->rc;
            else{
                rt->rc=x;
                break;
            }
        }
    }
    x->prt=rt;
    while(rt){
        rt->refresh();
        rt=rt->prt;
    }
    splay(x,NULL,i);
}
int rank(int x,int i){
    node *rt=root[i],*y=rt;
    int ans=0;
    while(rt){
        y=rt;
        if(x<=rt->data)rt=rt->lc;
        else{
            if(rt->lc)ans+=rt->lc->size;
            ans++;
            rt=rt->rc;
        }
    }
    splay(y,NULL,i);
    return ans;
}
void splay(node *x,node *tar,int i){
    if(!x)return;
    for(node *rt=x->prt;rt!=tar;rt=x->prt){
        if(rt->prt==tar){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            break;
        }
        if(rt==rt->prt->lc){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            rrot(x->prt,i);
        }
        else{
            if(x==rt->rc)lrot(rt,i);
            else rrot(rt,i);
            lrot(x->prt,i);
        }
    }
}
void lrot(node *x,int i){
    node *y=x->rc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->rc=y->lc;
    if(y->lc)y->lc->prt=x;
    y->lc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
void rrot(node *x,int i){
    node *y=x->lc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->lc=y->rc;
    if(y->rc)y->rc->prt=x;
    y->rc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
View Code

2.COGS 257 动态排名系统

其实正解据说是树状数组套主席树……然而鉴于代码难度太大没敢写,写的是线段树套平衡树+二分答案……

线段树的每个节点存一棵平衡树,表示左右儿子的平衡树“相加”,这样查询第k小的数就可以借助二分+线段树查询区间中比答案小的数实现。

(其实这是O(log3n)的……)

至于单点修改,我们借助线段树的单点修改功能,把所有相关节点的平衡树进行更新(删掉原来的数,插入新的数),复杂度O(log2n)。

代码:

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<algorithm>
  4 using namespace std;
  5 const int maxn=50010;
  6 struct node{
  7     int data,size;
  8     node *lc,*rc,*prt;
  9     node(int d=0):data(d),size(1),lc(NULL),rc(NULL),prt(NULL){}
 10     void refresh(){
 11         size=1;
 12         if(lc)size+=lc->size;
 13         if(rc)size+=rc->size;
 14     }
 15 }*root[maxn<<2];
 16 void build(int,int,int);
 17 void mset(int,int,int);
 18 node *find(int,int);
 19 void erase(node*,int);
 20 void qsum(int,int,int);
 21 void insert(node*,int);
 22 void copy(int,node*);
 23 int rank(int,int);
 24 void splay(node*,node*,int);
 25 void lrot(node*,int);
 26 void rrot(node*,int);
 27 node *findmax(node*);
 28 int T,n,m,a[maxn],x,l,r,k,s,t,tmp,ans;
 29 char c;
 30 int main(){
 31 #define MINE
 32 #ifdef MINE
 33     freopen("dynrank.in","r",stdin);
 34     freopen("dynrank.out","w",stdout);
 35 #endif
 36     scanf("%d",&T);
 37     while(T--){
 38         memset(root,0,sizeof(root));
 39         scanf("%d%d",&n,&m);
 40         build(1,n,1);
 41         while(m--){
 42             scanf(" %c",&c);
 43             if(c=='C'){
 44                 scanf("%d%d",&x,&tmp);
 45                 mset(1,n,1);
 46                 a[x]=tmp;
 47             }
 48             else if(c=='Q'){
 49                 scanf("%d%d%d",&s,&t,&k);
 50                 l=1;r=1000000000;
 51                 while(l<=r){
 52                     ans=(l+r)>>1;
 53                     tmp=0;
 54                     qsum(1,n,1);
 55                     if(tmp<k)l=ans+1;
 56                     else r=ans-1;
 57                 }
 58                 printf("%d
",r);
 59             }
 60         }
 61     }
 62 #ifndef MINE
 63     printf("
--------------------DONE--------------------
");
 64     for(;;);
 65 #else
 66     fclose(stdin);
 67     fclose(stdout);
 68 #endif
 69     return 0;
 70 }
 71 void build(int l,int r,int rt){
 72     if(l==r){
 73         scanf("%d",&a[l]);
 74         insert(new node(a[l]),rt);
 75         return;
 76     }
 77     int mid=(l+r)>>1;
 78     build(l,mid,rt<<1);
 79     build(mid+1,r,rt<<1|1);
 80     copy(rt,root[rt<<1]);
 81     copy(rt,root[rt<<1|1]);
 82 }
 83 void mset(int l,int r,int rt){
 84     erase(find(a[x],rt),rt);
 85     insert(new node(tmp),rt);
 86     if(l==r)return;
 87     int mid=(l+r)>>1;
 88     if(x<=mid)mset(l,mid,rt<<1);
 89     else mset(mid+1,r,rt<<1|1);
 90 }
 91 void qsum(int l,int r,int rt){
 92     if(s<=l&&t>=r){
 93         tmp+=rank(ans,rt);
 94         return;
 95     }
 96     int mid=(l+r)>>1;
 97     if(s<=mid)qsum(l,mid,rt<<1);
 98     if(t>mid)qsum(mid+1,r,rt<<1|1);
 99 }
100 void copy(int rt,node *x){
101     if(!x)return;
102     insert(new node(x->data),rt);
103     copy(rt,x->lc);
104     copy(rt,x->rc);
105 }
106 void insert(node *x,int i){
107     if(!root[i]){
108         root[i]=x;
109         return;
110     }
111     node *rt=root[i];
112     for(;;){
113         if(x->data<rt->data){
114             if(rt->lc)rt=rt->lc;
115             else{
116                 rt->lc=x;
117                 break;
118             }
119         }
120         else{
121             if(rt->rc)rt=rt->rc;
122             else{
123                 rt->rc=x;
124                 break;
125             }
126         }
127     }
128     x->prt=rt;
129     while(rt){
130         rt->refresh();
131         rt=rt->prt;
132     }
133     splay(x,NULL,i);
134 }
135 node *find(int x,int i){
136     node *rt=root[i];
137     while(rt){
138         if(x<rt->data)rt=rt->lc;
139         else if(x>rt->data)rt=rt->rc;
140         else return rt;
141     }
142     return NULL;
143 }
144 void erase(node *x,int i){
145     splay(x,NULL,i);
146     splay(findmax(x->lc),x,i);
147     if(x->lc){
148         x->lc->rc=x->rc;
149         x->lc->prt=NULL;
150         root[i]=x->lc;
151         x->lc->refresh();
152     }
153     else root[i]=x->rc;
154     if(x->rc)x->rc->prt=x->lc;
155     delete x;
156 }
157 int rank(int x,int i){
158     node *rt=root[i],*y=rt;
159     int ans=0;
160     while(rt){
161         y=rt;
162         if(x<=rt->data)rt=rt->lc;
163         else{
164             if(rt->lc)ans+=rt->lc->size;
165             ans++;
166             rt=rt->rc;
167         }
168     }
169     splay(y,NULL,i);
170     return ans;
171 }
172 void splay(node *x,node *tar,int i){
173     if(!x)return;
174     for(node *rt=x->prt;rt!=tar;rt=x->prt){
175         if(rt->prt==tar){
176             if(x==rt->lc)rrot(rt,i);
177             else lrot(rt,i);
178             break;
179         }
180         if(rt==rt->prt->lc){
181             if(x==rt->lc)rrot(rt,i);
182             else lrot(rt,i);
183             rrot(x->prt,i);
184         }
185         else{
186             if(x==rt->rc)lrot(rt,i);
187             else rrot(rt,i);
188             lrot(x->prt,i);
189         }
190     }
191 }
192 void lrot(node *x,int i){
193     node *y=x->rc;
194     if(x->prt){
195         if(x==x->prt->lc)x->prt->lc=y;
196         else x->prt->rc=y;
197     }
198     else root[i]=y;
199     y->prt=x->prt;
200     x->rc=y->lc;
201     if(y->lc)y->lc->prt=x;
202     y->lc=x;
203     x->prt=y;
204     x->refresh();
205     y->refresh();
206 }
207 void rrot(node *x,int i){
208     node *y=x->lc;
209     if(x->prt){
210         if(x==x->prt->lc)x->prt->lc=y;
211         else x->prt->rc=y;
212     }
213     else root[i]=y;
214     y->prt=x->prt;
215     x->lc=y->rc;
216     if(y->rc)y->rc->prt=x;
217     y->rc=x;
218     x->prt=y;
219     x->refresh();
220     y->refresh();
221 }
222 node *findmax(node *x){
223     if(!x)return NULL;
224     while(x->rc)x=x->rc;
225     return x;
226 }
View Code

3.COGS 1594 & bzoj 3196 & Tyvj 1730 二逼平衡树

其实这就是动态排名系统的加强版……

据说正解仍然是树状数组套主席树,然而由于代码能力太烂的原因仍然没敢打,用的还是线段树套平衡树。

(ps:以上两题都可以用树状数组代替外层的平衡树,但是这里必须用线段树,因为求前驱后继时要求不多不少刚好覆盖整个区间,而树状数组只能先多求出左边那段然后再减掉,对于不满足可减性的前驱后继无能为力。)

查询k小数和单点修改同上,新增的三个操作也很容易:

查询排名直接利用线段树套平衡树求区间内比k小的数的个数即可。

查询前驱直接利用对线段树分解成的每个子区间进行平衡树的求前驱操作即可。

查询后继同理。

平衡树一开始用的Splay,然而T了,所以换成AVL树,依然T了……

所以加了快读快写和inline卡常大法,终于AC。

(我只想说加了inline瞬间提速25%是什么鬼……)

贴个代码,我估计这是我OI生涯中写过的最长的代码了……

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<algorithm>
  4 using namespace std;
  5 namespace mine{
  6     inline int getint(){
  7         static int __c,__x;
  8         static bool __neg;
  9         __x=0;
 10         __neg=false;
 11         do __c=getchar();while(__c==' '||__c=='
'||__c=='
'||__c=='	');
 12         if(__c=='-'){
 13             __neg=true;
 14             __c=getchar();
 15         }
 16         for(;__c>='0'&&__c<='9';__c=getchar())__x=__x*10+(__c^48);
 17         if(__neg)return -__x;
 18         return __x;
 19     }
 20     inline void putint(int __x){
 21         static int __a[40],__i,__j;
 22         static bool __neg;
 23         __neg=__x<0;
 24         if(__neg)__x=-__x;
 25         __i=0;
 26         do{
 27             __a[__i++]=__x%10+48;
 28             __x/=10;
 29         }while(__x);
 30         if(__neg)putchar('-');
 31         for(__j=__i-1;__j^-1;__j--)putchar(__a[__j]);
 32     }
 33 }
 34 using namespace mine;
 35 const int maxn=50010;
 36 struct node{
 37     int data,size,h;
 38     node *lc,*rc,*prt;
 39     node(int d=0):data(d),size(1),h(1),lc(NULL),rc(NULL),prt(NULL){}
 40     inline void refresh(){
 41         size=h=1;
 42         if(lc){
 43             size+=lc->size;
 44             h=max(h,lc->h+1);
 45         }
 46         if(rc){
 47             size+=rc->size;
 48             h=max(h,rc->h+1);
 49         }
 50     }
 51     inline int bal(){
 52         if(lc&&rc)return lc->h-rc->h;
 53         if(lc&&!rc)return lc->h;
 54         if(!lc&&rc)return -rc->h;
 55         return 0;
 56     }
 57 }*root[maxn<<2]={NULL};
 58 void build(int,int,int);
 59 void mset(int,int,int);
 60 void qsum(int,int,int);
 61 void qpred(int,int,int);
 62 void qsucc(int,int,int);
 63 node *find(int,int);
 64 void erase(node*,int);
 65 void insert(node*,int);
 66 void copy(int,node*);
 67 int rank(int,int);
 68 node *pred(int,int);
 69 node *succ(int,int);
 70 void lrot(node*,int);
 71 void rrot(node*,int);
 72 node *findmax(node*);
 73 int T,n,m,a[maxn],x,d,l,r,k,s,t,tmp,ans;
 74 char c;
 75 int main(){
 76 #define MINE
 77 #ifdef MINE
 78     freopen("psh.in","r",stdin);
 79     freopen("psh.out","w",stdout);
 80 #endif
 81     n=getint();
 82     m=getint();
 83     build(1,n,1);
 84     while(m--){
 85         d=getint();
 86         if(d==1){//查询k在区间内的排名
 87             s=getint();
 88             t=getint();
 89             ans=getint();
 90             tmp=1;
 91             qsum(1,n,1);
 92             putint(tmp);
 93         }
 94         else if(d==2){//查询区间内排名为k的值
 95             s=getint();
 96             t=getint();
 97             k=getint();
 98             l=0;r=100000000;
 99             while(l<=r){
100                 ans=(l+r)>>1;
101                 tmp=0;
102                 qsum(1,n,1);
103                 if(tmp<k)l=ans+1;
104                 else r=ans-1;
105             }
106             putint(r);
107         }
108         else if(d==3){//修改某一位值上的数值
109             x=getint();
110             tmp=getint();
111             mset(1,n,1);
112             a[x]=tmp;
113         }
114         else if(d==4){//查询k在区间内的前驱(前驱定义为小于x,且最大的数)
115             s=getint();
116             t=getint();
117             ans=getint();
118             tmp=0;
119             qpred(1,n,1);
120             putint(tmp);
121         }
122         else if(d==5){//查询k在区间内的后继(后继定义为大于x,且最小的数)
123             s=getint();
124             t=getint();
125             ans=getint();
126             tmp=100000000;
127             qsucc(1,n,1);
128             putint(tmp);
129         }
130         if(d!=3)putchar('
');
131     }
132 #ifndef MINE
133     printf("
--------------------DONE--------------------
");
134     for(;;);
135 #endif
136     return 0;
137 }
138 inline void build(int l,int r,int rt){
139     if(l==r){
140         scanf("%d",&a[l]);
141         insert(new node(a[l]),rt);
142         return;
143     }
144     int mid=(l+r)>>1;
145     build(l,mid,rt<<1);
146     build(mid+1,r,rt<<1|1);
147     copy(rt,root[rt<<1]);
148     copy(rt,root[rt<<1|1]);
149 }
150 inline void mset(int l,int r,int rt){
151     erase(find(a[x],rt),rt);
152     insert(new node(tmp),rt);
153     if(l==r)return;
154     int mid=(l+r)>>1;
155     if(x<=mid)mset(l,mid,rt<<1);
156     else mset(mid+1,r,rt<<1|1);
157 }
158 inline void qsum(int l,int r,int rt){
159     if(s<=l&&t>=r){
160         tmp+=rank(ans,rt);
161         return;
162     }
163     int mid=(l+r)>>1;
164     if(s<=mid)qsum(l,mid,rt<<1);
165     if(t>mid)qsum(mid+1,r,rt<<1|1);
166 }
167 inline void qpred(int l,int r,int rt){
168     if(s<=l&&t>=r){
169         node *x=pred(ans,rt);
170         if(x)tmp=max(tmp,x->data);
171         return;
172     }
173     int mid=(l+r)>>1;
174     if(s<=mid)qpred(l,mid,rt<<1);
175     if(t>mid)qpred(mid+1,r,rt<<1|1);
176 }
177 inline void qsucc(int l,int r,int rt){
178     if(s<=l&&t>=r){
179         node *x=succ(ans,rt);
180         if(x)tmp=min(tmp,x->data);
181         return;
182     }
183     int mid=(l+r)>>1;
184     if(s<=mid)qsucc(l,mid,rt<<1);
185     if(t>mid)qsucc(mid+1,r,rt<<1|1);
186 }
187 inline void copy(int rt,node *x){
188     if(!x)return;
189     insert(new node(x->data),rt);
190     copy(rt,x->lc);
191     copy(rt,x->rc);
192 }
193 inline void insert(node *x,int i){
194     if(!root[i]){
195         root[i]=x;
196         return;
197     }
198     node *rt=root[i];
199     for(;;){
200         if(x->data<rt->data){
201             if(rt->lc)rt=rt->lc;
202             else{
203                 rt->lc=x;
204                 break;
205             }
206         }
207         else{
208             if(rt->rc)rt=rt->rc;
209             else{
210                 rt->rc=x;
211                 break;
212             }
213         }
214     }
215     x->prt=rt;
216     while(rt){
217         rt->refresh();
218         if(rt->bal()==2){
219             x=rt->lc;
220             if(x->bal()==-1)lrot(x,i);
221             rrot(rt,i);
222             rt=rt->prt;
223         }
224         else if(rt->bal()==-2){
225             x=rt->rc;
226             if(x->bal()==1)rrot(x,i);
227             lrot(rt,i);
228             rt=rt->prt;
229         }
230         rt=rt->prt;
231     }
232 }
233 inline node *find(int x,int i){
234     node *rt=root[i];
235     while(rt){
236         if(x<rt->data)rt=rt->lc;
237         else if(x>rt->data)rt=rt->rc;
238         else return rt;
239     }
240     return NULL;
241 }
242 inline void erase(node *x,int i){
243     if(x->lc&&x->rc){
244         node *y=findmax(x->lc);
245         x->data=y->data;
246         erase(y,i);
247     }
248     else{
249         if(x->lc&&!x->rc){
250             x->lc->prt=x->prt;
251             if(x->prt){
252                 if(x==x->prt->lc)x->prt->lc=x->lc;
253                 else x->prt->rc=x->lc;
254             }
255             else root[i]=x->lc;
256         }
257         else if(!x->lc&&x->rc){
258             x->rc->prt=x->prt;
259             if(x->prt){
260                 if(x==x->prt->lc)x->prt->lc=x->rc;
261                 else x->prt->rc=x->rc;
262             }
263             else root[i]=x->rc;
264         }
265         else{
266             if(x->prt){
267                 if(x==x->prt->lc)x->prt->lc=NULL;
268                 else x->prt->rc=NULL;
269             }
270             else root[i]=NULL;
271         }
272         node *rt=x->prt;
273         delete x;
274         for(;rt;rt=rt->prt){
275             rt->refresh();
276             if(rt->bal()==2){
277                 x=rt->lc;
278                 if(x->bal()==-1)lrot(x,i);
279                 rrot(rt,i);
280                 rt=rt->prt;
281             }
282             else if(rt->bal()==-2){
283                 x=rt->rc;
284                 if(x->bal()==1)rrot(x,i);
285                 lrot(rt,i);
286                 rt=rt->prt;
287             }
288         }
289     }
290 }
291 inline int rank(int x,int i){
292     node *rt=root[i],*y=rt;
293     int ans=0;
294     while(rt){
295         y=rt;
296         if(x<=rt->data)rt=rt->lc;
297         else{
298             if(rt->lc)ans+=rt->lc->size;
299             ans++;
300             rt=rt->rc;
301         }
302     }
303     return ans;
304 }
305 inline node *pred(int x,int i){
306     node *rt=root[i],*y=NULL;
307     while(rt){
308         if(rt->data<x){
309             if(!y||y->data<rt->data)y=rt;
310             rt=rt->rc;
311         }
312         else rt=rt->lc;
313     }
314     return y;
315 }
316 inline node *succ(int x,int i){
317     node *rt=root[i],*y=NULL;
318     while(rt){
319         if(rt->data>x){
320             if(!y||y->data>rt->data)y=rt;
321             rt=rt->lc;
322         }
323         else rt=rt->rc;
324     }
325     return y;
326 }
327 inline void lrot(node *x,int i){
328     node *y=x->rc;
329     if(x->prt){
330         if(x==x->prt->lc)x->prt->lc=y;
331         else x->prt->rc=y;
332     }
333     else root[i]=y;
334     y->prt=x->prt;
335     x->rc=y->lc;
336     if(y->lc)y->lc->prt=x;
337     y->lc=x;
338     x->prt=y;
339     x->refresh();
340     y->refresh();
341 }
342 inline void rrot(node *x,int i){
343     node *y=x->lc;
344     if(x->prt){
345         if(x==x->prt->lc)x->prt->lc=y;
346         else x->prt->rc=y;
347     }
348     else root[i]=y;
349     y->prt=x->prt;
350     x->lc=y->rc;
351     if(y->rc)y->rc->prt=x;
352     y->rc=x;
353     x->prt=y;
354     x->refresh();
355     y->refresh();
356 }
357 inline node *findmax(node *x){
358     if(!x)return NULL;
359     while(x->rc)x=x->rc;
360     return x;
361 }
View Code

【后记】

其实这三个树套树都用的线段树套平衡树(虽然k小数也用树状数组套平衡树水了一发)……

平衡树前两题用的Splay,最后一题因为被卡了换成了AVL树……(我会说其实这些都是从普通平衡树抄来的么……)作为Treap和SBT打死学不会的渣渣表示orzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz……

二逼平衡树的AC标志着自己平衡树系列三题全过……撒花。(好像自己是HZOI2016第一个A掉二逼平衡树的……吓傻orzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz。)

今天也是闲得= =没事非要水树套树……结果把COGS上树套树分类里的4道题水完了……

树套树这种东西联赛肯定用不着,以后还是乖乖复习联赛知识吧……

【9.19新增】

其实昨天那些代码里的平衡树都是抄自己的普通平衡树= =根本没有好好打= =

于是乎今天荒废了将近两节课打线段树套Splay,从头到尾打了将近300行= =

从头到尾手打二逼平衡树,留念。

(这次没加卡常,不过线段树的build改成了左子树复制右子树合并,好像快了不少)

#include<cstdio>
#include<cstring>
#include<algorithm>
#define siz(x) ((x)?(x)->size:0)
using namespace std;
const int maxn=50010;
struct node{//Splay Tree
    int data,size;
    node *lc,*rc,*prt;
    node(int d=0):data(d),size(1),lc(NULL),rc(NULL),prt(NULL){}
    inline void refresh(){size=siz(lc)+siz(rc)+1;}
}*root[maxn<<2];
void build(int,int,int);
void qsum(int,int,int);
void qpred(int,int,int);
void qsucc(int,int,int);
void mset(int,int,int);
void copy(node*&,node*);
void merge(int,node*);
void insert(node*,int);
node *find(int,int);
void erase(node*,int);
int rank(int,int);
node *pred(int,int);
node *succ(int,int);
void splay(node*,node*,int);
void lrot(node*,int);
void rrot(node*,int);
node *findmax(node*);
int n,m,a[maxn],tmp,ans,s,t,k,d,l,r;
int main(){
#define MINE
#ifdef MINE
    freopen("psh.in","r",stdin);
    freopen("psh.out","w",stdout);
#endif
    scanf("%d%d",&n,&m);
    build(1,n,1);
    while(m--){
        scanf("%d",&d);
        if(d==1){//查询k在区间内的排名
            scanf("%d%d%d",&s,&t,&k);
            ans=k;
            tmp=1;
            qsum(1,n,1);
            printf("%d
",tmp);
        }
        else if(d==2){//查询区间内排名为k的值
            scanf("%d%d%d",&s,&t,&k);
            l=0;r=100000000;
            while(l<=r){
                ans=(l+r)>>1;
                tmp=0;
                qsum(1,n,1);
                if(tmp<k)l=ans+1;
                else r=ans-1;
            }
            printf("%d
",r);
        }
        else if(d==3){//修改某一位值上的数值
            scanf("%d%d",&ans,&tmp);
            mset(1,n,1);
            a[ans]=tmp;
        }
        else if(d==4){//查询k在区间内的前驱(前驱定义为小于x,且最大的数)
            scanf("%d%d%d",&s,&t,&k);
            ans=k;
            tmp=0;
            qpred(1,n,1);
            printf("%d
",tmp);
        }
        else if(d==5){//查询k在区间内的后继(后继定义为大于x,且最小的数)
            scanf("%d%d%d",&s,&t,&k);
            ans=k;
            tmp=100000000;
            qsucc(1,n,1);
            printf("%d
",tmp);
        }
    }
#ifndef MINE
    printf("
--------------------DONE--------------------
");
    for(;;);
#endif
    return 0;
}
void build(int l,int r,int rt){
    if(l==r){
        scanf("%d",&a[l]);
        insert(new node(a[l]),rt);
        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,rt<<1);
    build(mid+1,r,rt<<1|1);
    copy(root[rt],root[rt<<1]);
    merge(rt,root[rt<<1|1]);
}
void qsum(int l,int r,int rt){
    if(s<=l&&t>=r){
        tmp+=rank(ans,rt);
        return;
    }
    int mid=(l+r)>>1;
    if(s<=mid)qsum(l,mid,rt<<1);
    if(t>mid)qsum(mid+1,r,rt<<1|1);
}
void qpred(int l,int r,int rt){
    if(s<=l&&t>=r){
        node *x=pred(ans,rt);
        if(x)tmp=max(tmp,x->data);
        return;
    }
    int mid=(l+r)>>1;
    if(s<=mid)qpred(l,mid,rt<<1);
    if(t>mid)qpred(mid+1,r,rt<<1|1);
}
void qsucc(int l,int r,int rt){
    if(s<=l&&t>=r){
        node *x=succ(ans,rt);
        if(x)tmp=min(tmp,x->data);
        return;
    }
    int mid=(l+r)>>1;
    if(s<=mid)qsucc(l,mid,rt<<1);
    if(t>mid)qsucc(mid+1,r,rt<<1|1);
}
void mset(int l,int r,int rt){
    erase(find(a[ans],rt),rt);
    insert(new node(tmp),rt);
    if(l==r)return;
    int mid=(l+r)>>1;
    if(ans<=mid)mset(l,mid,rt<<1);
    else mset(mid+1,r,rt<<1|1);
}
void copy(node *&x,node *y){
    x=new node(y->data);
    if(y->lc){
        copy(x->lc,y->lc);
        x->lc->prt=x;
    }
    if(y->rc){
        copy(x->rc,y->rc);
        x->rc->prt=x;
    }
    x->refresh();
}
void merge(int i,node *x){
    if(!x)return;
    insert(new node(x->data),i);
    merge(i,x->lc);
    merge(i,x->rc);
}
void insert(node *x,int i){
    if(!root[i]){
        root[i]=x;
        return;
    }
    node *rt=root[i];
    for(;;){
        if(x->data<rt->data){
            if(rt->lc)rt=rt->lc;
            else{
                rt->lc=x;
                break;
            }
        }
        else{
            if(rt->rc)rt=rt->rc;
            else{
                rt->rc=x;
                break;
            }
        }
    }
    x->prt=rt;
    for(;rt;rt=rt->prt)rt->refresh();
    splay(x,NULL,i);
}
node *find(int x,int i){
    node *rt=root[i];
    while(rt){
        if(x==rt->data)return rt;
        else if(x<rt->data)rt=rt->lc;
        else rt=rt->rc;
    }
    return NULL;
}
void erase(node *x,int i){
    splay(x,NULL,i);
    if(x->lc){
        splay(findmax(x->lc),x,i);
        x->lc->rc=x->rc;
        if(x->rc)x->rc->prt=x->lc;
        x->lc->prt=NULL;
        root[i]=x->lc;
        x->lc->refresh();
    }
    else{
        if(x->rc)x->rc->prt=NULL;
        root[i]=x->rc;
    }
    delete x;
}
int rank(int x,int i){
    node *rt=root[i],*y=NULL;
    int ans=0;
    while(rt){
        y=rt;
        if(x<=rt->data)rt=rt->lc;
        else{
            ans+=siz(rt->lc)+1;
            rt=rt->rc;
        }
    }
    splay(y,NULL,i);
    return ans;
}
node *pred(int x,int i){
    node *rt=root[i],*y=NULL;
    while(rt){
        if(rt->data<x){
            if(!y||y->data<rt->data)y=rt;
            rt=rt->rc;
        }
        else rt=rt->lc;
    }
    if(y)splay(y,NULL,i);
    return y;
}
node *succ(int x,int i){
    node *rt=root[i],*y=NULL;
    while(rt){
        if(rt->data>x){
            if(!y||y->data>rt->data)y=rt;
            rt=rt->lc;
        }
        else rt=rt->rc;
    }
    if(y)splay(y,NULL,i);
    return y;
}
void splay(node *x,node *tar,int i){
    for(node *rt=x->prt;rt!=tar;rt=x->prt){
        if(rt->prt==tar){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            break;
        }
        if(rt==rt->prt->lc){
            if(x==rt->lc)rrot(rt,i);
            else lrot(rt,i);
            rrot(x->prt,i);
        }
        else{
            if(x==rt->rc)lrot(rt,i);
            else rrot(rt,i);
            lrot(x->prt,i);
        }
    }
}
void lrot(node *x,int i){
    node *y=x->rc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->rc=y->lc;
    if(y->lc)y->lc->prt=x;
    y->lc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
void rrot(node *x,int i){
    node *y=x->lc;
    if(x->prt){
        if(x==x->prt->lc)x->prt->lc=y;
        else x->prt->rc=y;
    }
    else root[i]=y;
    y->prt=x->prt;
    x->lc=y->rc;
    if(y->rc)y->rc->prt=x;
    y->rc=x;
    x->prt=y;
    x->refresh();
    y->refresh();
}
node *findmax(node *x){
    while(x->rc)x=x->rc;
    return x;
}
View Code
233333333
原文地址:https://www.cnblogs.com/hzoier/p/5882916.html