模板——Splay

$Splay$

#include <bits/stdc++.h>
#define inf (int)1e9
using namespace std;
const int N=1e5+100;
int n,tot,root,val[N],sz[N],son[N][2];
int fa[N],sf[N],re[N];
int newnode()
{
    return tot++;
}
void connect(int x,int y,int dir)
{
    son[y][dir]=x;
    fa[x]=y;
    sf[x]=dir;
}
void pushup(int x)
{
    sz[x]=sz[son[x][0]]+sz[son[x][1]]+re[x];
}
void clear()
{
    val[0]=son[0][0]=son[0][1]=sz[0]=re[0]=fa[0]=sf[0]=0;
}
void rotate(int x)
{
    int f,gf,xd,fd,s;
    f=fa[x];gf=fa[f];
    xd=sf[x];fd=sf[f];
    s=son[x][xd^1];
    connect(x,gf,fd);connect(f,x,xd^1);connect(s,f,xd);
    clear();
    pushup(f);pushup(x);
}
void splay(int x,int y)
{
    while (fa[x]!=y)
    {
        if (fa[fa[x]]==y)
          rotate(x);
        else
        if (sf[fa[x]]==sf[x])
        {
            rotate(fa[x]);
            rotate(x);
        }
        else
        {
            rotate(x);
            rotate(x);
        }
    }
    if (y==0)
      root=x;
}
void find(int v)
{
    int cur=root;
    while (son[cur][v>val[cur]] && val[cur]!=v)
      cur=son[cur][v>val[cur]];
    splay(cur,0);
}
int per(int v)
{
    find(v);
    if (val[root]<v) return root;
    int cur=son[root][0];
    while (son[cur][1]) cur=son[cur][1];
    return cur;
}
int suc(int v)
{
    find(v);
    if (val[root]>v) return root;
    int cur=son[root][1];
    while (son[cur][0]) cur=son[cur][0];
    return cur;
}
void insert(int v)
{
    
    if (root==0)
    {
        int x=newnode();
        val[x]=v;sz[x]=1;
        root=x;
        return;
    }
    int cur=root;
    while (son[cur][v>val[cur]] && val[cur]!=v)
      cur=son[cur][v>val[cur]];
    if (val[cur]==v)
    {
        re[cur]++;sz[cur]++;
        splay(cur,0);
        return;
    }
    int x=newnode();
    val[x]=v;re[x]=sz[x]=1;
    connect(x,cur,v>val[cur]);
    pushup(cur);
    splay(x,0);
}
void del(int v)
{
    int p,s;
    p=per(v);s=suc(v);
    splay(p,0);
    splay(s,p);
    re[son[s][0]]--;sz[son[s][0]]--;
    if (re[son[s][0]]==0)
      son[s][0]=0;
    pushup(s);pushup(p);
}
int rk(int v)
{
    find(v);
    return sz[son[root][0]];
}
int kth(int x,int k)
{
    if (k>sz[son[x][0]]+re[x])
      return kth(son[x][1],k-sz[son[x][0]]-re[x]);
    if (k<=sz[son[x][0]])
      return kth(son[x][0],k);
    return x;
}
int main()
{
    tot=1;
    insert(inf);insert(-inf);
    scanf("%d",&n);
    for (int i=1;i<=n;i++)
    {
        int op,x;
        scanf("%d%d",&op,&x);
        if (op==1) insert(x);
        if (op==2) del(x);
        if (op==3) printf("%d
",rk(x));
        if (op==4) printf("%d
",val[kth(root,x+1)]);
        if (op==5) printf("%d
",val[per(x)]);
        if (op==6) printf("%d
",val[suc(x)]);
    }
}

维护数列

#include <bits/stdc++.h>
#define inf (int)1e9
using namespace std;
const int N=500100;
int n,m,tot,root,a[N];
struct node
{
    int sz,val,sum,res,lx,rx,tx,vc;
    int son[2],fa,sf;
}sh[N+100];
queue <int> q;
int newnode()
{
    int x;
    if (tot>=N)
    {
        x=q.front();
        q.pop();
    }
    else
      x=tot++;
    sh[x].sz=sh[x].val=sh[x].sum=sh[x].res=sh[x].lx=sh[x].rx=sh[x].tx=sh[x].vc=0;
    sh[x].son[0]=sh[x].son[1]=sh[x].sf=sh[x].fa=0;
    sh[x].vc=inf;
    return x;
}
void clear(int x)
{
    q.push(x);
    if (sh[x].son[0]) clear(sh[x].son[0]);
    if (sh[x].son[1]) clear(sh[x].son[1]);
}
void connect(int x,int y,int dir)//x->y
{
    sh[y].son[dir]=x;
    sh[x].fa=y;
    sh[x].sf=dir;
}
void pushdown(int x)
{
    int ls,rs;
    ls=sh[x].son[0];rs=sh[x].son[1];
    if (sh[x].res==1)
    {
        if (ls)
        {
            sh[sh[ls].son[0]].sf^=1;sh[sh[ls].son[1]].sf^=1;
            swap(sh[ls].son[0],sh[ls].son[1]);
            swap(sh[sh[ls].son[0]].lx,sh[sh[ls].son[0]].rx);
            swap(sh[sh[ls].son[1]].lx,sh[sh[ls].son[1]].rx);
            sh[ls].res^=1;
        }
        if (rs)
        {
            sh[sh[rs].son[0]].sf^=1;sh[sh[rs].son[1]].sf^=1;
            swap(sh[rs].son[0],sh[rs].son[1]);
            swap(sh[sh[rs].son[0]].lx,sh[sh[rs].son[0]].rx);
            swap(sh[sh[rs].son[1]].lx,sh[sh[rs].son[1]].rx);
            sh[rs].res^=1;
        }
        sh[x].res=0;
    }
    if (sh[x].vc!=inf)
    {
        if (ls)
        {
            sh[ls].val=sh[ls].vc=sh[x].vc;
            sh[ls].sum=sh[ls].val*sh[ls].sz;
            sh[ls].tx=max(sh[ls].val,sh[ls].val*sh[ls].sz);
            sh[ls].lx=sh[ls].rx=max(0,sh[ls].val*sh[ls].sz);
        }
        if (rs)
        {
            sh[rs].val=sh[rs].vc=sh[x].vc;
            sh[rs].sum=sh[rs].val*sh[rs].sz;
            sh[rs].tx=max(sh[rs].val,sh[rs].val*sh[rs].sz);
            sh[rs].lx=sh[rs].rx=max(0,sh[rs].val*sh[rs].sz);
        }
        sh[x].vc=inf;
    }
}
void pushup(int x)
{
    int ls,rs;
    ls=sh[x].son[0];rs=sh[x].son[1];
    sh[x].sz=sh[ls].sz+sh[rs].sz+1;
    sh[x].sum=sh[ls].sum+sh[rs].sum+sh[x].val;
    sh[x].lx=max(sh[ls].lx,sh[ls].sum+sh[rs].lx+sh[x].val);
    sh[x].rx=max(sh[rs].rx,sh[rs].sum+sh[ls].rx+sh[x].val);
    sh[x].tx=max(sh[ls].tx,max(sh[rs].tx,sh[ls].rx+sh[rs].lx+sh[x].val));
}
void rotate(int x)
{
    int fa,gf,son,xd,fd;
    fa=sh[x].fa;gf=sh[sh[x].fa].fa;
    xd=sh[x].sf;fd=sh[fa].sf;
    son=sh[x].son[xd^1];
    connect(x,gf,fd);connect(fa,x,xd^1);connect(son,fa,xd);
    sh[0].fa=sh[0].son[0]=sh[0].son[1]=sh[0].sf=0;
    pushup(fa);pushup(x);
}
void splay(int x,int y)
{
    while (sh[x].fa!=y)
    {
        if (sh[sh[x].fa].fa==y)
          rotate(x);
        else
        if (sh[x].sf==sh[sh[x].fa].sf)
        {
            rotate(sh[x].fa);
            rotate(x);
        }
        else
        {
            rotate(x);
            rotate(x);
        }
    }
    if (y==0)
      root=x;
}
int find(int x,int k)
{
    pushdown(x);
    if (k>sh[sh[x].son[0]].sz+1)
      return find(sh[x].son[1],k-sh[sh[x].son[0]].sz-1);
    if (k<=sh[sh[x].son[0]].sz)
      return find(sh[x].son[0],k);
    return x;
}
int build(int l,int r,int father,int dir)
{
    int mid=(l+r)>>1;
    int x=newnode();
    sh[x].val=a[mid];
    sh[x].fa=father;sh[x].sf=dir;
    if (l==r)
    {
        sh[x].sz=1;
        sh[x].sum=sh[x].tx=a[mid];
        sh[x].lx=sh[x].rx=max(0,a[mid]);
        return x;
    }
    if (l<=mid-1)
      sh[x].son[0]=build(l,mid-1,x,0);
    if (r>=mid+1)
      sh[x].son[1]=build(mid+1,r,x,1);
    pushup(x);
    return x;
}
int split(int l,int r)
{
    int per,suc;
    per=find(root,l);suc=find(root,r+2);
    splay(per,0);
    splay(suc,per);
    return sh[suc].son[0];
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
      scanf("%d",&a[i]);
    tot=1;sh[0].tx=a[0]=a[n+1]=-inf;
    root=build(0,n+1,0,0);
    while (m--)
    {
        char ch[20];
        scanf("%s",ch);
        if (ch[0]=='I')
        {
            int pos,tot;
            scanf("%d%d",&pos,&tot);
            for (int i=1;i<=tot;i++)
              scanf("%d",&a[i]);
            int x=build(1,tot,0,0);
            int A,B;
            A=find(root,pos+1);B=find(root,pos+2);
            splay(A,0);splay(B,A);
            connect(x,B,0);
            pushup(B);pushup(A);
        }
        if (ch[0]=='D')
        {
            int pos,tot;
            scanf("%d%d",&pos,&tot);
            int per,suc;
            per=find(root,pos);suc=find(root,pos+tot+1);
            splay(per,0);
            splay(suc,per);
            clear(sh[suc].son[0]);
            sh[suc].son[0]=0;
            pushup(suc);pushup(per);
        }
        if (ch[0]=='M' && ch[2]=='K')
        {
            int pos,tot,c;
            scanf("%d%d%d",&pos,&tot,&c);
            int x=split(pos,pos+tot-1);
            sh[x].val=sh[x].vc=c;
            sh[x].sum=c*sh[x].sz;
            sh[x].tx=max(sh[x].val,c*sh[x].sz);
            sh[x].lx=sh[x].rx=max(0,c*sh[x].sz);
            pushup(sh[x].fa);pushup(sh[sh[x].fa].fa);
        }
        if (ch[0]=='R')
        {
            int pos,tot;
            scanf("%d%d",&pos,&tot);
            int x=split(pos,pos+tot-1);
            sh[x].res^=1;
            int ls,rs;
            ls=sh[x].son[0];rs=sh[x].son[1];
            sh[ls].sf^=1;sh[rs].sf^=1;
            swap(sh[x].son[0],sh[x].son[1]);
            swap(sh[ls].lx,sh[ls].rx);
            swap(sh[rs].lx,sh[rs].rx);
            pushup(x);pushup(sh[x].fa);pushup(sh[sh[x].fa].fa);
        }
        if (ch[0]=='G')
        {
            int pos,tot;
            scanf("%d%d",&pos,&tot);
            int x=split(pos,pos+tot-1);
            printf("%d
",sh[x].sum);
        }
        if (ch[0]=='M' && ch[2]=='X')
        {
            printf("%d
",sh[root].tx);
        }
    }
}
原文地址:https://www.cnblogs.com/huangchenyan/p/11216762.html