「Splay」普通平衡树模板

口诀:

$rotate$:先上再下,最后自己

$splay$:祖父未到旋两次,三点一线旋父亲,三点折线旋自己。

$delete$:没有儿子就删光。单个儿子删自己。两个儿子找前驱。

易错点:

$rotate$:祖父不在自己做根

$delete$:自己做根父亲为0

$kth$:先减排名后转移

/*By DennyQi 2018*/
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 100010;
const int INF = 0x3f3f3f3f;
inline int Max(const int a, const int b){ return (a > b) ? a : b; }
inline int Min(const int a, const int b){ return (a < b) ? a : b; }
inline int read(){
    int x = 0; int w = 1; register char c = getchar();
    for(; c ^ '-' && (c < '0' || c > '9'); c = getchar());
    if(c == '-') w = -1, c = getchar();
    for(; c >= '0' && c <= '9'; c = getchar()) x = (x<<3) + (x<<1) + c - '0'; return x * w;
}
int n,opt,x,num_node;
int ch[MAXN][2],fa[MAXN],val[MAXN],size[MAXN],cnt[MAXN],root;
struct Splay{
    inline bool rson(int f, int x){
        return ch[f][1] == x;
    }
    inline void update(int x){
        size[x] = size[ch[x][0]] + size[ch[x][1]] + cnt[x];
    }
    inline void clear(int x){
        val[x]=cnt[x]=size[x]=fa[x]=ch[x][0]=ch[x][1]=0;
    }
    inline void rotate(int x){
        int f = fa[x], gf = fa[f];
        bool p = rson(f, x), q = !p;
        if(gf) ch[gf][rson(gf,f)] = x; else root = x; fa[x] = gf;
        ch[f][p] = ch[x][q], fa[ch[x][q]] = f;
        ch[x][q] = f, fa[f] = x;
        update(f), update(x);
    }
    inline void splay(int x, int target){
        while(fa[x] != target){
            int f = fa[x], gf = fa[f];
            if(gf == target){ rotate(x); break;}
            if(rson(gf,f) == rson(f,x)) rotate(f); else rotate(x);
            rotate(x);
        }
    }
    inline void Insert(int v){
        int o = root;
        if(root == 0){
            root = ++num_node;
            cnt[root] = size[root] = 1;
            val[root] = v;
            return;
        }
        for(;o;){
            if(v == val[o]){
                cnt[o]++, size[o]++;
                splay(o, 0);
                return;
            }
            bool b = v>val[o];
            if(!ch[o][b]){
                ch[o][b] = ++num_node;
                cnt[ch[o][b]] = size[ch[o][b]] = 1;
                val[ch[o][b]] = v, fa[ch[o][b]] = o;
                splay(ch[o][b], 0);
                return;
            }
            o = ch[o][v>val[o]];
        }
    }
    inline void Find(int v){
        for(int o = root; o; o = ch[o][v>val[o]]){
            if(val[o] == v){ splay(o, 0); return; }
            if(!ch[o][v>val[o]]) return;
        }
    }
    inline void Delete(int v){
        Find(v);
        if(val[root] != v) return;
        int o = root;
        if(cnt[o] > 1){ --cnt[o],--size[o]; return; }
        if(!ch[o][0] && !ch[o][1]){ root = 0, fa[root] = 0; return; }
        if(!ch[o][0]){ root = ch[o][1], fa[root] = 0; return; }
        if(!ch[o][1]){ root = ch[o][0], fa[root] = 0; return; }
        int l_max = ch[root][0];
        while(ch[l_max][1]) l_max = ch[l_max][1];
        splay(l_max, root);
        ch[l_max][1] = ch[root][1];
        fa[ch[root][1]] = l_max;
        fa[l_max] = 0;
        int pre_root = root;
        root = l_max;
        clear(pre_root);
    }
    inline int Rnk(int x){
        Find(x);
        return size[ch[root][0]] + 1;
    }
    inline int Kth(int k){
        for(int o = root; o;){
            if(size[ch[o][0]] >= k) o = ch[o][0];
            else if(size[ch[o][0]] + cnt[o] < k){
                k -= size[ch[o][0]] + cnt[o];
                o = ch[o][1];
            }
            else{
                splay(o,0);
                return val[o];
            } 
        }
    }
    inline int Pre(int v){
        Insert(v);
        int o = ch[root][0];
        while(ch[o][1]) o = ch[o][1];
        int ans = val[o];
        Delete(v);
        return ans;
    }
    inline int Nxt(int v){
        Insert(v);
        int o = ch[root][1]; 
        while(ch[o][0]) o = ch[o][0];
        int ans = val[o]; 
        Delete(v);
        return ans;
    }
}qxz;
int main(){
//    freopen(".in","r",stdin);
    n = read();
    for(int i = 1; i <= n; ++i){
        opt = read(), x = read();
        if(opt==1) qxz.Insert(x);
        if(opt==2) qxz.Delete(x);
        if(opt==3) printf("%d
",qxz.Rnk(x));
        if(opt==4) printf("%d
",qxz.Kth(x));
        if(opt==5) printf("%d
",qxz.Pre(x));
        if(opt==6) printf("%d
",qxz.Nxt(x));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/qixingzhi/p/9365586.html