splay树

https://www.luogu.org/problemnew/show/P3369

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int sum=0,x=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-')
			x=0;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
		sum=(sum<<1)+(sum<<3)+(ch^48),ch=getchar();
	return x?sum:-sum;
}
inline void write(int x){
	if(x<0)
		putchar('-'),x=-x;
	if(x>9)
	 	write(x/10);
	putchar(x%10+'0');
}
const int M=2e5+5;
const int inf=0x3f3f3f3f;
int ch[M][2],val[M],cnt[M],fa[M],sz[M],ncnt,root;
int  ck(int x){
	return ch[fa[x]][1]==x;
}
void up(int x){
	sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
void rotate(int x){
	int y=fa[x];
	int z=fa[y];
	int k=ck(x);
	int w=ch[x][k^1];
	ch[y][k]=w,fa[w]=y;
	ch[z][ck(y)]=x,fa[x]=z;
	ch[x][k^1]=y,fa[y]=x;
	up(y),up(x);
}
void splay(int x,int goal=0){//将x旋转为goal的儿子,如果goal是0则旋转到根
	while(fa[x]!=goal){//一直旋转到x成为goal的儿子
		int y=fa[x],z=fa[y];//父节点祖父节点
		if(z!=goal)//如果Y不是根节点,则分为两类来旋转(如果随意旋转会达不到查询效果) 
			ck(x)^ck(y)?rotate(x):rotate(y);
		rotate(x);//无论怎么样最后的一个操作都是旋转x
	}
	if(!goal)
		root=x;//如果goal是0,则将根节点更新为x
}
void find(int x){//查找x的位置,并将其旋转到根节点 
	if(!root)//树空 
		return ;
	int cur=root;
	while(ch[cur][x>val[cur]]&&x!=val[cur])//当存在儿子并且当前位置的值不等于x 
		cur=ch[cur][x>val[cur]];//跳转到儿子,查找x的父节点
	splay(cur);//把当前位置旋转到根节点
}
void insertt(int x){//插入x
	int cur=root,p=0;//当前位置cur,cur的父节点p 
	while(cur&&val[cur]!=x){//当u存在并且没有移动到当前的值
		p=cur;//向下u的儿子,父节点变为u
		cur=ch[cur][x>val[cur]];//大于当前位置则向右找,否则向左找
	}
	if(cur)//存在这个值的位置
		cnt[cur]++;//增加一个数
	else{//不存在这个数字,要新建一个节点来存放
		cur=++ncnt;//新节点的位置
		if(p)//如果父节点非根
			ch[p][x>val[p]]=cur;//不存在儿子
		fa[cur]=p;//父节点
		ch[cur][0]=ch[cur][1]=0;
		val[cur]=x;//值
		cnt[cur]=sz[cur]=1;//数量,大小 
	}
	splay(cur);//把当前位置移到根,保证结构的平衡。注意前面因为更改了子树大小,所以这里必须Splay上去进行pushup保证size的正确。
}
int kth(int k){
	int u=root;
	while(true){
		if(ch[u][0]&&k<=sz[ch[u][0]])
			u=ch[u][0];
		else if(k>sz[ch[u][0]]+cnt[u])
			k-=sz[ch[u][0]]+cnt[u],u=ch[u][1];
		else
			return u;	
	}
}
int  pre(int x){//查找x的前驱(0)或者后继(1)
	find(x);
	if(val[root]<x)//如果当前节点的值小于x并且要查找的是前驱
		return root;
	int u=ch[root][0];//前驱再左儿子上找 
	while(ch[u][1])
		u=ch[u][1];
	return u;//返回位置 
}
int succ(int x){
	find(x);
	if(val[root]>x)//如果当前节点的值大于x并且要查找的是后继
		return root;
	int u=ch[root][1];//后继在右儿子上找 
	while(ch[u][0])
		u=ch[u][0];
	return u;
}
void remove(int x){
	int las=pre(x),nex=succ(x);
	splay(las,0),splay(nex,las);
	//将前驱旋转到根节点,后继旋转到根节点下面
    //很明显,此时后继是前驱的右儿子,x是后继的左儿子,并且x是叶子节点
	int del=ch[nex][0];//后继的左儿子
	if(cnt[del]>1)//如果超过一个
		cnt[del]--,splay(del,0);//直接减少一个
	else
		ch[nex][0]=0;//这个节点直接丢掉(不存在了)
}
void init(){
	memset(ch,0,sizeof(cnt));
	memset(cnt,0,sizeof(cnt));
	memset(sz,0,sizeof(sz));
	memset(val,0,sizeof(val));
	memset(fa,0,sizeof(fa));
	ncnt=root=0;
}
int main(){
	
	
	int n;
	while(~scanf("%d",&n)){
		init();
		insertt(inf);
		insertt(-inf);
		while(n--){
			
			int op=read(),x=read();
			if(op==1)
				insertt(x);
			else if(op==2)
				remove(x);
			else if(op==3)
				find(x),write(sz[ch[root][0]]),putchar('
');
			else if(op==4)
				write(val[kth(x+1)]),putchar('
');
			else if(op==5)
				write(val[pre(x)]),putchar('
');
			else
				write(val[succ(x)]),putchar('
');
		}
		putchar('
');
	}
	return 0;
}  

模板题:https://www.luogu.org/problem/P3380

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5;
const int M=2e6+6;
const int inf=2147483647;
inline int read(){
    int sum=0,x=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            x=0;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
        sum=(sum<<1)+(sum<<3)+(ch^48),ch=getchar();
    return x?sum:-sum;
}
inline void write(int x){
    if(x<0)
        putchar('-'),x=-x;
    if(x>9)
        write(x/10);
    putchar(x%10+'0');
}
int ans,p,l,r,k,tot,tmp,top,a[N],fa[M],ch[M][2],sz[M],val[M];
struct tree{
    int l,r,root;
}tree[N];
void pushup(int x){
    if(x)
        sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+1;
}
int leftmin(int x){
    if(!ch[x][0])
        return x;
    return leftmin(ch[x][0]);
}
void newnode(int &x,int f,int da){
    if(top)
        x=top,top=0;
    else x=++tot;
    fa[x]=f;
    ch[x][0]=ch[x][1]=0;
    val[x]=da;
    sz[x]=1;
}
void rotate(int x,int k){
    int y=fa[x];
    ch[y][!k]=ch[x][k];
    fa[ch[y][!k]]=y;
    ch[fa[y]][ch[fa[y]][1]==y]=x;
    fa[x]=fa[y];
    ch[x][k]=y;
    fa[y]=x;
    pushup(y);
}
void splay(int x,int goal,int &root){
    while(fa[x]!=goal){
        int y=fa[x],z=fa[y],f=(ch[z][0]==y);
        if(z==goal)rotate(x,ch[y][0]==x);else{
            if(ch[y][f]==x)rotate(x,!f);else rotate(y,f);
            rotate(x,f);
        }
    }
    pushup(x);
    if(!goal)root=x;
}
void kthfind(int x,int key){
    if(!x)return;
    if(key<=val[x])kthfind(ch[x][0],key);else{
        tmp+=sz[ch[x][0]]+1;
        kthfind(ch[x][1],key);
    }
}
int findit(int x,int key){
    if(key<val[x])return findit(ch[x][0],key);
    if(key==val[x])return x;
    return findit(ch[x][1],key);
}
void insert(int &root,int key){
    int x=root;
    if(!root){newnode(root,0,key);return;}
    while(ch[x][key>val[x]])sz[x]++,x=ch[x][key>val[x]];
    sz[x]++;
    newnode(ch[x][key>val[x]],x,key);
    splay(ch[x][key>val[x]],0,root);
}
void del(int &root,int key){
    int x=findit(root,key);
    splay(x,0,root);
    top=root;
    if(!ch[x][1]){
        fa[ch[x][0]]=0;
        root=ch[x][0];
    }else{
        splay(leftmin(ch[x][1]),root,root);
        ch[ch[root][1]][0]=ch[root][0];
        fa[ch[root][1]]=0;
        fa[ch[root][0]]=ch[root][1];
        root=ch[root][1];
    }
    pushup(root);
}
void pre(int x,int key){
    if(!x)return;
    if(val[x]<key){if(val[x]>ans)ans=val[x];pre(ch[x][1],key);}
    else pre(ch[x][0],key);
}
void suc(int x,int key){
    if(!x)return;
    if(val[x]>key){if(val[x]<ans)ans=val[x];suc(ch[x][0],key);}
    else suc(ch[x][1],key);
}
void change(int k,int x,int key,int last){
    del(tree[k].root,last);
    insert(tree[k].root,key);
    int l=tree[k].l,r=tree[k].r,mid=(l+r)>>1;
    if(l==r)return;
    if(x<=mid)change(k<<1,x,key,last);else change(k<<1|1,x,key,last);
}
void getk(int k,int x,int y,int key){
    int l=tree[k].l,r=tree[k].r,mid=(l+r)>>1;
    if(x==l&&r==y)kthfind(tree[k].root,key);
    else{
        if(x>mid)getk(k<<1|1,x,y,key);
        else if(y<=mid)getk(k<<1,x,y,key);
        else {getk(k<<1,x,mid,key);getk(k<<1|1,mid+1,y,key);}
    }
}
void getpre(int k,int x,int y,int key){
    int l=tree[k].l,r=tree[k].r,mid=(l+r)>>1;
    if(x==l&&r==y)
        pre(tree[k].root,key);
    else{
        if(x>mid)
            getpre(k<<1|1,x,y,key);
        else if(y<=mid)
            getpre(k<<1,x,y,key);
        else 
            getpre(k<<1,x,mid,key),getpre(k<<1|1,mid+1,y,key);
    }
}
void getsuc(int k,int x,int y,int key){
    int l=tree[k].l,r=tree[k].r,mid=(l+r)>>1;
    if(x==l&&r==y)
        suc(tree[k].root,key);
    else{
        if(x>mid)
            getsuc(k<<1|1,x,y,key);
        else if(y<=mid)
            getsuc(k<<1,x,y,key);
        else 
            getsuc(k<<1,x,mid,key),getsuc(k<<1|1,mid+1,y,key);
    }
}
 
void buildtree(int k,int l,int r){
    tree[k].l=l;tree[k].r=r;char ch;
    for(int i=l;i<=r;i++)
        insert(tree[k].root,a[i]);
    if(l==r)
        return;
    int mid=(l+r)>>1;
    buildtree(k<<1,l,mid);
    buildtree(k<<1|1,mid+1,r);
}
int main(){
    int n=read(),m=read();
    //cout<<inf<<endl;
    
    for(int i=1;i<=n;i++)
        a[i]=read();
    buildtree(1,1,n);
    while(m--){
        p=read();
        if(p!=3){
            l=read(),r=read(),k=read();
            if(p==1){
                tmp=1;getk(1,l,r,k);
                printf("%d
",tmp);
            }else if(p==2){
                int L=0;
                int R=inf;
                int ans=R;
                while(L<=R){
                    int mid=(L+R)>>1;
                    tmp=1;getk(1,l,r,mid);
                    if(tmp<=k)L=mid+1,ans=mid;else R=mid-1;
                }
                printf("%d
",ans);
            }else if(p==4){
                ans=-inf;
                getpre(1,l,r,k);
                printf("%d
",ans);
            }else if(p==5){
                ans=inf;
                getsuc(1,l,r,k);
                printf("%d
",ans);
            }
        }else{
            int pos=read();
            k=read();
            change(1,pos,k,a[pos]);
            a[pos]=k;
        }
    }
    return 0;
}
View Code

Buy Tickets

http://poj.org/problem?id=2828
题意:给定拆入的位置和值,然后将其序列输出;
考虑用splay树去维护
对于splay树中的一个节点来说左子树上的每一个节点的值一定比该节点小,所以我们把这层关系对应于位置的先后次序:对于一个节点来说,他的左子树的每一个节点一定在这个节点在序列位置的左边,即一定比他先输出,右子树则类似。
然后最终答案即该splay树的中序遍历

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int M=2e5+5;
int ch[M][2],sz[M],cnt[M],fa[M],ncnt,root;
int val[M];
int ck(int x){
    return ch[fa[x]][1]==x;
}
void up(int x){
    sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
void rotate(int x){
    int y=fa[x];
    int z=fa[y];
    int k=ck(x);
    int w=ch[x][k^1];
    ch[z][ck(y)]=x,fa[x]=z;
    ch[y][k]=w,fa[w]=y;
    ch[x][k^1]=y,fa[y]=x;
    up(x),up(y);
}
void splay(int x,int goal=0){
    
    while(fa[x]!=goal){
        int y=fa[x],z=fa[y];
        if(z!=goal)
            ck(x)^ck(y)?rotate(x):rotate(y);
        rotate(x);
    }
    if(!goal)
        root=x;
}
int kth(int k){
    int u=root;
    while(true){
        if(ch[u][0]&&k<=sz[ch[u][0]])
            u=ch[u][0];
        else if(k>sz[ch[u][0]]+cnt[u])
            k-=sz[ch[u][0]]+cnt[u],u=ch[u][1];
        else
            return u;
    }
}
void newnode(int &cur,int f,int a){
    cur=++ncnt;
    fa[cur]=f;
    sz[cur]=cnt[cur]=1;
    val[cur]=a;
    ch[cur][0]=ch[cur][1]=0;
}
void insert(int x,int y){
    int p=0;
    if(!root){
        newnode(root,0,y);
        return;
    }
    if(!x){//如果要插入的节点要为head,就将之前的整棵树spaly为该节点的左子树
        p=root;
        sz[p]++;
        while(ch[p][0])
            p=ch[p][0],sz[p]++;
        newnode(ch[p][0],p,y);
        splay(ch[p][0]);
        return ;
    }
    int u=kth(x);
    splay(u);
    newnode(root,0,y);
    ch[root][1]=ch[u][1];
    fa[ch[u][1]]=root;
    ch[u][1]=0;
    ch[root][0]=u;
    fa[u]=root;
    up(u),up(root);
    
}
void output(int x){
    if(ch[x][0])
        output(ch[x][0]);
    printf("%d ",val[x]);
    if(ch[x][1])
        output(ch[x][1]);
}
int main(){
    int n;
    while(~scanf("%d",&n)){
        ncnt=root=0;
        while(n--){
            int id,x;
            scanf("%d%d",&id,&x);
            insert(id,x);
        }
        output(root);
        putchar('
');
    }
    return 0;
}
View Code

 https://www.lydsy.com/JudgeOnline/problem.php?id=1588

输入n个数,每读入一个数,在前面输入的数中找到一个与该数相差最小的一个。把所有的差值加起来得到答案。

前驱:小于当前值的最大值

后继:大于当前值的最小值

其实没必要用splay,直接set找要插入数的左右值比较大小即可

#include<bits/stdc++.h>
using namespace std;
const int M=4e4+4;
const int inf=0x3f3f3f3;
int val[M],fa[M],ch[M][2],cnt[M],ncnt,root;
int ck(int x){
    return ch[fa[x]][1]==x;
}
void rotate(int x){
    int y=fa[x],z=fa[y],k=ck(x);
    int w=ch[x][k^1];
    ch[y][k]=w,fa[w]=y;
    ch[z][ck(y)]=x,fa[x]=z;
    ch[x][k^1]=y,fa[y]=x;
}
int splay(int x,int goal=0){
    while(fa[x]!=goal){
        int y=fa[x],z=fa[y];
        if(z!=goal)
            ck(x)^ck(y)?rotate(x):rotate(y);
        rotate(x);
    }
    if(!goal)
        root=x;
}
int insertt(int x){
    int cur=root,p=0;
    while(cur&&x!=val[cur]){
        p=cur;
        cur=ch[cur][x>val[cur]];
    }
    if(cur){
        cnt[cur]++;
        splay(cur);
        return 1;
    }
    else{
        cur=++ncnt;
        if(p)
            ch[p][x>val[p]]=cur;
        ch[cur][0]=ch[cur][1]=0;
        cnt[cur]=1;
        val[cur]=x;
        fa[cur]=p;
    }
    splay(cur);
    return 0;
}
void find(int x){
    if(!root)
        return ;
    int cur=root;
    while(ch[cur][x>val[cur]]&&x!=val[cur]){
        cur=ch[cur][x>val[cur]];
    }
    splay(cur);
}
int pre(int x){
    find(x);
    if(val[root]<x)
        return root;
    int cur=ch[root][0];
    while(ch[cur][1])
        cur=ch[cur][1];
    return cur;
}
int succ(int x){
    find(x);
    if(val[root]>x)
        return root;
    int cur=ch[root][1];
    while(ch[cur][0])
        cur=ch[cur][0];
    return cur;
}
int main(){
    int n,ans=0;
    scanf("%d",&n);
    insertt(inf);
    insertt(-inf);
    for(int i=1;i<=n;i++){
        int x;
        scanf("%d",&x);
        if(i==1){
            ans+=x;
            insertt(x);
            continue;
        }
        int sign=insertt(x);
        if(sign==1)
            continue;
        int y=inf;
        int p=val[pre(x)],q=val[succ(x)];
        if(p!=-inf)
            y=min(y,x-p);
        if(q!=inf)
            y=min(y,q-x);
        ans+=y;
    }
    printf("%d",ans);
    return 0;
}
View Code

 http://poj.org/problem?id=2828

按排名插入值的写法

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int M=2e5+5;
int ch[M][2],fa[M],sz[M],cnt[M],val[M],root,ncnt;
int ck(int x){
    return ch[fa[x]][1]==x;
}
void push_up(int x){
    sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
void Rotate(int x){
    int y=fa[x],z=fa[y],k=ck(x);
    int w=ch[x][k^1];
    ch[y][k]=w,fa[w]=y;
    ch[z][ck(y)]=x,fa[x]=z;
    ch[x][k^1]=y,fa[y]=x;
    push_up(x);push_up(y);
}
void splay(int x,int goal=0){
    while(fa[x]!=goal){
        int y=fa[x],z=fa[y];
        if(z!=goal)
            ck(x)^ck(y)?Rotate(x):Rotate(y);
        Rotate(x);
    }
    if(!goal)
        root=x;
}
void newnode(int &cur,int f,int y){
    cur=++ncnt;
    fa[cur]=f;
    ch[cur][0]=ch[cur][1]=0;
    sz[cur]=cnt[cur]=1;
    val[cur]=y;
}
int Kth(int k){
    int cur=root;
    while(true){
        if(ch[cur][0]&&k<=sz[ch[cur][0]])
            cur=ch[cur][0];
        else if(k>sz[ch[cur][0]]+cnt[cur])
            k-=sz[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
        else
            return cur;
    }
}
void insertt(int x,int y){
    if(!root){///当树为空时
        newnode(root,0,y);
        return ;
    }
    if(!x){///当要插入排名为首个时
        int p=root;
        sz[p]++;
        while(ch[p][0])
            p=ch[p][0],sz[p]++;
        newnode(ch[p][0],p,y);
        splay(ch[p][0]);
        return ;
    }
    int u=Kth(x);
    splay(u);
    newnode(root,0,y);
 /*   ch[root][0]=u,fa[u]=root;
    ch[u][1]=0;
    ch[root][1]=ch[u][1];
    fa[ch[root][1]]=root;
*/

    ch[root][1] = ch[u][1];
    fa[ch[root][1]] = root;
    ch[u][1] = 0;
    ch[root][0] = u;
    fa[u] = root;
    push_up(u),push_up(root);
    return ;
}
void print(int cur){
    if(ch[cur][0])
        print(ch[cur][0]);
    printf("%d ",val[cur]);
    if(ch[cur][1])
        print(ch[cur][1]);
}
int main(){
    int n;
    while(~scanf("%d",&n)){
        root=ncnt=0;
        for(int i=1;i<=n;i++){
            int x,y;
            scanf("%d%d",&x,&y);///x是要插入的排名,y是对应的值
            insertt(x,y);
        }
        print(root);
        printf("
");
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/starve/p/10883512.html