[bzoj] 3224 Tyvj 1728 普通平衡树 || 平衡树板子题

#include<cstdio>
#define N 100010
#define which(x) (ls[f[(x)]]==(x))
using namespace std;
int n,m,ls[N],rs[N],val[N],sze[N],cnt[N],f[N],root,idx;

int read()
{
    int ans=0,fu=1;
    char j=getchar();
    for (;j<'0' || j>'9';j=getchar()) if (j=='-') fu=-1;
    for (;j>='0' && j<='9';j=getchar()) ans*=10,ans+=j-'0';
    return ans*fu;
}

void updt(int x)
{
    sze[x]=sze[ls[x]]+sze[rs[x]]+1;
}

void Rotate(int u)
{
    int v=f[u],w=f[v],b=which(u)?rs[u]:ls[u];
    if (w) which(v)?rs[w]=u:ls[w]=u;
    which(u)?(ls[v]=b,rs[u]=v):(rs[v]=b,ls[u]=v);
    f[u]=w,f[v]=u;
    if (b) f[b]=v;
    updt(v),updt(u);
}

void Splay(int x,int tar)
{
    while (f[x]!=tar)
    {
	if (f[f[x]]!=tar)
	{
	    if (which(f[x])==which(x)) Rotate(f[x]);
	    else Rotate(x);
	}
	Rotate(x);
    }
    if (!tar) root=x;
}

int find(int x)
{
    int u=root,v=0;
    while (u && val[u]!=x)
    {
	v=u;
	if (x<val[u]) u=ls[u];
	else u=rs[u];
    }
    return u?u:v;
}

void insert(int x)
{
    int u=root,v=0;
    while (u && val[u]!=x)
    {
	v=u;
	if (x<=val[u]) u=ls[u];
	else u=rs[u];
    }
    if (u && val[u]==x)
	return (void)(cnt[u]++,sze[u]++,Splay(u,0));
    f[++idx]=v;
    sze[idx]=1;
    val[idx]=x;
    if (v) x<=val[v]?ls[v]=idx:rs[v]=idx;
    Splay(idx,0);
}

int getmn(int x)
{
    while (ls[x]) x=ls[x];
    return x;
}

int getmx(int x)
{
    while (rs[x]) x=rs[x];
    return x;
}

void erase(int x)
{
    int tmp=find(x);
    Splay(tmp,0);
    if (cnt[tmp]>1) cnt[tmp]--,sze[tmp]--;
    else if (!ls[tmp] || !rs[tmp]) root=ls[tmp]+rs[tmp],f[root]=0;
    else
    {
	f[ls[tmp]]=0;
	int u=getmx(ls[tmp]);
	Splay(u,0);
	rs[u]=rs[tmp];
	f[rs[tmp]]=u;
	updt(u);
    }
}

int getkth(int k)
{
    int cur=root;
    while (cur)
    {
	if (sze[ls[cur]]>=k) cur=ls[cur];
	else if (sze[ls[cur]]+cnt[cur]>=k) return val[cur];
	else k-=sze[ls[cur]]+cnt[cur],cur=rs[cur];
    }
    return val[cur];
}

int getrank(int x)
{
    int cur=find(x);
    Splay(cur,0);
    return sze[ls[cur]]+1;
}

int getpre(int x)
{
    int cur=find(x);
    if (val[cur]<x) return val[cur];
    Splay(cur,0);
    return val[getmx(ls[cur])];
}

int getnxt(int x)
{
    int cur=find(x);
    if (val[cur]>x) return val[cur];
    Splay(cur,0);
    return val[getmn(rs[cur])];
}

int main()
{
    n=read();
    for (int i=1,op,a;i<=n;i++)
    {
	op=read();
	a=read();
	printf("%d %d
",op,a);
	if (op==1) insert(a);
	if (op==2) erase(a);
	if (op==3) printf("%d
",getrank(a));
	if (op==4) printf("%d
",getkth(a));
	if (op==5) printf("%d
",getpre(a));
	if (op==6) printf("%d
",getnxt(a));
    }
    return 0;
}
原文地址:https://www.cnblogs.com/mrha/p/8157645.html