splay模板整理

1.插入一个数

void insert(int x)
{
    if (!root)
    {
        ++tot;
        e[tot].left = e[tot].right = e[tot].fa = 0;
        e[tot].v = x;
        e[tot].cnt = sizee[tot] = 1;
        root = tot;
        return;
    }
    int now = root,fa = 0;
    while (1)
    {
        if(e[now].v == x)
        {
            e[now].cnt++;
            update(now);
            update(fa);
            splay(now);
            break;
        }
        fa = now;
        if (e[now].v < x)
            now = e[now].right;
        else
            now = e[now].left;
        if (!now)
        {
            ++tot;
            e[tot].left = e[tot].right = 0;
            e[tot].fa = fa;
            e[tot].v = x;
            e[tot].cnt = sizee[tot] = 1;
            if (e[fa].v < x)
                e[fa].right = tot;
            else
                e[fa].left = tot;
            update(fa);
            splay(tot);
            break;
        }
    }
}

2.splay

void splay(int x,int yy)
{
    if (yy == 0)
        root = x;
    while (e[x].fa != yy)
    {
        pushdown(x);
        int y = e[x].fa;
        int z = e[y].fa;
        if (z == yy || z == 0)
        {
            if (x == e[y].left)
                turnr(x);
            else
                turnl(x);
        }
        else
        {
            if (e[z].left == y && e[y].left == x)
            {
                turnr(y);
                turnr(x);
            }
            else
            {
                if (e[z].right == y && e[y].right == x)
                {
                    turnl(y);
                    turnl(x);
                }
                else
                {
                    if (e[z].left == y && e[y].right == x)
                    {
                        turnl(x);
                        turnr(x);
                    }
                    else
                    {
                            turnr(x);
                            turnl(x);
                    }
                }
            }
        }
    }
    if (yy == 0)
        root = x;
}

3.左右旋

void turnr(int x)
{
    pushdown(x);
    int y = e[x].fa;
    int z = e[y].fa;
    e[y].left = e[x].right;
    if (e[x].right != 0)
        e[e[x].right].fa = y;
    e[x].fa = z;
    if (z != 0)
    {
        if (e[z].left == y)
            e[z].left = x;
        else
            e[z].right = x;
    }
    e[x].right = y;
    e[y].fa = x;
    update(x);
    update(y);
}

void turnl(int x)
{
    pushdown(x);
    int y = e[x].fa;
    int z = e[y].fa;
    e[y].right = e[x].left;
    if (e[x].left != 0)
        e[e[x].left].fa = y;
    e[x].fa = z;
    if (z != 0)
    {
        if (e[z].left == y)
            e[z].left = x;
        else
            e[z].right = x;
    }
    e[x].left = y;
    e[y].fa = x;
    update(x);
    update(y);
}

4.找前驱/后继(若点已经在splay里面)

int findl(int x)
{
    int y = e[x].left;
    if (y == -1)
        return y;
    while (e[y].right != -1)
        y = e[y].right;
    return y;
}

int findr(int x)
{
    int y = e[x].right;
    if (y == -1)
        return y;
    while (e[y].left != -1)
        y = e[y].left;
    return y;
}

若点不在splay里面

void findl(int x)
{
    int now = root;
    temp1 = -1;
    v1 = 0x7fffffff;
    while (now != -1)
    {
        if (e[now].v < x && (x - e[now].v) < v1)
        {
            temp1 = now;
            v1 = x - e[now].v;
        }
        if (e[now].v < x)
            now = e[now].right;
        else
            now = e[now].left;
    }
}

void findr(int x)
{
    int now = root;
    temp2 = -1;
    v2 = 0x7fffffff;
    while (now != -1)
    {
        if (e[now].v > x && (e[now].v - x) < v2)
        {
            temp2 = now;
            v2 = e[now].v - x;
        }
        if (e[now].v < x)
            now = e[now].right;
        else
            now = e[now].left;
    }
}

5.删除

