【封装】Splay

注意确保操作合法性,否则可能陷入死循环

以点权作为排序依据

struct Splay{   
    #define ls p[u].son[0]
    #define rs p[u].son[1]
    #define maxn 100000

    int root, cnt;
    struct Node{
        int val, fa, size, sum;
        int son[2]; 
    }p[maxn];
    
    inline void destroy(int u){
        p[u].val = ls = rs = p[u].fa = p[u].sum = p[u].size = 0;
    }

    inline int identify(int u){
        return p[p[u].fa].son[1] == u;
    }

    inline void update(int u){
        if(u)   p[u].sum = p[ls].sum + p[rs].sum + p[u].size;
    }

    void rotate(int u){
        int f = p[u].fa,  gf = p[f].fa,  sta = identify(u),  sta_f = identify(f);
        p[f].son[sta] = p[u].son[sta ^ 1];
        p[p[f].son[sta]].fa = f;
        p[u].son[sta^1] = f,  p[f].fa = u,  p[u].fa = gf;
        p[gf].son[sta_f] = u;
        update(f);
    }

    void splay(int u, int goal){
        for(int f; (f = p[u].fa) && (f != goal); rotate(u))
            if(p[f].fa != goal)  rotate(identify(u) == identify(f) ? f : u);
        if(!goal)  root = u;
        update(u);
    }

    void insert(int u){     // 函数结束后权值为u的节点变为根节点
        if(!root){
            p[++cnt].val = u;
            p[cnt].size = p[cnt].sum = 1;
            root = cnt;
            return ;
        }
        int now = root,  f = 0;
        while(true){
            if(u == p[now].val){
                ++p[now].size;
                splay(now, 0);
                return ;
            }
            f = now,  now = p[now].son[p[now].val < u];
            if(!now){
                p[++cnt].val = u;
                p[cnt].size = p[cnt].sum = 1;
                p[cnt].fa = f,  p[f].son[p[f].val < u] = cnt;
                ++p[f].sum;
                splay(cnt, 0);
                return ;
            }
        }
    }

    int find_val(int rank){
        int now = root;
        while(true){
            if(p[now].son[0] && rank <= p[p[now].son[0]].sum)
                now = p[now].son[0];
            else{
                int temp = p[p[now].son[0]].sum + p[now].size;
                if(rank <= temp)   return p[now].val;
                now = p[now].son[1],  rank -= temp;
            }
        }
    }

    int find_rank(int u){     // 函数结束后权值为u的节点是根节点
        int now = root,  rank = 0;
        while(true){
            if(u < p[now].val)    now = p[now].son[0];
            else{
                rank += p[p[now].son[0]].sum;
                if(u == p[now].val){
                    splay(now, 0);
                    return rank + 1;
                }
                rank += p[now].size,  now = p[now].son[1];
            }
        }
    }

    int find_pre(int x){    // 返回x前驱节点编号
        insert(x);
        int now = p[root].son[0];
        while(p[now].son[1])  now = p[now].son[1];
        delete_val(x);
        return now;
    }

    int find_suffix(int x){     // 返回x后继节点编号
        insert(x);
        int now = p[root].son[1];
        while(p[now].son[0])  now = p[now].son[0];
        delete_val(x);
        return now;
    }

    void delete_val(int u){
        find_rank(u);   // 将权值为u的节点旋转到根节点
        if(p[root].size > 1){
            --p[root].size,  --p[root].sum;
            return ;
        }
        if(!p[root].son[0] && !p[root].son[1]){
            destroy(root),  root = 0;
            return ;
        }
        int old_root = root;
        if(!p[root].son[0]){
            root = p[root].son[1];
            p[root].fa = 0;
            destroy(old_root);
            return ;
        }
        if(!p[root].son[1]){
            root = p[root].son[0];
            p[root].fa = 0;
            destroy(old_root);
            return ;
        }
        int left_max = find_pre(u);
        splay(left_max, 0);
        p[root].son[1] = p[old_root].son[1];
        p[p[old_root].son[1]].fa = root;
        destroy(old_root);
        update(root);
    }
}splay;

以序列下标作为排序依据,常用于需要增删序列的操作(区间平衡树)

struct Splay{   
    #define ls p[u].son[0]
    #define rs p[u].son[1]
    #define maxn N
    static const int inf = 1e9;

	// 因为有虚点直接初始化了 
    int root = 1, top = 0, temp = 5e5 + 45;
    int id[N], c[N], cnt[N];
    
    struct Node{
        int fa, size, len, reset, sum;	// size是子树大小,len是除去虚点的子树大小
        int tag, val, l_max, r_max, mid_max;
        int son[2];
        Node() { reset = inf; } 
    }p[maxn];
	
    inline int identify(int u){
        return p[p[u].fa].son[1] == u;
    }
	
	// 空间回收 
	void destroy(int u){
		if(!u)  return ;
		if(ls)  destroy(ls);
		if(rs)  destroy(rs);
		p[u] = p[temp];
		id[++top] = u; 
	}
	
