二逼平衡树 题解(树套树)

题面

我 想 扇 死 自 己

void up(int x)
    {
        if(x)
        {
            size[x]=cnt[x];//我TM这行忘了
            if(son[x][0])size[x]+=size[son[x][0]];
            if(son[x][1])size[x]+=size[son[x][1]];
        }
    }

4个小时!调一道模板!我敲里码!

上道splay刚因为细节打错浪费了3个小时时间,这次就又**重现了

不多说了,先把splay抄上10遍,手写!

-----------以下是正经题解----------------

第一道树套树:线段树套splay

对于线段树的每一段区间建splay维护这段的信息

在合并时:

排名相加;

前驱取max;

后继取min;

比较麻烦的是查询数值,需要二分答案.

以数值为值域进行二分,不断询问mid的排名来缩小范围。

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=4000005,inf=1e9;
int n,m,a[N];
    int root[N],son[N][3],fa[N],key[N],size[N],type,cnt[N];
    void clear(int x)
    {
        if(!x)return ;
        fa[x]=cnt[x]=son[x][0]=son[x][1]=size[x]=key[x]=0;
    }
    int pre(int k)
    {
        int now=son[root[k]][0];
        while(son[now][1])now=son[now][1];
        return now;
    }
    bool judge(int x)
    {
        return son[fa[x]][1]==x;
    }
    void up(int x)
    {
        if(x)
        {
            size[x]=cnt[x];
            if(son[x][0])size[x]+=size[son[x][0]];
            if(son[x][1])size[x]+=size[son[x][1]];
        }
    }
    void rotate(int x)
    {
        int old=fa[x],oldf=fa[old],lr=judge(x);
        son[old][lr]=son[x][lr^1];
        fa[son[old][lr]]=old;
        son[x][lr^1]=old;
        fa[old]=x;
        fa[x]=oldf;
        if(oldf)son[oldf][son[oldf][1]==old]=x;
        up(old);up(x);
    }
    void splay(int k,int x)
    {
        for(int f;f=fa[x];rotate(x))
            if(fa[f])rotate(judge(x)==judge(f)?f:x);
        root[k]=x;
    }
    void ins(int k,int x)
    {
        if(!root[k])
        {
            type++;
            key[type]=x;
            root[k]=type;
            cnt[type]=size[type]=1;
            fa[type]=son[type][0]=son[type][1]=0;
            return ;
        }
        int now=root[k],f=0;
        while(1)
        {
            if(x==key[now])
            {
                cnt[now]++;
                up(now);
                up(f);
                splay(k,now);
                return ;
            }
            f=now;now=son[now][key[now]<x];
            if(!now)
            {
                type++;
                size[type]=cnt[type]=1;
                son[type][0]=son[type][1]=0;
                son[f][x>key[f]]=type;
                fa[type]=f;
                key[type]=x;
                up(f);splay(k,type);
                return ;
            }
        }
    }
    int getrank(int k,int x)
    {
        int now=root[k],ans=0;
        while(1)
        {
            if(!now)return ans;
            if(x==key[now])return (son[now][0]?size[son[now][0]]:0)+ans;
            else if(x>key[now])
            {
                ans+=(son[now][0]?size[son[now][0]]:0)+cnt[now];
                now=son[now][1];
            }
            else if(x<key[now])now=son[now][0];
        }
    }
    int findpos(int k,int x)
    {
        int now=root[k];
        while(1)
        {
            if(x==key[now])return now;
            else if(x<key[now])now=son[now][0];
            else now=son[now][1];
        }
    }
    int findpre(int k,int x)
    {
        int now=root[k],ans=0;
        while(now)
        {
            if(key[now]<x)
            {
                if(ans<key[now])ans=key[now];
                now=son[now][1];
            }
            else now=son[now][0];
        }
        return ans;
    }
    int findnxt(int k,int x)
    {
        int now=root[k],ans=inf;
        while(now)
        {
            if(key[now]>x)
            {
                if(ans>key[now])ans=key[now];
                now=son[now][0];
            }
            else now=son[now][1];
        }
        return ans;
    }
    void del(int k,int x)
    {
        int now=findpos(k,x);
        splay(k,now);
        if(cnt[root[k]]>1)
        {
            cnt[root[k]]--;
            up(root[k]);
            return ;
        }
        else if(!son[root[k]][0]&&(!son[root[k]][1]))
        {
            clear(root[k]);
            root[k]=0;
            return ;
        }
        int old=root[k];
        if(son[root[k]][0]*son[root[k]][1]==0)
        {
            if(!son[root[k]][0])root[k]=son[root[k]][1];
            else root[k]=son[root[k]][0];
            fa[root[k]]=0;
            clear(old);
            return ;
        }
        int L=pre(k);
        splay(k,L);
        son[root[k]][1]=son[old][1];
        fa[son[old][1]]=root[k];
        clear(old);
        up(root[k]);
    }
    #define ls(k) k<<1
    #define rs(k) k<<1|1
    void update(int k,int l,int r,int pos,int val)
    {
        ins(k,val);
        if(l==r)return ;
        int mid=l+r>>1;
        if(pos<=mid)update(ls(k),l,mid,pos,val);
        else update(rs(k),mid+1,r,pos,val);
        return ;
    }
    int rank(int k,int l,int r,int L,int R,int val)
    {
        if(l>=L&&r<=R)
        {
            int res=getrank(k,val);
            return res;
        }
        int mid=l+r>>1,res=0;
        if(L<=mid)res+=rank(ls(k),l,mid,L,R,val);
        if(R>mid)res+=rank(rs(k),mid+1,r,L,R,val);
        return res;
    }
    void modify(int k,int l,int r,int pos,int val)
    {
        del(k,a[pos]);
        ins(k,val);
        if(l==r)return ;
        int mid=l+r>>1;
        if(pos<=mid)modify(ls(k),l,mid,pos,val);
        else modify(rs(k),mid+1,r,pos,val);
    }
    int getpre(int k,int l,int r,int L,int R,int val)
    {
        if(l>=L&&r<=R)return findpre(k,val);
        int mid=l+r>>1,res=0;
        if(L<=mid)res=max(res,getpre(ls(k),l,mid,L,R,val));
        if(R>mid)res=max(res,getpre(rs(k),mid+1,r,L,R,val));
        return res;
    }
    int getnxt(int k,int l,int r,int L,int R,int val)
    {
        if(l>=L&&r<=R)return findnxt(k,val);
        int mid=l+r>>1,res=inf;
        if(L<=mid)res=min(res,getnxt(ls(k),l,mid,L,R,val));
        if(R>mid)res=min(res,getnxt(rs(k),mid+1,r,L,R,val));
        return res;
    }
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9')
    {if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')
    {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    return x*f;
}
int main()
{
    n=read();m=read();
    int op,maxx=0;
    for(int i=1;i<=n;i++)
    {
        a[i]=read();
        update(1,1,n,i,a[i]);
        maxx=max(maxx,a[i]);
    }
    while(m--)
    {
        op=read();
        if(op==1)
        {
            int l=read(),r=read(),val=read();
            printf("%d
",rank(1,1,n,l,r,val)+1);
        }
        else if(op==2)
        {
            int l=read(),r=read(),val=read();
            int L=0,R=maxx+1;
            while(L!=R)
            {
                int mid=L+R>>1;
                int res=rank(1,1,n,l,r,mid);
                //cout<<"***"<<res<<endl;
                if(res<val)L=mid+1;
                else R=mid;
            }
            printf("%d
",L-1);
        }
        else if(op==3)
        {
            int pos=read(),val=read();modify(1,1,n,pos,val);
            a[pos]=val;
            maxx=max(maxx,val);
        }
        else if(op==4)
        {
            int l=read(),r=read(),val=read();
            printf("%d
",getpre(1,1,n,l,r,val));
        }
        else if(op==5)
        {
            int l=read(),r=read(),val=read();
            printf("%d
",getnxt(1,1,n,l,r,val));
        }
    }
    return 0;
}

好了。

从上面那段简短而狗屁不通的“题解”和几乎是抄来的代码可以看出来,是什么让当时的我那么垃圾。

不求甚解、生搬硬套、懒于思考、依赖题解。

装模作样打个Splay,考场上没板子真的写得出来?

如果像本题一样,把普通平衡树的操作放到区间上,显然是无法只用平衡树维护的。解决区间问题最有力的武器就是线段树,所以考虑线段树套平衡树解决。

对每个线段树区间建一棵平衡树。建树时直接把所有区间都插入该区间的所有元素,单点修改时把沿路的所有线段树区间上的平衡树都进行改动(删除再插入)。

对于剩下的查询操作,求排名显然可以转化为所有区间小于该数的元素个数之和+1,即$( sum (每个区间求排名结果-1)) +1$,前驱应当是所有区间结果的最大值,同理后继就是最小值。

但用相同的方式求K大是不太可行的,考虑牺牲一下时间复杂度进行二分答案,每次二分出一个数check它的排名即可。这样的话是3个$log$。

平衡树使用的是替罪羊树,一是确实好写且容易封装,二是动态开点删点可以避免内存超限。这样就可以直接粗暴地扔到结构体里而不用像Splay一样使用$root[]$数组了。

上面瞎写的东西我没有删。给自己和大家一个警示以及反面典型。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x*f;
}
const int N=1e5+5,inf=2147483647;
const double al=0.7;
int n,m,a[N];
struct Scapegoat
{
    struct node
    {
        node *l,*r;
        int val,size,cnt;
        bool del;
        bool bad()
        {
            return l->cnt>al*cnt+5||r->cnt>al*cnt+5;
        }
        void up()
        {
            size=!del+l->size+r->size;
            cnt=1+l->cnt+r->cnt;
        }
    };
    node *null,**badtag;
    void dfs(node *k,vector<node*> &v)
    {
        if(k==null)return ;
        dfs(k->l,v);
        if(!k->del)v.push_back(k);
        dfs(k->r,v);
        if(k->del)delete k;
    }
    node *build(vector<node*> &v,int l,int r)
    {
        if(l>=r)return null;
        int mid=l+r>>1;
        node *k=v[mid];
        k->l=build(v,l,mid);
        k->r=build(v,mid+1,r);
        k->up();
        return k;
    }
    void rebuild(node* &k)
    {
        vector<node*> v;
        dfs(k,v);
        k=build(v,0,v.size());
    }
    void insert(int x,node* &k)
    {
        if(k==null)
        {
            k=new node;
            k->l=k->r=null;
            k->del=0;
            k->size=k->cnt=1;
            k->val=x;
            return ;
        }
        ++k->size;++k->cnt;
        if(x>=k->val)insert(x,k->r);
        else insert(x,k->l);
        if(k->bad())badtag=&k;
        else if(badtag!=&null)
            k->cnt-=(*badtag)->cnt-(*badtag)->size;
    }
    void ins(int x,node* &k)
    {
        badtag=&null;
        insert(x,k);
        if(badtag!=&null)rebuild(*badtag);
    }
    int getrk(node *now,int x)
    {
        int ans=1;
        while(now!=null)
        {
            if(now->val>=x)now=now->l;
            else
            {
                ans+=now->l->size+!now->del;
                now=now->r;
            }
        }
        return ans;
    }
    int kth(node *now,int x)
    {
        while(now!=null)
        {
            if(!now->del&&now->l->size+1==x)
                return now->val;
            if(now->l->size>=x)now=now->l;
            else
            {
                x-=now->l->size+!now->del;
                now=now->r;
            }
        }
        return -1;
    }
    void erase(node *k,int rk)
    {
        if(!k->del&&rk==k->l->size+1)
        {
            k->del=1;
            --k->size;
            return ;
        }
        --k->size;
        if(rk<=k->l->size+!k->del)erase(k->l,rk);
        else erase(k->r,rk-k->l->size-!k->del);
    }
    node* root;
    Scapegoat()
    {
        null=new node;
        root=null;
    }
}s[N<<3];
#define ls(k) (k)<<1
#define rs(k) (k)<<1|1
void build(int k,int l,int r)
{
    for(int i=l;i<=r;i++)
        s[k].ins(a[i],s[k].root);
    if(l==r)return ;
    int mid=l+r>>1;
    build(ls(k),l,mid);
    build(rs(k),mid+1,r);
}
int askrk(int k,int l,int r,int L,int R,int val)
{
    if(L<=l&&R>=r)return s[k].getrk(s[k].root,val)-1;
    int mid=l+r>>1,res=0;
    if(L<=mid)res+=askrk(ls(k),l,mid,L,R,val);
    if(R>mid)res+=askrk(rs(k),mid+1,r,L,R,val);
    return res;
}
void update(int k,int l,int r,int pos,int val)
{
    s[k].erase(s[k].root,s[k].getrk(s[k].root,a[pos]));
    s[k].ins(val,s[k].root);
    if(l==r)return ;
    int mid=l+r>>1;
    if(pos<=mid)update(ls(k),l,mid,pos,val);
    else update(rs(k),mid+1,r,pos,val);
}
int askpre(int k,int l,int r,int L,int R,int val)
{
    if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val)-1);
    int res=-inf,mid=l+r>>1;
    if(L<=mid)
    {
        int ret=askpre(ls(k),l,mid,L,R,val);
        if(ret==-1)res=max(res,-inf);
        else res=max(res,ret);
    }
    if(R>mid)
    {
        int ret=askpre(rs(k),mid+1,r,L,R,val);
        if(ret==-1)res=max(res,-inf);
        else res=max(res,ret);
    }
    return res;
}
int asknxt(int k,int l,int r,int L,int R,int val)
{
    if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val+1));
    int res=inf,mid=l+r>>1;
    if(L<=mid)
    {
        int ret=asknxt(ls(k),l,mid,L,R,val);
        if(ret==-1)res=min(res,inf);
        else res=min(res,ret);
    }
    if(R>mid)
    {
        int ret=asknxt(rs(k),mid+1,r,L,R,val);
        if(ret==-1)res=min(res,inf);
        else res=min(res,ret);
    }
    return res;
}
int askth(int L,int R,int val)
{
    int l=0,r=1e8,res;
    while(l<=r)
    {
        int mid=l+r>>1;
        if(askrk(1,1,n,L,R,mid)+1<=val)res=mid,l=mid+1;
        else r=mid-1;
    }
    return res;
}

int main()
{
    n=read();m=read();
    for(int i=1;i<=n;i++)
        a[i]=read();
    build(1,1,n);
    while(m--)
    {
        int op=read();
        if(op==1){int l=read(),r=read(),K=read();printf("%d
",askrk(1,1,n,l,r,K)+1);}
        if(op==2){int l=read(),r=read(),K=read();printf("%d
",askth(l,r,K));}
        if(op==3){int pos=read(),K=read();update(1,1,n,pos,K);a[pos]=K;}
        if(op==4){int l=read(),r=read(),K=read();printf("%d
",askpre(1,1,n,l,r,K));}
        if(op==5){int l=read(),r=read(),K=read();printf("%d
",asknxt(1,1,n,l,r,K));}
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Rorschach-XR/p/11019342.html