暑假学习日记:Splay树

    从昨天开始我就想学这个伸展树了,今天花了一个上午2个多小时加下午2个多小时,学习了一下伸展树(Splay树),学习的时候主要是看别人博客啦~发现下面这个博客挺不错的http://zakir.is-programmer.com/posts/21871.html.在里面有连接到《运用伸展树解决数列维护问题》的文章,里面对伸展树的旋转操作讲得很仔细,而且也讲清楚了伸展树是怎么样维护一个数列的,一开始我是小白,觉得树和数列根本没什么关系,但看了之后就会明白,实际上树上的结点是维护该结点的值的,而这个值是原来数列里的哪一项呢?如果该结点对应的中序遍历数k,那么就是对应原数列a中的a[k]这一项.理解了这个之后我就豁然开朗了,要提取一个区间[a,b],实际上只需要将a-1,Splay为根,b+1Splay到根下的右儿子,则根下的右儿子的左儿子就是[a,b]这个区间,这是由平衡树,左小右大的性质决定的.

    所以无论做什么,首先是将该区间提取出来,然后对对应结点做就好了.问题是有时a-1不存在,b+1也不存在,所以一开始人为的做两个头尾的结点.而且很多时候为了避免对NULL的特殊处理,我们会构造一个实的null,让它的sz=0;sum=0;这样就不会影响一些情况的处理

    伸展树的特性是可以反转,注意到,一棵树,如果我们将它的每个结点的左右儿子都互换一次,它的中序遍历就刚好是原来的中序遍历倒过来,利用这个性质可以实现序列反转.而且还可以添加,假如要添加一个串{b1,b2,b3,b4..}在ak之后的位置,首先调出[ak,ak+1]这个区间,然后将{b1,b2...}建一棵伸展树,然后将结点粘在root->ch[1]的左儿子上即可.删除则是同理.我我还可以提取一个区间出来,反转,再加到我想加的地方.操作都是类似的.

    伸展树的优势除了它支持上面的操作外,它还兼容线段树的add,set操作,同样也是每个结点存lazy标记就可以了,然后写一个类似线段树的pushDown,pushUp,维护好区间的信息就可以了~

    下面给出的代码很大程度上(90%)是从上面网站的代码上copy下来的,将它改成自己的习惯的变量名,然后自己多写了一个add标记,原来的代码还能求最大子段和,但加了add之后再求就有点麻烦了,所以就删掉了原本维护最大子段和的代码,自己写了个驱动程序,调了一下感觉还行.

    代码的参数设置可能会不同,像add()函数传的是从哪个位置(pos),加多少个(tot),大可直接写成l,r,传参数的姿势不同罢了,但注意的是,当要在l位置开始加的时候,传进去的是l+1,是因为前面的头指针占了一位,看到输出之后就大概明白为什么要加1了.对了,因为区间的标记的lazy的,所以直接中序遍历得不出实际的序列(因为有些标记没往下传),所以写了个maintain()先把所有标记下传,实际上是不需要的,随用随查就好了,这么写是为了方便debug~

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<vector>
#define INF 0x3fffffff
#define maxn 500000
using namespace std;

struct Node
{
    Node *pre,*ch[2];
    bool rev,cov; // 结点翻转标记与cover标记
    int add; // 结点add标记(表示加了多少)
    int sz,val,sum; // 结点的size,保存的值,以及以该结点为子树的和
}*root,N[maxn],*null; // 定义了根的指针,人手写的空的指针,以及结点数组N
Node *stack[maxn]; // 用一个栈来回收用过的指针,这是学到的新姿势,这样的话在构造新的结点的时候可以不用一直idx++
int top,idx; // 栈顶指针,以及数组idx指针
int a[maxn+20]; // 用来构造伸展树的数组

Node *addNode(int val) // 产生新结点
{
    Node *p;
    if(top) p=stack[--top]; // 首先从回收栈里取
    else p=&N[idx++]; // 没有的话从N里面取
    //初始化
    p->rev=p->cov=false; 
    p->sz=1;
    p->sum=p->val=val;
    p->ch[0]=p->ch[1]=p->pre=null;
    return p;
}

void Recycle(Node *p) // 递归回收删除掉的指针,这是用来节省空间的
{
    if(p->ch[0]!=null) Recycle(p->ch[0]);
    if(p->ch[1]!=null) Recycle(p->ch[1]);
    stack[++top]=p;
}

