二逼平横树——线段树套平衡树

注意空间大小,以及建树时的细节

#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
const int maxn = 5000000 + 3;
const int N = 50000+3;
const int maxm = 100000000 + 3;
int root[maxn], n, spare[N], A[N], m, nex[N],bucket[N];
int ch[maxn][2],f[maxn],siz[maxn],num[maxn],val[maxn], cnt;
int arr[N];
inline int get(int x) { return ch[f[x]][1] == x; }
inline int ls(int x)  { return ch[x][0]; }
inline int rs(int x)  { return ch[x][1]; }
struct Splay_Tree
{
    inline void pushup(int x)
    {
        siz[x] = num[x] + siz[ls(x)]+ siz[rs(x)];
    }
    void build(int l,int r,int& o,int fa)
    {
        if(l > r)return;
        o = ++cnt;
        int mid = (l+r) >> 1;
        val[o] = arr[mid], f[o] = fa, num[o] = siz[o] = bucket[mid];
        if(mid - 1 >= l) build(l,mid-1,ch[o][0],o);
        if(mid + 1 <= r) build(mid+1,r,ch[o][1],o);
        pushup(o);
    }
    inline int remove_x(int x)
    {
        ch[x][0] = ch[x][1] = f[x] = 0;
        val[x] = num[x] = siz[x] = 0;
    }
    inline void rotate(int x)
    {
        int old = f[x], oldf = f[old], which = get(x);
        ch[old][which] = ch[x][which^1], f[ch[old][which]] = old;
        ch[x][which^1] = old, f[old] = x, f[x] = oldf;
        if(oldf) ch[oldf][ch[oldf][1] == old] = x;
        pushup(old); pushup(x); 
    }
    inline void splay(int x,int &tar)
    {
        int a = f[tar];
        for(int fa; (fa = f[x]) != a ;rotate(x))
            if(f[fa] != a) rotate(get(x) == get(fa) ? fa : x);
        tar = x;
    }
    inline int query_rank(int x,int ty)
    {
        int p = root[ty];
        int tmp = 0;
        while(val[p] != x && p) 
        {
            if(x>val[p])tmp += (siz[ls(p)]+num[p]);
            p = ch[p][x>val[p]];
        }
        tmp += siz[ls(p)];
        if(p == 0)return tmp;
        splay(p,root[ty]);
        return siz[ls(root[ty])];
    }
    inline int query_num(int x,int ty)
    {                   
        int p = root[ty];
        while(p)
        {
            if(x<=siz[ls(p)])p=ls(p);
            else 
            {
                x-=(siz[ls(p)]+num[p]);
                if(x<=0)
                {
                    splay(p,root[ty]); 
                    return val[p];
                }
                p=rs(p);
            }
        }
        return 0;
    }
    inline void insert_x(int x,int ty)
    {
        if(!root[ty])
        {
            root[ty] = ++cnt, val[root[ty]] = x, num[root[ty]]=1 ;
            pushup(root[ty]); return;
        }
        int p = root[ty], fa;
        while(val[p] != x && p) fa = p,p = ch[p][x > val[p]];
        if(p)
        {
            ++ num[p];
            pushup(p);
            splay(p,root[ty]);
        }
        else
        {
            ++ cnt, num[cnt] = 1, val[cnt] = x;
            ch[fa][x>val[fa]] = cnt, f[cnt] = fa;
            pushup(cnt);
            splay(cnt, root[ty]);
        }
    }
    inline void delete_x(int x,int ty)
    {
        int p = root[ty];
        while(val[p] != x && p) p = ch[p][x>val[p]];
        if(p == 0)return;
        splay(p,root[ty]);
        if(num[root[ty]] > 1){
            --num[root[ty]]; 
            pushup(root[ty]);
            return;
        }
        int a = root[ty];
        if(!ls(root[ty]) && !rs(root[ty])){ root[ty] = 0; remove_x(a);return;}
        else if(!ls(root[ty])){ root[ty] = rs(root[ty]); remove_x(a);  return;}
        else if(!rs(root[ty])){ root[ty] = ls(root[ty]); remove_x(a);return;}
        else
        {
            p = ls(root[ty]);
            while(rs(p))p = rs(p);
            splay(p, ch[root[ty]][0]);
            ch[p][1]=ch[root[ty]][1];
            f[ch[p][1]]=p,f[p]=0;
            pushup(p);
            root[ty]=p;
            remove_x(a);
        }
    }
    inline int pre_x(int x,int ty)
    {
        int ans = -maxm , p = root[ty];
        while(p)
        {
            if(val[p] < x) ans = val[p], p = ch[p][1];
            else p = ch[p][0];
        }
        return ans;
    }
    inline int aft_x(int x,int ty)
    {
        int ans = maxm, p = root[ty];
        while(p){
            if(val[p] > x) ans = val[p],p = ch[p][0];
            else p = ch[p][1];
        }
        return ans;
    }
}T;
void build_Tree(int l,int r,int o)
{
    if(l>r)return;
    int j = 0;
    for(int i = l;i <= r;++i)spare[++j] = A[i];
    sort(spare+1,spare+1+j);
    nex[j] = j+1;
    for(int i = j-1;i >= 1;--i)
        nex[i] = spare[i] == spare[i+1] ? nex[i+1] : i+1;
    int pos = 0;
    for(int i=1;i<=j;i=nex[i])
        arr[++pos] = spare[i], bucket[pos] = nex[i] - i;
    T.build(1,pos,root[o],0);
    if(l == r)return;
    int mid = (l+r) >> 1, ls = o<<1, rs = (o<<1)|1;
    build_Tree(l, mid, ls );
    build_Tree(mid+1, r,rs);
}
void update(int l,int r,int pos,int origin,int new_num, int o)
{
    if(l > r)return;
    T.delete_x(origin,o);
    T.insert_x(new_num,o);
    if(l == r) return ;
    int mid = (l+r) >> 1, ls = o<<1, rs = (o<<1)|1;
    if( mid >= pos) update(l, mid, pos ,origin, new_num, ls);
    else update(mid+1, r, pos,origin, new_num, rs);
}
int query_rank(int l,int r,int L,int R,int num,int o)                          
{
    if(l>=L&& r<=R) return T.query_rank(num,o);
    int mid = (l+r) >> 1, ls = o<<1, rs = (o<<1)|1;
    int tmp = 0;
    if(L <= mid)tmp += query_rank(l,mid,L,R,num,ls);
    if(R > mid) tmp += query_rank(mid+1,r,L,R,num,rs);
    return tmp;
}
int query_num(int L,int R,int k)
{
    int l = 1, r = maxm,  fin = 0, ans;
    while(l<=r)
    {
        int mid = (l+r) >> 1;
        ans = query_rank(1,n,L,R,mid,1)+1;
        if(ans <= k) l = mid + 1, fin = mid;
        else r  = mid - 1;
    }
    return fin;
}
int pre_x(int l,int r,int L,int R,int num,int o)
{
    int ans = -maxm;
    if(l>=L && r<=R) return T.pre_x(num,o);
    int mid = (l+r) >> 1, ls = o<<1, rs = (o<<1)|1;
    if(L <= mid) ans = max(ans,pre_x(l,mid,L,R,num,ls));
    if(R >  mid) ans = max(ans,pre_x(mid+1,r,L,R,num,rs));
    return ans;
}
int aft_x(int l,int r,int L,int R,int num,int o)
{
    int ans = maxm;
    if(l>=L && r<=R)return T.aft_x(num,o);
    int mid = (l+r) >> 1, ls = (o<<1), rs = (o<<1)|1;
    if(L<=mid)ans = min(ans,aft_x(l,mid,L,R,num,ls));
    if(R>mid) ans = min(ans,aft_x(mid+1,r,L,R,num,rs));
    return ans;
}
int main()
{
  //  freopen("in.txt","r",stdin);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d",&A[i]);
    build_Tree(1,n,1);
    for(int i = 1; i <= m;++i)
    {
        int opt,l,r,k,pos;
        scanf("%d",&opt);
        if(opt == 1)
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d
",query_rank(1,n,l,r,k,1)+1);
        }
        if(opt == 2)
        {
            scanf("%d%d%d",&l,&r,&k);
            printf("%d
",query_num(l,r,k));
        }
        if(opt == 3)
        {
            scanf("%d%d",&pos,&k);
            update(1,n,pos,A[pos],k,1);
            A[pos] = k;
        }
        if(opt == 4)
        {
            scanf("%d%d%d",&l,&r,&k);
            int ans = pre_x(1,n,l,r,k,1);
            if(ans == -maxm)printf("-2147483647
");
            else printf("%d
",ans);
        }
        if(opt == 5)
        {
            scanf("%d%d%d",&l,&r,&k);
            int ans = aft_x(1,n,l,r,k,1);
            if(ans == maxm)printf("2147483647
");
            else printf("%d
",ans);
        }
    }
    return 0;
}

  

原文地址:https://www.cnblogs.com/guangheli/p/9845220.html