    inline void update(int u){
        p[u].size = p[ls].size + p[rs].size + 1;
	    p[u].len = p[ls].len + p[rs].len + (u > 2);  // 判断u > 2是为了除去虚点的影响 
        p[u].sum = p[ls].sum + p[rs].sum + p[u].val;
        p[u].l_max = max(p[ls].l_max, p[ls].sum + p[u].val + p[rs].l_max);
        p[u].r_max = max(p[rs].r_max, p[rs].sum + p[u].val + p[ls].r_max);
        p[u].mid_max = max(p[u].val + p[ls].r_max + p[rs].l_max, max(p[ls].mid_max, p[rs].mid_max));
    }
	
	void change(int u, int val){
		p[u].val = p[u].reset = val;
		p[u].sum = p[u].val * p[u].len;
		p[u].l_max = p[u].r_max = max(0, p[u].sum);
		p[u].mid_max = max(val, p[u].sum);	
	}
	
    inline void pushdown(int u){
        if(p[u].reset != inf){
        	if(ls)  change(ls, p[u].reset);
        	if(rs)  change(rs, p[u].reset);
			p[u].reset = inf,  p[u].tag = 0;
        }
        if(p[u].tag){
            if(ls)  p[ls].tag ^= 1, swap(p[ls].son[0], p[ls].son[1]), swap(p[ls].l_max, p[ls].r_max);
            if(rs)  p[rs].tag ^= 1, swap(p[rs].son[0], p[rs].son[1]), swap(p[rs].l_max, p[rs].r_max);
			p[u].tag = 0;
        }
    }

    void rotate(int u){
        int f = p[u].fa,  gf = p[f].fa,  sta = identify(u),  sta_f = identify(f);
        p[f].son[sta] = p[u].son[sta ^ 1];
        p[p[f].son[sta]].fa = f;
        p[u].son[sta^1] = f,  p[f].fa = u,  p[u].fa = gf;
        p[gf].son[sta_f] = u;
        update(f);
    }

    void splay(int u, int goal){
        for(int f; (f = p[u].fa) && (f != goal); rotate(u)){
	    	if(p[f].fa != goal)  rotate(identify(u) == identify(f) ? f : u);
        }
	    if(!goal)  root = u; 
        update(u);
    }

    int find_Kth(int k){
        int u = root;
        while(1){
            pushdown(u);
            if(p[ls].size + 1 == k)  return u;
            if(p[ls].size >= k)  u = ls;
            else  k -= p[ls].size + 1,  u = rs;
        }
    }
    
    int build(int l, int r, int fa){
        if(l > r)  return 0;
		int mid = (l + r) >> 1, now = cnt[mid];
        if(l == r){
            p[now].val = c[l];
            p[now].fa =  fa;
            update(now);
            p[now].mid_max = p[now].val;
            return now;
        }
        p[now].fa = fa,  p[now].val = c[mid];
        p[now].son[0] = build(l, mid - 1, now);
        p[now].son[1] = build(mid + 1, r, now);
        update(now);
        return now;
    }

    // 在第u个位置后面插入值插入tot个数
    void insert(int u, int tot){
        for(int i = 1; i <= tot; ++i)  scanf("%d", &c[i]), cnt[i] = id[top--];
	    int rt = build(1, tot, 0);
        int L = find_Kth(u),  R = find_Kth(u + 1);
        splay(L, 0),  splay(R, L);
        p[rt].fa = R,  p[R].son[0] = rt;
	    splay(rt, 0);
    }
    
    // 区间删除 
    void delete_range(int pos, int tot){
        int L = find_Kth(pos),  R = find_Kth(pos + tot + 1);
        splay(L, 0),  splay(R, L);
		destroy(p[R].son[0]);
		p[R].son[0] = 0;
    	update(R),  update(L);
    }
	
	// 区间修改 
    void modify_range(int pos, int tot, int set){
        int L = find_Kth(pos),  R = find_Kth(pos + tot + 1);
        splay(L, 0),  splay(R, L);
        int u = p[R].son[0];
        p[u].reset = p[u].val = set;
        p[u].sum = p[u].len * set;
        p[u].l_max = p[u].r_max = max(0, p[u].sum);
        p[u].mid_max = max(p[u].sum, set);
		update(R), update(L);
    }

	// 翻转区间 
    void reverse(int pos, int tot){
        int L = find_Kth(pos),  R = find_Kth(pos + tot + 1);
        splay(L, 0),  splay(R, L);
        int u = p[R].son[0];
        p[u].tag ^= 1;
        swap(ls, rs);
        swap(p[u].l_max, p[u].r_max);
        update(R), update(L);
    }
 
    int get_sum(int pos, int tot){
        int L = find_Kth(pos),  R = find_Kth(pos + tot + 1);
        splay(L, 0),  splay(R, L);
        return p[p[R].son[0]].sum;
    }
    
	// 非空子段最大值	
    int Max(){
        return p[root].mid_max;
    }
    
    // 插入虚点,之后操作要注意下标 
    void init(){
    	for(int i = 3; i <= N - 45; ++i)  id[++top] = i; 
        p[1].son[1] = 2,  p[2].fa = 1;
	    p[1].size = 2,  p[2].size = 1;
	    // 虚点赋值负无穷,消除影响,根据不同题而定 
	    p[0].mid_max = p[1].mid_max = p[2].mid_max = p[0].val = p[1].val = p[2].val = -inf;
    }
}splay;
你只有十分努力,才能看上去毫不费力。
原文地址:https://www.cnblogs.com/214txdy/p/14023952.html