splay好板子

找到一份比较好的板子,链接https://blog.csdn.net/crazy_ac/article/details/8034190

#include<cstdio>
#include<cstdlib>
const int inf  = ~0u>>2;
#define L ch[x][0]
#define R ch[x][1]
#define KT (ch[ ch[rt][1] ][0])
const int maxn = 500010;
int lim;
struct SplayTree {
    int sz[maxn];
    int ch[maxn][2];
    int pre[maxn];
    int rt,top;
    inline void up(int x){
        sz[x]  = cnt[x]  + sz[ L ] + sz[ R ];
    }
    inline void Rotate(int x,int f){
        int y=pre[x];
        ch[y][!f] = ch[x][f];
        pre[ ch[x][f] ] = y;
        pre[x] = pre[y];
        if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x;
        ch[x][f] = y;
        pre[y] = x;
        up(y);
    }
    inline void Splay(int x,int goal){//将x旋转到goal的下面
        while(pre[x] != goal){
            if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x);
            else   {
                int y=pre[x],z=pre[y];
                int f = (ch[z][0]==y);
                if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);
                else Rotate(y,f),Rotate(x,f);
            }
        }
        up(x);
        if(goal==0) rt=x;
    }
    inline void RTO(int k,int goal){//将第k位数旋转到goal的下面
        int x=rt;
        while(sz[ L ] != k-1) {
            if(k < sz[ L ]+1) x=L;
            else {
                k-=(sz[ L ]+1);
                x = R;
            }
        }
        Splay(x,goal);
    }
    inline void vist(int x){
        if(x){
            printf("结点%2d : 左儿子  %2d   右儿子  %2d   val:%2d sz=%d  cnt:%d
",x,L,R,val[x],sz[x],cnt[x]);
            vist(L);
            vist(R);
        }
    }
    void debug() {
        puts("");
        vist(rt);
        puts("");
    }
    inline void Newnode(int &x,int c,int f){
        x=++top;
        L = R = 0;
        pre[x] = f;
        sz[x]=1; cnt[x]=1;
        val[x] = c;
    }
    inline void init(){
        ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;
        rt=top=0; cnt[0]=0;
    }
    inline void Insert(int &x,int key,int f){
        if(!x) {
            Newnode(x,key,f);
            Splay(x,0);//注意插入完成后splay
            return ;
        }
        if(key==val[x]){
            cnt[x]++;
            sz[x]++;
            Splay(x,0);//注意插入完成后splay
            return ;
        }else if(key<val[x]) {
            Insert(L,key,x);
        } else {
            Insert(R,key,x);
        }
        up(x);
    }
    void Del_root(){//删除根节点
        int t=rt;
        if(ch[rt][1]) {
            rt=ch[rt][1];
            RTO(1,0);
            ch[rt][0]=ch[t][0];
            if(ch[rt][0]) pre[ch[rt][0]]=rt;
        }
        else rt=ch[rt][0];
        pre[rt]=0;
        up(rt);
    }
    void findpre(int x,int key,int &ans){//找前驱节点
        if(!x)  return ;
        if(val[x] <= key){
            ans=x;
            findpre(R,key,ans);
        } else
            findpre(L,key,ans);
    }
    void findsucc(int x,int key,int &ans){//找后继节点
        if(!x) return ;
        if(val[x]>=key) {
            ans=x;
            findsucc(L,key,ans);
        } else
            findsucc(R,key,ans);
    }
    inline int find_kth(int x,int k){ //第k小的数
        if(k<sz[L]+1) {
            return find_kth(L,k);
        }else if(k > sz[ L ] + cnt[x] ) 
            return find_kth(R,k-sz[L]-cnt[x]);
        else{ 
            Splay(x,0);
            return val[x];
        }
    }
    int find(int x,int key){
        if(!x) return 0;
        else if(key < val[x])  return find(L,key);
        else if(key > val[x])  return find(R,key);
        else return x;
    }
    int getmin(int x){
        while(L) x=L;    return val[x];
    }
    int getmax(int x){
        while(R) x=R;   return val[x];
    }
    //确定key的排名
    int getrank(int x,int key,int cur){//cur:当前已知比要求元素(key)小的数的个数
        if(key == val[x])  
            return sz[L] + cur + 1;
        else if(key < val[x])
            getrank(L,key,cur);
        else 
            getrank(R,key,cur+sz[L]+cnt[rt]);
    }
    int get_lt(int x,int key){//小于key的数的个数 lt:less than 
        if(!x) return 0;
        if(val[x]>=key) return get_lt(L,key);
        return cnt[x]+sz[L]+get_lt(R,key);
    }
    int get_mt(int x,int key){//大于key的数的个数 mt:more than
        if(!x) return 0;
        if(val[x]<=key) return get_mt(R,key) ;
        return cnt[x]+sz[R]+get_mt(L,key);
    }
    void del(int &x,int f){//删除小于lim的所有的数所在的节点
        if(!x) return ;
        if(val[x]>=lim){
            del(L,x);
        } else {
            x=R; 
            pre[x]=f;
            if(f==0)  rt=x;
            del(x,f);
        }
        if(x)  up(x);
    }
    inline void update(){
        del(rt,0);
    }
    int get_mt(int key) {
        return get_mt(rt,key);
    }
    int get_lt(int key) {
        return get_lt(rt,key);
    }
    void insert(int key) {
        Insert(rt,key,0);
    }
    void Delete(int key) {
        int node=find(rt,key);
        Splay(node,0);
        cnt[rt]--;
        if(!cnt[rt])Del_root();
    }
    int kth(int k) {
        return find_kth(rt,k);
    }
    int cnt[maxn];
    int val[maxn];
    int lim;
}spt;
原文地址:https://www.cnblogs.com/zsben991126/p/10003544.html