void del(int x)
{
    int p = find(x);
    splay(p);
    if (e[p].cnt > 1)
        {
            e[p].cnt--;
            update(p);
            return;
        }
    if (e[p].left == -1 && e[p].right == -1)
    {
        root = -1;
        return;
    }
    if (e[p].left == -1)
    {
        root = e[p].right;
        e[e[p].right].fa = -1;
        return;
    }
    if (e[p].right == -1)
    {
        root = e[p].left;
        e[e[p].left].fa = -1;
        return;
    }
    int j = e[p].left;
    while (e[j].right != -1)
        j = e[j].right;
    splay(j);
    e[j].right = e[p].right;
    e[e[p].right].fa = j;
    update(j);
}

6.查询数x的排名

int query(int x)
{
    int p = root,res = 0;
    while (p != -1)
    {
        if (x < e[p].v)
            p = e[p].left;
        else
        {
            res += getsize(e[p].left);
            if (x == e[p].v)
                return res + 1;
            res += e[p].cnt;
            p = e[p].right; 
        }
    }
    return res;
}

7.查询排名为x的数

int query2(int x)
{
    int p = root;
    while (p != -1)
    {
        if (e[p].left && x <= sizee[e[p].left])
        p = e[p].left;
        else
        {
            int temp = getsize(e[p].left) + e[p].cnt;
            if (x <= temp)
            return e[p].v;
            x -= temp;
            p = e[p].right;
        }
    }
    return p;
}

8.各种区间操作(bzoj1500代码):

#include <stack>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 500010,inf = 707406378;
stack <int> s;
int n,m,a[maxn],root,tot,sizee[maxn];

struct node
{
    int fa,left,right,v,sum,ans,lmax,rmax,tag,cnt;
}e[maxn];

void pushup(int x)
{
    sizee[x] = 1 + sizee[e[x].left] + sizee[e[x].right];
    e[x].sum = e[e[x].left].sum + e[e[x].right].sum + e[x].v;
    e[x].lmax = max(e[e[x].left].lmax,e[e[x].left].sum + e[x].v + max(0,e[e[x].right].lmax));
    e[x].rmax = max(e[e[x].right].rmax,e[e[x].right].sum + e[x].v + max(0,e[e[x].left].rmax));
    e[x].ans = max(e[e[x].left].ans,e[e[x].right].ans);
    e[x].ans = max(e[x].ans,max(e[e[x].left].rmax,0) + e[x].v + max(0,e[e[x].right].lmax));
}

void fan(int x)
{
    int t = e[x].lmax;
    e[x].lmax = e[x].rmax;
    e[x].rmax = t;
    t = e[x].left;
    e[x].left = e[x].right;
    e[x].right = t;
    e[x].tag ^= 1;
}

void cover(int x,int y)
{
    e[x].sum = sizee[x] * y;
    e[x].v = y;
    e[x].maxx = y;
    if (y <= 0)
        e[x].lmax = e[x].rmax = e[x].ans = y;
    else
        e[x].lmax = e[x].rmax = e[x].ans = y * sizee[x];
    e[x].cnt = 1;
}

void pushdown(int x)
{
    if (e[x].tag)
    {
        if (e[x].left)
            fan(e[x].left);
        if (e[x].right)
            fan(e[x].right);
        e[x].tag = 0;
    }
    if (e[x].cnt)
    {
        if (e[x].left)
            cover(e[x].left,e[x].v);
        if (e[x].right)
            cover(e[x].right,e[x].v);
        e[x].cnt = 0;
    }
}

void build(int l,int r,int &x,int y)
{
    if (l > r)
        return;
    if (!x)
    {
        if (!s.empty())
        {
            x = s.top();
            s.pop();
        }
        else
            x = ++tot;
    }
    int mid = (l + r) >> 1;
    e[x].v = a[mid];
    e[x].fa = y;
    build(l,mid - 1,e[x].left,x);
    build(mid + 1,r,e[x].right,x);
    pushup(x);
}

void turnr(int x)
{
    pushdown(x);
    int y = e[x].fa;
    int z = e[y].fa;
    e[y].left = e[x].right;
    if (e[x].right != 0)
        e[e[x].right].fa = y;
    e[x].fa = z;
    if (z != 0)
    {
        if (e[z].left == y)
            e[z].left = x;
        else
            e[z].right = x;
    }
    e[x].right = y;
    e[y].fa = x;
    pushup(x);
    pushup(y);
}