void pushDown(Node *p) // 核心函数,用来处理标记的
{
    if(p==null||!p) return; // 遇到空指针返回
    if(p->rev)  // 先处理反转标记
    {
        swap(p->ch[0],p->ch[1]); // 交换子树
        if(p->ch[0]!=null) p->ch[0]->rev^=1; // 标记下传
        if(p->ch[1]!=null) p->ch[1]->rev^=1; // 标记下传
        p->rev=false; // 标记取消
    }
    if(p->cov) //下面的cov和add标记的更新与下传与线段树相同
    {
        if(p->ch[0]!=null){
            p->ch[0]->val=p->val;
            p->ch[0]->sum=p->val*p->ch[0]->sz;
            p->ch[0]->cov=true;
            p->ch[0]->add=0;
        }
        if(p->ch[1]!=null){
            p->ch[1]->val=p->val;
            p->ch[1]->sum=p->val*p->ch[1]->sz;
            p->ch[1]->cov=true;
            p->ch[1]->add=0;
        }
        p->cov=false;
    }
    if(p->add)
    {
        if(p->ch[0]!=null){
            p->ch[0]->val+=p->add;
            p->ch[0]->sum+=p->ch[0]->sz*p->add;
            p->ch[0]->add+=p->add;
        }
        if(p->ch[1]!=null){
            p->ch[1]->val+=p->add;
            p->ch[1]->sum+=p->ch[1]->sz*p->add;
            p->ch[1]->add+=p->add;
        }
        p->add=0;
    }
}

void pushUp(Node *p) // 核心函数,维护信息
{
    if(p==null) return;
    pushDown(p); 
    pushDown(p->ch[0]);
    pushDown(p->ch[1]);
    p->sz=p->ch[0]->sz+p->ch[1]->sz+1;
    p->sum=p->val+p->ch[0]->sum+p->ch[1]->sum;
}

void rotate(Node *x,int c) // Splay树的旋转函数,标准姿势
{
    Node *y=x->pre;
    pushDown(y);pushDown(x);
    y->ch[c^1]=x->ch[c];
    if(x->ch[c]!=null)
        x->ch[c]->pre=y;
    x->pre=y->pre;
    if(y->pre!=null)
        if(y->pre->ch[0]==y)
            y->pre->ch[0]=x;
        else
            y->pre->ch[1]=x;
    x->ch[c]=y;y->pre=x;
    if(y==root) root=x;
    pushUp(y);
}

void Splay(Node *x,Node *f) // 将x结点转到f下
{
    pushDown(x);
    while(x->pre!=f)
    {
        Node *y=x->pre,*z=y->pre;
        if(x->pre->pre==f)
            rotate(x,x->pre->ch[0]==x);
        else
        {
            if(z->ch[0]==y){
                if(y->ch[0]==x) {rotate(y,1);rotate(x,1);}
                else {rotate(x,0);rotate(x,1);}
            }
            else{
                if(y->ch[1]==x) {rotate(y,0),rotate(x,0);}
                else {rotate(x,1),rotate(x,0);}
            }
        }
    }
    pushUp(x);
}

Node *select(int kth) // 选出第k个点,返回对应结点
{
    int tmp;
    Node *t=root;
    while(1){
        pushDown(t);
        tmp=t->ch[0]->sz;
        if(tmp+1==kth) break;
        if(kth<=tmp) {t=t->ch[0];}
        else { kth-=tmp+1;t=t->ch[1];}
    }
    return t;
}

Node *build(int L,int R) // 建树,有点像线段树
{
    if(L>R) return null;
    int M=(L+R)>>1;
    Node *p=addNode(a[M]);
    p->ch[0]=build(L,M-1);
    if(p->ch[0]!=null){
        p->ch[0]->pre=p;
    }
    p->ch[1]=build(M+1,R);
    if(p->ch[1]!=null){
        p->ch[1]->pre=p;
    }
    pushUp(p);
}

void remove(int pos,int tot) // 从pos位置开始,删除tot个(包括pos)
{
    Splay(select(pos-1),null);
    Splay(select(pos+tot),root);
    if(root->ch[1]->ch[0]!=null){
        Recycle(root->ch[1]->ch[0]);
        root->ch[1]->ch[0]=null;
    }
    pushUp(root->ch[1]);pushUp(root);
    Splay(root->ch[1],null);
}

