AcWing 955. 维护数列(splay插入,删除,区间修改,区间翻转,区间求和,区间求最大子段和)

题目链接

解题思路

  板子题,考察对(splay)的插入,删除,区间修改,区间翻转,区间求和,区间求最大子段和的基础操作。

const int maxn = 1e6+10;
int n, m, a[maxn];
struct Node {
    int s[2], v, p, sz;
    int sum, sm, lm, rm;
    int rev, same;
    void init(int _v, int _p) {
        sum = sm = v = _v, p = _p;
        lm = rm = max(_v, 0);
        s[0] = s[1] = rev = same = 0; 
        //因为要回收再利用,所以都要初始化为0
        sz = 1;
    }
} tr[maxn];
int rt, q[maxn], tt;
inline void push_up(int u) {
    Node &now = tr[u];
    Node ls = tr[now.s[0]];
    Node rs = tr[now.s[1]];    
    now.sz = ls.sz+rs.sz+1;
    now.sum = ls.sum+rs.sum+now.v;
    now.lm = max(ls.lm, ls.sum+rs.lm+now.v);
    now.rm = max(rs.rm, rs.sum+ls.rm+now.v);
    now.sm = max({ls.sm, rs.sm, ls.rm+now.v+rs.lm});
}
inline void push_down(int u) {
    Node &now = tr[u];
    Node &ls = tr[now.s[0]];
    Node &rs = tr[now.s[1]];
    if (now.same) {
        now.same = now.rev = 0;
        //如果u是叶子就不能更新了,也就是看u有没有儿子
        if (now.s[0]) {
            ls.same = 1; 
            ls.v = now.v;
            ls.sum = ls.sz*now.v;
            if (now.v>0) ls.sm = ls.lm = ls.rm = ls.sum;
            //最大子段和至少要有一个数,所以ls.sm至少为now.v
            else ls.sm = now.v, ls.lm = ls.rm = 0; 
        }
        if (now.s[1]) {
            rs.same = 1; 
            rs.v = now.v;
            rs.sum = rs.sz*now.v;
            if (now.v>0) rs.sm = rs.lm = rs.rm = rs.sum;
            else rs.sm = now.v, rs.lm = rs.rm = 0;
        }
    }
    else if (now.rev) {
        ls.rev ^= 1;
        rs.rev ^= 1;
        swap(ls.s[0], ls.s[1]);
        swap(rs.s[0], rs.s[1]);
        swap(ls.lm, ls.rm);
        swap(rs.lm, rs.rm);
        now.rev = 0;
    }
}

int build(int l, int r, int p) {
    int mid = (l+r)>>1;
    int u = q[tt--];
    tr[u].init(a[mid], p);
    if (l<mid) tr[u].s[0] = build(l, mid-1, u);
    if (r>mid) tr[u].s[1] = build(mid+1, r, u);
    push_up(u);
    return u;
}
void rotate(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1]==x;
    tr[z].s[tr[z].s[1]==y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k^1], tr[tr[x].s[k^1]].p = y;
    tr[x].s[k^1] = y, tr[y].p = x;
    push_up(y), push_up(x);
}
void splay(int x, int k) {
    while(tr[x].p!=k) {
        int y = tr[x].p, z = tr[y].p;
        if (z!=k) 
            if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else rotate(y);
        rotate(x);
    }
    if (!k) rt = x;
}
int get_k(int k) {
    int u = rt;
    while(u) {
        push_down(u);
        int sz = tr[tr[u].s[0]].sz;
        if (sz>=k) u = tr[u].s[0];
        else if (sz+1==k) return u;
        else k -= sz+1, u = tr[u].s[1];
    }
    return 0;
}
void rec(int u) {
    if (!u) return;
    if (tr[u].s[0]) rec(tr[u].s[0]);
    if (tr[u].s[1]) rec(tr[u].s[1]);
    q[++tt] = u;
}
int main() {
    IOS;
    for (int i = 1; i<maxn; ++i) q[++tt] = i;
    cin >> n >> m;
    for (int i = 1; i<=n; ++i) cin >> a[i];
    a[0] = a[n+1] = tr[0].sm = -INF;
    rt = q[tt]; build(0, n+1, 0);
    while(m--) {
        char op[20]; int pos, tot;
        cin >> op;
        if (!strcmp(op, "INSERT")) {
            cin >> pos >> tot;
            int L = get_k(pos+1), R = get_k(pos+2);
            splay(L, 0), splay(R, L);
            for (int i = 0; i<tot; ++i) cin >> a[i];
            tr[R].s[0] = build(0, tot-1, R);
            push_up(R), push_up(L);
        }
        else if (!strcmp(op, "DELETE")) {
            cin >> pos >> tot;
            int L = get_k(pos), R = get_k(pos+tot+1);
            splay(L, 0), splay(R, L);
            rec(tr[R].s[0]);
            tr[R].s[0] = 0;
            push_up(R), push_up(L);
        }
        else if (!strcmp(op, "MAKE-SAME")) {
            int c;
            cin >> pos >> tot >> c;
            int L = get_k(pos), R = get_k(pos+tot+1);
            splay(L, 0), splay(R, L);
            Node &now = tr[tr[R].s[0]];
            now.same = 1; 
            now.v = c;
            now.sum = now.sz*c;
            if (c>0) now.sm = now.lm = now.rm = now.sum;
            else now.sm = c, now.lm = now.rm = 0;
            push_up(R), push_up(L);
        }
        else if (!strcmp(op, "REVERSE")) {
            cin >> pos >> tot;
            int L = get_k(pos), R = get_k(pos+tot+1);
            splay(L, 0), splay(R, L);
            Node &now = tr[tr[R].s[0]];
            now.rev ^= 1;
            swap(now.s[0], now.s[1]);
            swap(now.lm, now.rm);
            push_up(R), push_up(L);
        }   
        else if (!strcmp(op, "GET-SUM")) {
            cin >> pos >> tot;
            int L = get_k(pos), R = get_k(pos+tot+1);
            splay(L, 0), splay(R, L);
            cout << tr[tr[R].s[0]].sum << endl;
        }
        else if (!strcmp(op, "MAX-SUM")) {
            cout << tr[rt].sm << endl;
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/shuitiangong/p/15093152.html