P3369 普通平衡树(Splay做法)

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入 xx 数
  2. 删除 xx 数(若有多个相同的数,因只删除一个)
  3. 查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 )
  4. 查询排名为 xx 的数
  5. 求 xx 的前驱(前驱定义为小于 xx,且最大的数)
  6. 求 xx 的后继(后继定义为大于 xx,且最小的数)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int inf=1e9;
struct Splay_tree {
    int fa;
    int cnt;
    int ch[2];
    int v;
    int size;
}t[maxn];
int root,tot;
void update (int x) {
    t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}
void rotate (int x) {
    int y=t[x].fa;
    int z=t[y].fa;
    int k=(t[y].ch[1]==x);
    t[z].ch[t[z].ch[1]==y]=x;
    t[x].fa=z;
    t[y].ch[k]=t[x].ch[k^1];
    t[t[x].ch[k^1]].fa=y;
    t[x].ch[k^1]=y;
    t[y].fa=x;
    update(y);
    update(x);
}
void splay (int x,int s) {
    while (t[x].fa!=s) {
        int y=t[x].fa;
        int z=t[y].fa;
        if (z!=s) 
            (t[z].ch[0]==y)^(t[y].ch[0]==x)?rotate(x):rotate(y);
        rotate(x);
    }
    if (s==0) root=x;
}
void find (int x) {
    int u=root;
    if (!u) return;
    while (t[u].ch[x>t[u].v]&&x!=t[u].v) 
        u=t[u].ch[x>t[u].v];
    splay(u,0);
}
void ins (int x) {
    int u=root;
    int fa=0;
    while (u&&t[u].v!=x) {
        fa=u;
        u=t[u].ch[x>t[u].v];
    }
    if (u)
        t[u].cnt++;
    else {
        u=++tot;
        if (fa)
            t[fa].ch[x>t[fa].v]=u;
        t[u].ch[0]=t[u].ch[1]=0;
        t[tot].fa=fa;
        t[tot].v=x;
        t[tot].cnt=1;
        t[tot].size=1;
    }
    splay(u,0); 
}
int Next (int x,int f) {
    find(x);
    int u=root;
    if (t[u].v>x&&f) return u;
    if (t[u].v<x&&!f) return u;
    u=t[u].ch[f];
    while (t[u].ch[f^1]) u=t[u].ch[f^1];
    return u;
} 
void del (int x) {
    int lst=Next(x,0);
    int nxt=Next(x,1);
    splay(lst,0);
    splay(nxt,lst);
    int tt=t[nxt].ch[0];
    if (t[tt].cnt>1) {
        t[tt].cnt--;
        splay(tt,0);
    }
    else 
        t[nxt].ch[0]=0;
}
int kth (int x) {
    int u=root;
    while (t[u].size<x) return 0;
    while (1) {
        int y=t[u].ch[0];
        if (x>t[y].size+t[u].cnt) {
            x-=t[y].size+t[u].cnt;
            u=t[u].ch[1];
        }
        else if (t[y].size>=x) 
            u=y;
        else
            return t[u].v;
    }
}
int main () {
    int n;
    scanf("%d",&n);
    ins(inf);ins(-inf);
    while (n--) {
        int op,x;
        scanf("%d%d",&op,&x);
        if (op==1) ins(x);
        if (op==2) del(x);
        if (op==3) find(x),printf("%d
",t[t[root].ch[0]].size);
        if (op==4) printf("%d
",kth(x+1));
        if (op==5) printf("%d
",t[Next(x,0)].v);
        if (op==6) printf("%d
",t[Next(x,1)].v); 
    }
}
原文地址:https://www.cnblogs.com/zhanglichen/p/13416519.html