qbzt day3 晚上 平衡树的一些思想

pks大佬的blog

二叉查找树

任何一个节点左子树的所有元素都小于这个节点,右子树的所有元素都大于这个节点

查找一个节点:从根节点开始,比他小就向左走,比他大就向右走

 

平衡树:解决二叉查找树的一些痛点。

二叉查找树的问题:它的形态并不固定,查找非常依赖于深度

通过一种叫做伸展的操作,让树的深度不那么深

那么什么是伸展?

伸展操作基于一个元操作:旋转(rotate)

如果一个节点之前被访问过,那么之后访问到它的几率会变大

通过旋转把这个点移到根,使下一次访问到它只需要o(1)的时间

Splay操作:把一个点旋转到根或者旋转到某个点下面

有x,y,z三个点,z是y的父亲,y是x的父亲

x,y,z三个点,如果在一条直线上就先转y再转x(让树的深度-1)

如果不在一条直线上,就转两次x

 

 

fa记录每一个点的父亲,ch记录节点的两个儿子(有可能没有)

函数:son(x)表示x是它父亲的左儿子还是右儿子

      Rotate旋转(其实可以说抬升)

  Splay将x的节点转到i的位置上

Cnt表示当前节点的数出现了多少次

Data表示当前节点的数是什么

Size表示当前节点及其子树中共有多少个数
pushup:左节点size+右节点size加上自己的cnt

插入:将x插入到rt节点中

如果x<data[rt],就往左子树插

如果x>data[rt],往右子树插

如果x=data[rt],就cnt[rt]++,size[rt]++,直接return掉

如果rt=0,也就是说这个位置没有出现过,就新建一个位置,rt=++nn(总的节点数),data[rt]=x,size[rt]=cnt[rt]=1,return掉

找前驱(小于x最大的数)和后继(大于x最小的数)

Getpre 按照子节点大小关系往两边找前驱

Getaux 同上,找后继

Getmn 找一颗子树中的最小值 只需要不停往左子树跳就可以了

Delete 删除节点

  先找这个点

  如果在左边,就往左边删除,如果在右边就往右边删除

  如果找到了这个点,分情况。如果不只一个数,就cnt[rt]--,size[rt]--就可以了

  否则,先把rt转到根,在右子树找最小的元素

如果没有右子树就让根变成他的左儿子

否则就让右儿子最小的节点转到根,因为最小的节点一定没有左儿子

Getk 查询x的排名

现在树中查找这个节点

找到之后,把他转到根,看他左子树有多少个数,+1就是3的排名0

Getkth 找到第k个数

如果左儿子的size+1<=k并且左儿子的size加上当前节点的数>k,就意味着第k个数一定是这个节点

如果k<左儿子的size+1,就往左节点找

否则就往右儿子找。注意要把左儿子的size和自己的cnt减去

int fa[N],ch[N][2];
int cnt[N];
int data[N];
int siz[N];

int son(int x)
{
    return x==ch[fa[x]][1];
}

void pushup(int rt)
{
    siz[rt]=siz[ch[rt][0]]+siz[ch[rt][1]]+cnt[rt];
}

void rotate(int x){
    int y=fa[x],z=fa[y],b=son(x),c=son(y),a=ch[x][!b];
    if(z) ch[z][c]=x; else root=x; fa[x]=z;
    if(a) fa[a]=y; ch[y][b]=a;
    ch[x][!b]=y;fa[y]=x;
    pushup(y);pushup(x);
}

void splay(int x,int i){
    while(fa[x]!=i){
        int y=fa[x],z=fa[y];
        if(z==i){
            rotate(x);
        }else{
            if(son(x)==son(y)){
                rotate(y);rotate(x);
            }else{
                rotate(x);rotate(x);
            }
        }
    }
}

void insert(int &rt,int x){
    if(rt==0){
        rt=++nn;
        data[rt]=x;
        siz[rt]=cnt[rt]=1;
        return;
    }
    if(x==data[rt]){
        cnt[rt]++;
        siz[rt]++;
        return;
    }
    if(x<data[rt]){
        insert(ch[rt][0],x);
        fa[ch[rt][0]]=rt;
        pushup(rt);
    }else{
        insert(ch[rt][1],x);
        fa[ch[rt][1]]=rt;
        pushup(rt);
    }
}

int getpre(int rt,int x){
    int p=rt,ans;
    while(p){
        if(x<=data[p]){
            p=ch[p][0];
        }else{
            ans=p;
            p=ch[p][1];
        }
    }
    return ans;
}

int getsuc(int rt,int x){
    int p=rt,ans;
    while(p){
        if(x>=data[p]){
            p=ch[p][1];
        }else{
            ans=p;
            p=ch[p][0];
        }
    }
    return ans;
}

int getmn(int rt){
    int p=rt,ans=-1;
    while(p){
        ans=p;
        p=ch[p][0];
    }
    return ans;
}

void del(int rt,int x){
    if(data[rt]==x){
        if(cnt[rt]>1){
            cnt[rt]--;
            siz[rt]--;
        }else{
            splay(rt,0);
            int p=getmn(ch[rt][1]);
            if(p==-1){
                root=ch[rt][0];
                fa[ch[rt][0]]=0;
            }else{
                splay(p,rt);
                root=p;fa[p]=0;
                ch[p][0]=ch[rt][0];fa[ch[rt][0]]=p;
                pushup(p);
            }
        }
        return;
    }
    if(x<data[rt]){
        del(ch[rt][0],x);
    }else{
        del(ch[rt][1],x);
    }
    pushup(rt);
}

int getk(int rt,int k){
    if(data[rt]==k){
        splay(rt,0);
        if(ch[rt][0]==0){
            return 1;
        }else{
            return siz[ch[rt][0]]+1;
        }
    }
    if(k<data[rt]) return getk(ch[rt][0],k);
    if(data[rt]<k) return getk(ch[rt][1],k);
}

int getkth(int rt,int k){
    int l=ch[rt][0];
    if(siz[l]+1<=k&&k<=siz[l]+cnt[rt]) return data[rt];
    if(k<siz[l]+1) return getkth(l,k);
    else return getkth(ch[rt][1],k-siz[l]-cnt[rt]);
}
原文地址:https://www.cnblogs.com/lcezych/p/11190737.html