【模板】Treap

Treap,又称树堆,是一种通过堆性质来维持BST平衡的数据结构。具体体现在对于树上每一个点来说,既有BST维护的值,又有一个堆维护的随机生成的值。维护平衡性的办法是根据堆维护的值的相对大小关系进行左旋和右旋这两种操作,在旋转的前后,依然满足BST性质

latest updated:2019.2.25
代码如下

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
const int inf=0x3f3f3f3f;

struct node{
    #define ls(p) t[p].lc
    #define rs(p) t[p].rc
    int lc,rc,val,rd,cnt,size;
}t[maxn];
int tot,root;
inline int newnode(int val){
    ++tot,t[tot].val=val,t[tot].rd=rand(),t[tot].size=t[tot].cnt=1;
    return tot;
}
inline void pushup(int p){
    t[p].size=t[ls(p)].size+t[rs(p)].size+t[p].cnt;
}
inline void zig(int &p){
    int lson=ls(p);
    ls(p)=rs(lson),rs(lson)=p,p=lson;
    pushup(rs(p)),pushup(p);
}
inline void zag(int &p){
    int rson=rs(p);
    rs(p)=ls(rson),ls(rson)=p,p=rson;
    pushup(ls(p)),pushup(p);
}
void insert(int &p,int val){
    if(!p)p=newnode(val);
    else if(val==t[p].val)++t[p].size,++t[p].cnt;
    else if(val<t[p].val){
        ++t[p].size,insert(ls(p),val);
        if(t[ls(p)].rd>t[p].rd)zig(p);
    }else{
        ++t[p].size,insert(rs(p),val);
        if(t[rs(p)].rd>t[p].rd)zag(p);
    }
}
void remove(int &p,int val){
    if(!p)return;
    else if(val<t[p].val)--t[p].size,remove(ls(p),val);
    else if(val>t[p].val)--t[p].size,remove(rs(p),val);
    else{
        if(t[p].cnt>1)--t[p].cnt,--t[p].size;
        else if(!ls(p)||!rs(p))p=ls(p)+rs(p);
        else if(t[ls(p)].rd>t[rs(p)].rd)zig(p),remove(p,val);
        else zag(p),remove(p,val);
    }
}
int kth(int p,int k){
    if(k<=t[ls(p)].size)return kth(ls(p),k);
    else if(k>t[ls(p)].size+t[p].cnt)return kth(rs(p),k-t[ls(p)].size-t[p].cnt);
    else return t[p].val;
}
int getrank(int p,int val){
    if(t[p].val==val)return t[ls(p)].size+1;
    else if(val<t[p].val)return getrank(ls(p),val);
    else return getrank(rs(p),val)+t[p].cnt+t[ls(p)].size;
}
int getpre(int p,int val){
	if(!p)return -inf;
	else if(t[p].val>=val)return getpre(ls(p),val);
	else return max(getpre(rs(p),val),t[p].val);
}
int getnxt(int p,int val){
	if(!p)return inf;
	else if(t[p].val<=val)return getnxt(rs(p),val);
	else return min(getnxt(ls(p),val),t[p].val);
}

int main(){
    int n;scanf("%d",&n);
    while(n--){
        int opt,val;
        scanf("%d%d",&opt,&val);
        switch(opt){
            case 1:insert(root,val);break;
            case 2:remove(root,val);break;
            case 3:printf("%d
",getrank(root,val));break;
            case 4:printf("%d
",kth(root,val));break;
            case 5:printf("%d
",getpre(root,val));break;
            case 6:printf("%d
",getnxt(root,val));break;
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/wzj-xhjbk/p/9939823.html