Size Balanced Tree

具体讲解还是看发明人陈启峰神犇的吧:http://wenku.baidu.com/link?url=Sh3e8rMJ2Pn146yz0_ClcF_bWTu9uwVEuXy8P0y-CwG-2WNmcDRehaUiuOV-4NcVQBQ9Kpwzd-TwMN3uKigQvzYXm2ZC3UPeoLuKv-Hsapa

核心代码:

版本一:(好理解)

void maintain(int &x,int flag)

{

    if(flag)        // right

    {

        // 右孩子的右子树大于左孩子

        if(T[T[T[x].ch[1]].ch[1]].sz > T[T[x].ch[0]].sz)            rotate(x,0);

        // 右孩子的左子树大于左孩子

        else if(T[T[T[x].ch[1]].ch[0]].sz > T[T[x].ch[0]].sz)       rotate(T[x].ch[1],1),rotate(x,0);

        else return;

    }

    else            // left

    {

        // 左孩子的左子树大于右孩子

        if(T[T[T[x].ch[0]].ch[0]].sz > T[T[x].ch[1]].sz)            rotate(x,1);

        // 右孩子的右子树大于右孩子

        else if(T[T[T[x].ch[0]].ch[1]].sz > T[T[x].ch[1]].sz)       rotate(T[x].ch[0],0),rotate(x,1);

        else return;

    }

 

    maintain(T[x].ch[0],false);

    maintain(T[x].ch[1],true);

    maintain(x,false);

    maintain(x,true);

}

版本二:(精简)

#include <cstdio>
#include <cstring>
using namespace std;

const int inf = 1 << 30;
const int maxn = 100000;

int sz[maxn],sn[maxn][2],val[maxn],cnt,root = 0;

int rotate(int &x,int d)
{
    int k = sn[x][d ^ 1];   sn[x][d ^ 1] = sn[k][d];    sn[k][d] = x;
    sz[k] = sz[x];  sz[x] = sz[sn[x][d]] + 1 + sz[sn[x][d ^ 1]];
    x = k;
}

void maintain(int &x,int d)
{
    if(sz[sn[sn[x][d]][d]] > sz[sn[x][d ^ 1]])  rotate(x,d ^ 1);
    else if(sz[sn[sn[x][d]][d ^ 1]] > sz[sn[x][d ^ 1]]) rotate(sn[x][d],d),rotate(x,d ^ 1);
    else    return;
    maintain(sn[x][0],false),   maintain(sn[x][1],true);
    maintain(x,false),          maintain(x,true);
}

void ins(int &x,int v)
{
    if(!x)  {x = ++cnt,val[x] = v,sz[x] = 1,sn[x][0] = sn[x][1] = 0;    return;}
    sz[x] ++;   int d = v >= val[x];
    ins(sn[x][d],v);
    maintain(x,d);
}

int del(int &x,int v)       // 若没有v,则删除其后继
{
    sz[x] --;
    if(v == val[x] || (v < val[x] && !sn[x][0]) || (v > val[x] && !sn[x][1]))
    {
        int tmp = val[x];
        if(!sn[x][0] || !sn[x][1])  x = sn[x][0] + sn[x][1];
        else                        val[x] = del(sn[x][0],inf);
        return tmp;
    }
    else    del(sn[x][v >= val[x]],v);
}

int find(int x,int v)
{
    while(x && v != val[x])
        x = find(sn[x][v >= val[x]],v);
    return x;
}

int rank(int x,int v)
{
    if(!x)  return 1;
    if(v <= val[x]) return rank(sn[x][0],v);
    else            return sz[sn[x][0]] + 1 + rank(sn[x][1],v);
}

int main()
{
    int n,m,t,opt;
    scanf("%d%d",&n,&m);
    for(int i = 0;i < n;i ++)   scanf("%d",&t),ins(root,t);

    printf("root %d
",root);
    for(int i = 1;i <= cnt;i ++)
        printf("T[%d] : ls : %d  rs : %d   val : %d
",i,sn[i][0],sn[i][1],val[i]);

    for(int i = 0;i < m;i ++)
    {
        scanf("%d%d",&opt,&t);
        if(opt == 1)    printf("del %d
",del(root,t));
        else if(opt == 2)   printf("find %d
",val[find(root,t)]);
        else if(opt == 3)   printf("rank %d
",rank(root,t));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/vivym/p/3928028.html