树套树

历时三天终于打过了树套树  激动激动激动

写个博客纪念一下  

二逼平衡树~

// luogu-judger-enable-o2
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
using namespace std;
typedef long long ll;

const int maxn = 101000;
const int INF = 1e9+7;

int n,m,tot;
int a[maxn],b[maxn];
int rt[maxn];

struct Node{
    int ch[2];
    int v,fa;
    int sz,cnt;
    
    void init(int val,int f){
        ch[0]=ch[1]=0;
        v=val,fa=f;
        sz=cnt=1;
    }
}t[maxn*100];

void pushup(int i){ t[i].sz=t[t[i].ch[0]].sz+t[t[i].ch[1]].sz+t[i].cnt; }

void rotate(int x){
    int y=t[x].fa,z=t[y].fa;
    int k=(t[y].ch[1]==x);
    t[z].ch[t[z].ch[1]==y]=x;
    t[x].fa=z;
    t[y].ch[k]=t[x].ch[k^1];
    t[t[x].ch[k^1]].fa=y;
    t[x].ch[k^1]=y;
    t[y].fa=x;
    pushup(y); pushup(x);
}

void splay(int x,int g,int ts){
    while(t[x].fa!=g){
        int y=t[x].fa,z=t[y].fa;
        if(z!=g){
            (t[z].ch[1]==y)^(t[y].ch[1]==x)?rotate(x):rotate(y);
        }
        rotate(x);
    }
    if(g==0) rt[ts]=x;
}

void insert(int ts,int k){
    int u=rt[ts],ff=0;
    if(!u){
        rt[ts]=u=++tot;
        t[u].init(k,0);
        return;
    } 
    while(u&&t[u].v!=k) ff=u,u=t[u].ch[k>t[u].v];
    if(u) ++t[u].cnt;
    else{
        u=++tot;
        if(ff!=0) t[ff].ch[k>t[ff].v]=u;
        t[u].init(k,ff);
    }
    splay(u,0,ts);
}

void find(int ts,int x){
    int u=rt[ts];
    if(!u) return;
    while(t[u].ch[x>t[u].v]&&t[u].v!=x){
        u=t[u].ch[x>t[u].v]; 
    }
    splay(u,0,ts); 
}

int nxt(int ts,int x,int c){
    find(ts,x);
    int u=rt[ts];
    if(t[u].v>x && c) return u;
    if(t[u].v<x && !c) return u;
    u=t[u].ch[c];
    while(t[u].ch[c^1]){
        u=t[u].ch[c^1];
    }return u;
}

void del(int ts,int x){
    int pre=nxt(ts,x,0);
    int nx=nxt(ts,x,1);
    splay(pre,0,ts);splay(nx,pre,ts);
    int d=t[nx].ch[0];
    if(t[d].cnt>1){
        --t[d].cnt;
        splay(d,0,ts);
    }else{
        t[nx].ch[0]=0;
    }
}

int rk(int ts,int x){
    find(ts,x);
    int u=rt[ts];
    return t[t[u].ch[0]].sz;
}

void build(int i,int l,int r){
    for(int j=l;j<=r;j++) insert(i,a[j]);
    insert(i,-2147483647); insert(i,2147483647);
    if(l==r){ return; }
    int mid=(l+r)/2;
    build(i<<1,l,mid),build(i<<1|1,mid+1,r);
}

void update(int i,int pos,int k,int l,int r,int o){
    del(i,o);
    insert(i,k);
    if(l==r){ return; }
    int mid=(l+r)/2;
    if(pos<=mid) update(i<<1,pos,k,l,mid,o);
    else update(i<<1|1,pos,k,mid+1,r,o);
}

int qry_rk(int i,int l,int r,int k,int x,int y){
    if(x<=l && r<=y){
        insert(i,k);
        int p=rk(i,k)-1;
        del(i,k);
        return p;
    }
    int mid=(l+r)/2,res=0;
    if(x<=mid) res+=qry_rk(i<<1,l,mid,k,x,y);
    if(y>mid) res+=qry_rk(i<<1|1,mid+1,r,k,x,y);
    return res;
}

int qry_pre(int i,int l,int r,int k,int x,int y){
    if(x<=l && r<=y){
        insert(i,k);
        int u=nxt(i,k,0);
        del(i,k);
        return t[u].v;
    }
    int mid=(l+r)/2;
    int res=-2147483647;
    if(x<=mid) res=max(res,qry_pre(i<<1,l,mid,k,x,y));
    if(y>mid) res=max(res,qry_pre(i<<1|1,mid+1,r,k,x,y));
    return res;
}

int qry_nxt(int i,int l,int r,int k,int x,int y){
    if(x<=l && r<=y){
        insert(i,k);
        int u=nxt(i,k,1);
        del(i,k); 
        return t[u].v;
    }
    int mid=(l+r)/2;
    int res=2147483647;
    if(x<=mid) res=min(res,qry_nxt(i<<1,l,mid,k,x,y));
    if(y>mid) res=min(res,qry_nxt(i<<1|1,mid+1,r,k,x,y));
    return res;
}

int qry_nx(int i,int l,int r,int k,int x,int y,int c){
    if(x<=l && r<=y){
        return t[nxt(i,k,c)].v;
    }
    int mid=(l+r)/2;
    if(c==0){
        int res=-INF;
        if(x<=mid) return max(res,qry_nx(i<<1,l,mid,k,x,y,c));
        if(y>mid) return max(res,qry_nx(i<<1|1,mid+1,r,k,x,y,c));
    }else{
        int res=INF;
        if(x<=mid) return min(res,qry_nx(i<<1,l,mid,k,x,y,c));
        if(y>mid) return min(res,qry_nx(i<<1|1,mid+1,r,k,x,y,c));
    }
}

void print(int i){
    if(!i) return;
    print(t[i].ch[0]);
    printf("%d ",t[i].v);
    print(t[i].ch[1]);
}

void write(int i,int l,int r){
    if(l==r){
        printf("%d : %d\n",rt[i],t[rt[i]].v);
        return;
    }
    printf("%d :",rt[i]);
    print(rt[i]);
    printf("\n");
    int mid=(l+r)/2;
    write(i<<1,l,mid),write(i<<1|1,mid+1,r);
}

ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }

int main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    build(1,1,n);
//    write(1,1,n);
    int op,l,r,x;
    for(int i=1;i<=m;i++){
        op=read(),l=read(),r=read();
        if(op==1){
            x=read();
            printf("%d\n",qry_rk(1,1,n,x,l,r)+1); 
        }else if(op==2){
            x=read();
            int L=0,R=1e8,ans=0,mid;
            while(L<=R){
                mid=(L+R)/2;
                if(qry_rk(1,1,n,mid,l,r)<x){
                    L=mid+1;
                }else{
                    ans=mid;
                    R=mid-1;
                }
            }
            printf("%d\n",ans-1);
        }else if(op==3){
            update(1,l,r,1,n,a[l]);
            a[l]=r;
        }else if(op==4){
            x=read();
            printf("%d\n",qry_pre(1,1,n,x,l,r)); 
        }else if(op==5){
            x=read();
            printf("%d\n",qry_nxt(1,1,n,x,l,r)); 
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/tuchen/p/10332505.html