void turnl(int x)
{
    pushdown(x);
    int y = e[x].fa;
    int z = e[y].fa;
    e[y].right = e[x].left;
    if (e[x].left != 0)
        e[e[x].left].fa = y;
    e[x].fa = z;
    if (z != 0)
    {
        if (e[z].left == y)
            e[z].left = x;
        else
            e[z].right = x;
    }
    e[x].left = y;
    e[y].fa = x;
    pushup(x);
    pushup(y);
}

void splay(int x,int yy)
{
    while (e[x].fa != yy)
    {
        pushdown(x);
        int y = e[x].fa;
        int z = e[y].fa;
        if (z == 0 || z == yy)
        {
            if (e[y].left == x)
                turnr(x);
            else
                turnl(x);
        }
        else
        {
            if (e[z].left == y && e[y].left == x)
            {
                turnr(y);
                turnr(x);
            }
            else
            {
                if (e[z].right == y && e[y].right == x)
                {
                    turnl(y);
                    turnl(x);
                }
                else
                {
                    if (e[z].left == y && e[y].right == x)
                    {
                        turnl(x);
                        turnr(x);
                    }
                    else
                    {
                        turnr(x);
                        turnl(x);
                    }
                }
            }
        }
    }
    if (yy == 0)
        root = x;
    pushup(x);
}

int find(int x,int k)
{
    pushdown(x);
    if (k > sizee[e[x].left] + 1)
        return find(e[x].right,k - 1 - sizee[e[x].left]);
    if (k == sizee[e[x].left] + 1)
        return x;
    return find(e[x].left,k);
}

void del(int x)
{
    if (!x)
        return;
    s.push(x);
    del(e[x].left);
    del(e[x].right);
    sizee[x] = e[x].left = e[x].right = e[x].cnt = e[x].tag = e[x].fa = e[x].sum = e[x].v = 0;
    e[x].maxx = e[x].lmax = e[x].rmax = e[x].ans = -inf;
}

int main()
{
    for (int i = 0; i < maxn; i++)
        e[i].lmax = e[i].rmax = e[i].ans = -inf;
    scanf("%d%d",&n,&m);
    for (int i = 1; i <= n; i++)
        scanf("%d",&a[i]);
    build(0,n + 1,root,0);
    root = 1;
    for (int i = 1; i <= m; i++)
    {
        char s[15];
        scanf("%s",s);
        if (s[0] == 'I')
        {
            int x,y;
            scanf("%d%d",&x,&y);
            for (int j = 1; j <= y; j++)
                scanf("%d",&a[j]);
            int p = find(root,x + 1),q = find(root,x + 2);
            splay(p,0);
            splay(q,p);
            build(1,y,e[q].left,q);
            pushup(q);
            pushup(p);
        }
        if (s[0] == 'D')
        {
            int x,y;
            scanf("%d%d",&x,&y);
            int p = find(root,x),q = find(root,x + y + 1);
            splay(p,0);
            splay(q,p);
            del(e[q].left);
            e[q].left = 0;
            pushup(q);
            pushup(p);
        }
        if (s[0] == 'M' && s[2] == 'K')
        {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            int p = find(root,x),q = find(root,x + y + 1);
            splay(p,0);
            splay(q,p);
            cover(e[q].left,z);
            pushup(q);
            pushup(p);
        }
        if (s[0] == 'R')
        {
            int x,y;
            scanf("%d%d",&x,&y);
            int p = find(root,x),q = find(root,x + y + 1);
            splay(p,0);
            splay(q,p);
            if (e[q].left && !e[e[q].left].tag)
            fan(e[q].left);
            pushup(q);
            pushup(p);
        }
        if (s[0] == 'G')
        {
            int x,y;
            scanf("%d%d",&x,&y);
            int p = find(root,x),q = find(root,x + y + 1);
            splay(p,0);
            splay(q,p);
            printf("%d
",e[e[q].left].sum);
        }
        if (s[0] == 'M' && s[2] == 'X')
        {
            int p = find(root,1),q = find(root,sizee[root]);
            splay(p,0);
            splay(q,p);
            printf("%d
",e[e[q].left].ans);
        }
    }

    return 0;
}
原文地址:https://www.cnblogs.com/zbtrs/p/8276402.html