Treap

Treap

您需要写一种数据结构,来维护一些数,其中需要提供以下操作:

  1. 插入 \(x\)
  2. 删除 \(x\) 数(若有多个相同的数,因只删除一个)
  3. 查询 \(x\) 数的排名(排名定义为比当前数小的数的个数 \(+1\) )
  4. 查询排名为 \(x\) 的数
  5. \(x\) 的前驱(前驱定义为小于 \(x\),且最大的数)
  6. \(x\)的后继(后继定义为大于 \(x\),且最小的数)

第一行为 \(n\),表示操作的个数,下面 \(n\) 行每行有两个数 \(\text{opt}\)\(x\)\(\text{opt}\) 表示操作的序号\(( 1≤opt≤6 )\)

对于操作 \(3,4,5,6\) 每行输出一个数,表示对应答案

const ll maxn = 1e5 + 10;
const ll inf = 1e14 + 10;
ll n, tot, root;

struct node
{
    ll l, r;
    ll v, rk;
    ll cnt, siz;
} s[maxn << 1];

inline ll New(ll w)
{
    s[++tot].v = w;
    s[tot].rk = rand();
    s[tot].cnt = s[tot].siz = 1;
    return tot;
}

inline void upd(ll p)
{
    s[p].siz = s[s[p].l].siz + s[s[p].r].siz + s[p].cnt;
}

inline void build()
{
    New(-inf), New(inf);
    root = 1;
    s[1].r = 2;
    upd(root);
    return;
}

inline void zig(ll &p)
{
    ll q = s[p].l;
    s[p].l = s[q].r, s[q].r = p;
    p = q;
    upd(s[p].r);
    upd(p);
}

inline void zag(ll &p)
{
    ll q = s[p].r;
    s[p].r = s[q].l, s[q].l = p;
    p = q;
    upd(s[p].l);
    upd(p);
}

inline void add(ll &p, ll x)
{
    if (p == 0)
    {
        p = New(x);
        return;
    }
    if (s[p].v == x)
    {
        s[p].cnt++;
        upd(p);
        return;
    }
    if (s[p].v < x)
    {
        add(s[p].r, x);
        if (s[p].rk < s[s[p].r].rk)
            zag(p);
    }
    if (s[p].v > x)
    {
        add(s[p].l, x);
        if (s[p].rk < s[s[p].l].rk)
            zig(p);
    }

    upd(p);
}

inline void delt(ll &p, ll x)
{
    if (p == 0)
        return;
    if (s[p].v == x)
    {
        if (s[p].cnt > 1)
        {
            s[p].cnt--;
            upd(p);
            return;
        }
        if (s[p].l || s[p].r)
        {
            if (s[p].r == 0 || s[s[p].l].rk > s[s[p].r].rk)
            {
                zig(p);
                delt(s[p].r, x);
            }
            else
            {
                zag(p);
                delt(s[p].l, x);
            }
            upd(p);
        }
        else
            p = 0;
        return;
    }
    if (s[p].v < x)
    {
        delt(s[p].r, x);
        upd(p);
        return;
    }
    if (s[p].v > x)
    {
        delt(s[p].l, x);
        upd(p);
        return;
    }
}

inline ll getrank(ll p, ll x)
{
    if (p == 0)
        return 0;
    if (s[p].v == x)
    {
        return s[s[p].l].siz + 1;
    }
    if (s[p].v > x)
    {
        return getrank(s[p].l, x);
    }
    else
    {
        return getrank(s[p].r, x) + s[s[p].l].siz + s[p].cnt;
    }
}

inline ll getval(ll p, ll x)
{
    if (p == 0)
        return inf;
    if (s[s[p].l].siz >= x)
    {
        return getval(s[p].l, x);
    }
    if (s[s[p].l].siz + s[p].cnt >= x)
    {
        return s[p].v;
    }
    else
    {
        return getval(s[p].r, x - s[s[p].l].siz - s[p].cnt);
    }
}

inline ll getpre(ll x)
{
    ll ans = 1;
    ll p = root;
    while (p)
    {
        if (x == s[p].v)
        {
            if (s[p].l)
            {
                p = s[p].l;
                while (s[p].r)
                    p = s[p].r;
                ans = p;
            }
            break;
        }
        if (s[p].v < x && s[p].v > s[ans].v)
            ans = p;
        if (x < s[p].v)
            p = s[p].l;
        else
            p = s[p].r;
    }
    return s[ans].v;
}

inline ll getnxt(ll x)
{
    ll ans = 2;
    ll p = root;
    while (p)
    {
        if (x == s[p].v)
        {
            if (s[p].r)
            {
                p = s[p].r;
                while (s[p].l)
                    p = s[p].l;
                ans = p;
            }
            break;
        }
        if (s[p].v > x && s[p].v < s[ans].v)
            ans = p;
        if (x < s[p].v)
            p = s[p].l;
        else
            p = s[p].r;
    }
    return s[ans].v;
}

int main()
{
    build();
    n = read();
    for (int i = 1; i <= n; i++)
    {
        ll op = read(), x = read();
        if (op == 1)
            add(root, x);
        if (op == 2)
            delt(root, x);
        if (op == 3)
            printf("%lld\n", getrank(root, x) - 1);
        if (op == 4)
            printf("%lld\n", getval(root, x + 1));
        if (op == 5)
            printf("%lld\n", getpre(x));
        if (op == 6)
            printf("%lld\n", getnxt(x));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/EdisonBa/p/14948605.html