BZOJ3224:普通平衡树——题解

http://www.lydsy.com/JudgeOnline/problem.php?id=3224

题面源于洛谷

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入x数

  2. 删除x数(若有多个相同的数,因只删除一个)

  3. 查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)

  4. 查询排名为x的数

  5. 求x的前驱(前驱定义为小于x,且最大的数)

  6. 求x的后继(后继定义为大于x,且最小的数)

输入输出格式

输入格式:

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号( 1 leq opt leq 61opt6 )

输出格式:

对于操作3,4,5,6每行输出一个数,表示对应答案

输入输出样例

输入样例#1: 
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
输出样例#1: 
106465
84185
492737

——————————————————————————————

http://blog.csdn.net/clove_unique/article/details/50630280

所以最开始学习splay就从上面的博客学的话做这道题就是切。

简单讲几个问题吧。

1.findx操作。

和我前面博客的kthmin(或者query)思路一致,就是变成了非递归的。

2.find操作。

从根出发,判断当前左右子结点的值和x比较就可以知道要往哪里走了。

同时沿途中顺便记录排名即可,方法同findx操作。

3.del操作。

如果看过我前面的博客的话就会发现del操作只能删除根节点。

那么我们怎么删除任意结点呢?很简单,和find函数联动即可。

因为find(x)函数会返回x的排名,但同时会把x放在根上,所以又转换成了删除根节点的问题了。

