P3369 【模板】普通平衡树

题目链接:

以下代码全是参考于AgOH大佬!!!

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
const int N=1e6+10;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
struct note{int l,r,siz,cnt,val;}spl[N];
int rt,cnt;
void newnode(int &now,int val)
{
    spl[now=++cnt].val=val;
    spl[now].siz=1;
    spl[now].cnt=1;
}
void update(int now){spl[now].siz=spl[spl[now].l].siz+spl[spl[now].r].siz+spl[now].cnt;}
void zig(int &now)
{
    int l=spl[now].l;
    spl[now].l=spl[l].r;
    spl[l].r=now;
    now=l;
    update(spl[now].r),update(now);
}
void zag(int &now)
{
    int r=spl[now].r;
    spl[now].r=spl[r].l;
    spl[r].l=now;
    now=r;
    update(spl[now].l),update(now);
}
void splaying(int x,int &y)
{
    if(x==y) return;
    int &l=spl[y].l,&r=spl[y].r;
    if(x==l) zig(y);
    else if(x==r) zag(y);
    else if(spl[x].val<spl[y].val)
    {
        if(spl[x].val<spl[l].val) splaying(x,spl[l].l),zig(y),zig(y);
        else splaying(x,spl[l].r),zag(l),zig(y);
    }
    else
    {
        if(spl[x].val<spl[r].val) splaying(x,spl[r].l),zig(r),zag(y);
        else splaying(x,spl[r].r),zag(y),zag(y);
    }
}
void delnode(int now)
{
    splaying(now,rt);
    if(spl[now].cnt>1) spl[now].cnt--,spl[now].siz--;
    else if(spl[now].r)
    {
        int p=spl[now].r;
        while(spl[p].l) p=spl[p].l;
        splaying(p,spl[now].r);
        spl[spl[rt].r].l=spl[rt].l;
        rt=spl[rt].r;
        update(rt);
    }
    else rt=spl[now].l;
}
void ins(int &now,int val)
{
    if(!now) newnode(now,val),splaying(now,rt);
    else if(spl[now].val>val) ins(spl[now].l,val);
    else if(spl[now].val<val) ins(spl[now].r,val);
    else spl[now].cnt++,spl[now].siz++,splaying(now,rt);
}
void del(int &now,int val)
{
    if(spl[now].val==val) delnode(now);
    else if(spl[now].val>val) del(spl[now].l,val);
    else del(spl[now].r,val);
}
int getnum(int rk)
{
    int now=rt;
    while(now)
    {
        int lsiz=spl[spl[now].l].siz;
        if(lsiz+1<=rk&&rk<=lsiz+spl[now].cnt)
        {
            splaying(now,rt);
            break;
        }
        else if(lsiz>=rk) now=spl[now].l;
        else
        {
            rk-=lsiz+spl[now].cnt;
            now=spl[now].r;
        }
    }
    return spl[now].val;
}
int getrank(int val)
{
    int now=rt,rk=1;
    while(now)
    {
        if(spl[now].val==val)
        {
            rk+=spl[spl[now].l].siz;
            splaying(now,rt);
            break;
        }
        else if(spl[now].val>val) now=spl[now].l;
        else
        {
            rk+=spl[now].cnt+spl[spl[now].l].siz;
            now=spl[now].r;
        }
    }
    return rk;
}
int main()
{
    int n;read(n);
    for(re int i=1,op,x;i<=n;i++)
    {
        read(op),read(x);
        if(op==1) ins(rt,x);
        else if(op==2) del(rt,x);
        else if(op==3) printf("%d
",getrank(x));
        else if(op==4) printf("%d
",getnum(x));
        else if(op==5) printf("%d
",getnum(getrank(x)-1));
        else printf("%d
",getnum(getrank(x+1)));
    }
    return 0;
}

 AVL版

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
const int N=1e6+10;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
struct note{int l,r,val,height,siz;}avl[N];
int cnt,rt;
void newnode(int &now,int val)
{
    avl[now=++cnt].val=val;
    avl[cnt].siz=1;
}
void update(int now)
{
    avl[now].siz=avl[avl[now].l].siz+avl[avl[now].r].siz+1;
    avl[now].height=max(avl[avl[now].l].height,avl[avl[now].r].height)+1;
}
int factor(int now){return avl[avl[now].l].height-avl[avl[now].r].height;}
void zig(int &now)
{
    int l=avl[now].l;
    avl[now].l=avl[l].r;
    avl[l].r=now;
    now=l;
    update(avl[now].r),update(now);
}
void zag(int &now)
{
    int r=avl[now].r;
    avl[now].r=avl[r].l;
    avl[r].l=now;
    now=r;
    update(avl[now].l),update(now);
}
void check(int &now)
{
    int nf=factor(now);
    if(nf>1)
    {
        int lf=factor(avl[now].l);
        if(lf>0) zig(now);
        else zag(avl[now].l),zig(now);
    }
    else if(nf<-1)
    {
        int rf=factor(avl[now].r);
        if(rf<0) zag(now);
        else zig(avl[now].r),zag(now);
    }
    else if(now) update(now);
}
void ins(int &now,int val)
{
    if(!now) newnode(now,val);
    else if(val<avl[now].val) ins(avl[now].l,val);
    else ins(avl[now].r,val);
    check(now);
}
int getv(int &now,int fa)
{
    int ret;
    if(!avl[now].l)
    {
        ret=now;
        avl[fa].l=avl[now].r;
    }
    else
    {
        ret=getv(avl[now].l,now);
        check(now);
    }
    return ret;
}
void del(int &now,int val)
{
    if(val==avl[now].val)
    {
        int l=avl[now].l,r=avl[now].r;
        if(!l||!r) now=l+r;
        else
        {
            now=getv(r,r);
            if(now!=r) avl[now].r=r;
            avl[now].l=l;
        }
    }
    else if(val<avl[now].val) del(avl[now].l,val);
    else del(avl[now].r,val);
    check(now);
}
int getrank(int val)
{
    int now=rt,rk=1;
    while(now)
    {
        if(val<=avl[now].val) now=avl[now].l;
        else rk+=avl[avl[now].l].siz+1,now=avl[now].r;
    }
    return rk;
}
int getnum(int rk)
{
    int now=rt;
    while(now)
    {
        if(avl[avl[now].l].siz+1==rk) break;
        else if(avl[avl[now].l].siz>=rk) now=avl[now].l;
        else rk-=avl[avl[now].l].siz+1,now=avl[now].r;
    }
    return avl[now].val;
}
int main()
{
    int n;read(n);
    for(re int i=1,op,x;i<=n;i++)
    {
        read(op),read(x);
        if(op==1) ins(rt,x);
        else if(op==2) del(rt,x);
        else if(op==3) printf("%d
",getrank(x));
        else if(op==4) printf("%d
",getnum(x));
        else if(op==5) printf("%d
",getnum(getrank(x)-1));
        else printf("%d
",getnum(getrank(x+1)));
    }
    return 0;
}

 非递归版splay

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
const int N=1e6+10;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
struct note{int ch[2],siz,cnt,val,fa;}spl[N];
int rt,cnt;
void newnode(int &now,int fa,int val)
{
    spl[now=++cnt].val=val;
    spl[now].siz=spl[now].cnt=1;
    spl[now].fa=fa;
}
int ident(int x,int f){return spl[f].ch[1]==x;}///判断是否为右儿子
void connect(int x,int f,int s)
{
    spl[f].ch[s]=x;
    spl[x].fa=f;
}
void update(int now){spl[now].siz=spl[spl[now].ch[0]].siz+spl[spl[now].ch[1]].siz+spl[now].cnt;}
void rotat(int x)
{
    int f=spl[x].fa,ff=spl[f].fa,k=ident(x,f);
    connect(spl[x].ch[k^1],f,k);
    connect(x,ff,ident(f,ff));
    connect(f,x,k^1);
    update(f),update(x);
}
void splaying(int x,int top)///把x转到top的儿子
{
    if(!top) rt=x;
    while(spl[x].fa!=top)
    {
        int f=spl[x].fa,ff=spl[f].fa;
        if(ff!=top) ident(f,ff)^ident(x,f)?rotat(x):rotat(f);
        rotat(x);
    }
}
void delnode(int x)
{
    splaying(x,0);
    if(spl[x].cnt>1) spl[x].cnt--,spl[x].siz--;
    else if(spl[x].ch[1])
    {
        int p=spl[x].ch[1];
        while(spl[p].ch[0]) p=spl[p].ch[0];
        splaying(p,x);
        connect(spl[x].ch[0],p,0);
        rt=p;
        spl[p].fa=0;
        update(rt);
    }
    else rt=spl[x].ch[0],spl[rt].fa=0;
}
void ins(int val,int &now=rt,int fa=0)
{
    if(!now) newnode(now,fa,val),splaying(now,0);
    else if(spl[now].val>val) ins(val,spl[now].ch[0],now);
    else if(spl[now].val<val) ins(val,spl[now].ch[1],now);
    else spl[now].cnt++,spl[now].siz++,splaying(now,0);
}
void del(int val,int now=rt)
{
    if(spl[now].val==val) delnode(now);
    else if(spl[now].val>val) del(val,spl[now].ch[0]);
    else del(val,spl[now].ch[1]);
}
int getnum(int rk)
{
    int now=rt;
    while(now)
    {
        int lsiz=spl[spl[now].ch[0]].siz;
        if(lsiz+1<=rk&&rk<=lsiz+spl[now].cnt)
        {
            splaying(now,0);
            break;
        }
        else if(lsiz>=rk) now=spl[now].ch[0];
        else
        {
            rk-=lsiz+spl[now].cnt;
            now=spl[now].ch[1];
        }
    }
    return spl[now].val;
}
int getrank(int val)
{
    int now=rt,rk=1;
    while(now)
    {
        if(spl[now].val==val)
        {
            rk+=spl[spl[now].ch[0]].siz;
            splaying(now,0);
            break;
        }
        else if(spl[now].val>val) now=spl[now].ch[0];
        else
        {
            rk+=spl[now].cnt+spl[spl[now].ch[0]].siz;
            now=spl[now].ch[1];
        }
    }
    return rk;
}
int main()
{
    int n;read(n);
    for(re int i=1,op,x;i<=n;i++)
    {
        read(op),read(x);
        if(op==1) ins(x,rt);
        else if(op==2) del(x,rt);
        else if(op==3) printf("%d
",getrank(x));
        else if(op==4) printf("%d
",getnum(x));
        else if(op==5) printf("%d
",getnum(getrank(x)-1));
        else printf("%d
",getnum(getrank(x+1)));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/acm1ruoji/p/11941531.html