void insert(int pos,int tot) // 添加,插的是一个数组的时候,要在数组a里面建一颗树,即a[1~N]是要插的数
{
    Node *troot=build(1,tot);
    Splay(select(pos),null);
    Splay(select(pos+1),root);
    root->ch[1]->ch[0]=troot;
    troot->pre=root->ch[1];
    pushUp(root->ch[1]);pushUp(root);
    Splay(troot,null); 
}

void reverse(int pos,int tot) // 从pos开始翻转tot个
{
    Splay(select(pos-1),null);
    Splay(select(pos+tot),root);
    if(root->ch[1]->ch[0]!=null)
    {
        root->ch[1]->ch[0]->rev^=1;
        Splay(root->ch[1]->ch[0],null);
    }
}

void set(int pos,int tot,int c) // 从pos开始将tot个设置为c
{
    Splay(select(pos-1),null);
    Splay(select(pos+tot),root);
    root->ch[1]->ch[0]->val=c;
    root->ch[1]->ch[0]->sum=root->ch[1]->ch[0]->sz*c;
    root->ch[1]->ch[0]->cov=true;
    Splay(root->ch[1]->ch[0],null);
}

void add(int pos,int tot,int c) // 从pos开始将tot个加c
{
    Splay(select(pos-1),null);
    Splay(select(pos+tot),root);
    root->ch[1]->ch[0]->val+=c;
    root->ch[1]->ch[0]->sum+=c*root->ch[1]->ch[0]->sz;
    root->ch[1]->ch[0]->add+=c;
    Splay(root->ch[1]->ch[0],null);
}

int query(int pos,int tot) // 求pos开始tot个的和
{
    Splay(select(pos-1),null);
    Splay(select(pos+tot),root);
    return root->ch[1]->ch[0]->sum;
}

void init() // 初始化函数
{
    idx=top=0; // idx,top归零
    null=addNode(-INF); // 初始化空指针
    null->sz=null->sum=0; // 记住sz和sum一定要设为0
    root=addNode(-INF); // 初始化根指针
    root->sum=0;
    Node *p;
    p=addNode(-INF); // 初始化"树尾"的指针
    root->ch[1]=p;
    p->pre=root;
    p->sum=0;
    pushUp(root->ch[1]);
    pushUp(root);
}
//下面三个函数是调试的时候用的
void maintain(Node *p) // 因为标记是lazy的,所以先将所有标记都下传好
{
    pushDown(p);
    if(p->ch[0]!=null) maintain(p->ch[0]);
    if(p->ch[1]!=null) maintain(p->ch[1]);
}
void dfs(Node *x) // 中序遍历
{
    if(x==null) return;
    dfs(x->ch[0]);
    printf("%d ",x->val);
    dfs(x->ch[1]);
}
void print() // 打印
{
    maintain(root); 
    dfs(root);
    puts("");
}

int main()
{
    int n,m;
    while(cin>>n)
    {
        for(int i=1;i<=n;i++){
            scanf("%d",&a[i]);
        }
        init();
        Node *troot=build(1,n); // 从a数组建一颗Splay树
        root->ch[1]->ch[0]=troot; // 让它和init()里的root,p连上
        troot->pre=root->ch[1]; 
        pushUp(root->ch[1]); // 维护相关信息
        pushUp(root->ch[0]);
        cin>>m;
        int o,l,r,v;
        //支持六种操作,区间add,区间set,区间反转,区间删除,区间添加,区间和
        while(m--)
        {
            scanf("%d",&o);
            if(o==1){
                scanf("%d%d%d",&l,&r,&v);add(l+1,r-l+1,v);print();
            }
            else if(o==2){
                scanf("%d%d%d",&l,&r,&v);set(l+1,r-l+1,v);print();
            }
            else if(o==3){
                scanf("%d%d",&l,&r);reverse(l+1,r-l+1);print();
            }
            else if(o==4){
                scanf("%d%d",&l,&r);remove(l+1,r-l+1);print();
            }
            else if(o==5){
                scanf("%d%d",&l,&v);
                for(int i=1;i<=v;i++){ scanf("%d",&a[i]);}
                insert(l+1,v);
                print();
            }
            else if(o==6){
                scanf("%d%d",&l,&r);
                cout<<query(l+1,r-l+1)<<endl;
            }
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/chanme/p/3272830.html