#include<cstdio>
#include<queue>
#include<cctype>
#include<cstring>
#include<cmath>
#include<vector>
#include<algorithm>
using namespace std;
const int N=100001;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
int fa[N],tr[N][2],key[N],cnt[N],size[N];
int root,sz;
inline void clear(int x){
    fa[x]=tr[x][0]=tr[x][1]=key[x]=cnt[x]=size[x]=0;
    return;
}
inline bool get(int x){
    return tr[fa[x]][1]==x;
}
inline void update(int x){
    if(x){  
        size[x]=cnt[x];  
        if(tr[x][0])size[x]+=size[tr[x][0]];  
        if(tr[x][1])size[x]+=size[tr[x][1]];  
    }  
    return;
}
inline void rotate(int x){
    int old=fa[x],oldf=fa[old],which=get(x);
    tr[old][which]=tr[x][which^1];fa[tr[old][which]]=old;  
    fa[old]=x;tr[x][which^1]=old;fa[x]=oldf;
    if(oldf)tr[oldf][tr[oldf][1]==old]=x;
    update(old);update(x);
    return;
}
inline void splay(int x){
    int f=fa[x];
    while(f){
        if(fa[f])rotate((get(x)==get(f)?f:x));
        rotate(x);f=fa[x];
    }
    root=x;
    return;
}
inline void insert(int v){
    if(!root){
        sz++;tr[sz][0]=tr[sz][1]=fa[sz]=0;
        key[sz]=v;cnt[sz]=size[sz]=1;root=sz;
        return;
    }
    int now=root,f=0;
    while(233){
        if(key[now]==v){
            cnt[now]++;update(now);update(f);splay(now);
            break;
        }
        f=now;
        now=tr[now][key[now]<v];
        if(!now){
            sz++;tr[sz][0]=tr[sz][1]=0;fa[sz]=f;
            key[sz]=v;cnt[sz]=size[sz]=1;
            tr[f][key[f]<v]=sz;
            update(f);splay(sz);
            break;
        }
    }
    return;
}
inline int find(int v){//查询v的排名
    int ans=0,now=root;
    while(233){
        if(v<key[now])now=tr[now][0];
        else{
            ans+=(tr[now][0]?size[tr[now][0]]:0);
            if(v==key[now]){
                splay(now);
                return ans+1;
            }
            ans+=cnt[now];
            now=tr[now][1];
        }
    }
}
inline int findx(int x){//找到排名为x的点 
    int now=root;
    while(233){
        if(tr[now][0]&&x<=size[tr[now][0]])now=tr[now][0];
        else{
            int temp=(tr[now][0]?size[tr[now][0]]:0)+cnt[now];
            if(x<=temp)return key[now];
            x-=temp;now=tr[now][1];
        }
    }
}
inline int pre(){//前驱
    int now=tr[root][0];  
    while(tr[now][1])now=tr[now][1];  
    return now;  
}     
inline int nxt(){//后继
    int now=tr[root][1];  
    while(tr[now][0])now=tr[now][0];  
    return now;
}
inline void del(int x){  
    find(x);
    if(cnt[root]>1){
        cnt[root]--;return;
    }
    if(!tr[root][0]&&!tr[root][1]){
        clear(root);root=0;return;
    }
    if(!tr[root][0]){  
        int oldroot=root;root=tr[root][1];fa[root]=0;clear(oldroot);return;
    }
    else if(!tr[root][1]){  
        int oldroot=root;root=tr[root][0];fa[root]=0;clear(oldroot);return;
    }
    int leftbig=pre(),oldroot=root;  
    splay(leftbig);
    fa[tr[oldroot][1]]=root;
    tr[root][1]=tr[oldroot][1];
    clear(oldroot);
    update(root);
    return;
}
int main(){
    int n=read();
    for(int i=1;i<=n;i++){
        int opt=read();
        int k=read();
        if(opt==1)insert(k);
        if(opt==2)del(k);
        if(opt==3)printf("%d
",find(k));
        if(opt==4)printf("%d
",findx(k));
        if(opt==5){
            insert(k);
            printf("%d
",key[pre()]);
            del(k);
        }
        if(opt==6){
            insert(k);
            printf("%d
",key[nxt()]);
            del(k);
        }
    }
    return 0;
}

 UPD :Treap

#include<cmath>
#include<cstdio>
#include<cctype>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int INF=2147483647;
const int N=1e5+5;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
struct treap{
    int l,r,p,size,cnt,key;
#define lc(x) tr[x].l
#define rc(x) tr[x].r
#define v(x) tr[x].key
#define p(x) tr[x].p
#define c(x) tr[x].cnt
#define s(x) tr[x].size
}tr[N];
int sz,rt;
inline int rand(){
    static int seed=233;
    return seed=(ll)seed*482711%998244353;
}
inline void upt(int k){
    s(k)=s(lc(k))+s(rc(k))+c(k);
}
inline void zig(int &k){
    int y=lc(k);lc(k)=rc(y);rc(y)=k;
    s(y)=s(k);upt(k);
    k=y;
}
inline void zag(int &k){
    int y=rc(k);rc(k)=lc(y);lc(y)=k;
    s(y)=s(k);upt(k);
    k=y;
}
inline void insert(int &k,int v){
    if(!k){
    k=++sz;v(k)=v;p(k)=rand();
    c(k)=s(k)=1;lc(k)=rc(k)=0;
    return;
    }
    else s(k)++;
    if(v(k)==v)c(k)++;
    else if(v<v(k)){
    insert(lc(k),v);
    if(p(lc(k))<p(k))zig(k);
    }else{
    insert(rc(k),v);
    if(p(rc(k))<p(k))zag(k);
    }
}
inline void del(int &k,int v){
    if(v(k)==v){
    if(c(k)>1)c(k)--,s(k)--;
    else if(!lc(k) || !rc(k))k=lc(k)+rc(k);
    else if(p(lc(k))<p(rc(k)))zig(k),del(k,v);
    else zag(k),del(k,v);
    return;
    }
    else s(k)--;
    if(v<v(k))del(lc(k),v);
    else del(rc(k),v);
}
inline int find(int v){
    int x=rt,res=0;
    while(x){
    if(v==v(x))return res+s(lc(x))+1;
    if(v<v(x))x=lc(x);
    else res+=s(lc(x))+c(x),x=rc(x);
    }
    return res;
}
inline int findx(int k){
    int x=rt;
    while(x){
    if(s(lc(x))<k&&s(lc(x))+c(x)>=k)return v(x);
    if(s(lc(x))>=k)x=lc(x);
    else k-=s(lc(x))+c(x),x=rc(x);
    }
    return 0;
}
inline int pre(int v){
    int x=rt,res=-INF;
    while(x){
    if(v(x)<v)res=v(x),x=rc(x);
    else x=lc(x);
    }
    return res;
}
inline int nxt(int v){
    int x=rt,res=INF;
    while(x){
    if(v(x)>v)res=v(x),x=lc(x);
    else x=rc(x);
    }
    return res;
}
int main(){
    int n=read();
    for(int i=1;i<=n;i++){
        int opt=read();
        int k=read();
        if(opt==1)insert(rt,k);
        if(opt==2)del(rt,k);
        if(opt==3)printf("%d
",find(k));
        if(opt==4)printf("%d
",findx(k));
        if(opt==5){
            insert(rt,k);
            printf("%d
",pre(k));
            del(rt,k);
        }
        if(opt==6){
            insert(rt,k);
            printf("%d
",nxt(k));
            del(rt,k);
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/luyouqi233/p/8145026.html