【NOI2005】维护数列

https://daniu.luogu.org/problem/show?pid=2042

一道伸展树维护数列的很悲伤的题目,共要维护两个标记和两个数列信息,为了维护MAX-SUM还要维护从左端开始的数列的最大和及到右端结束的数列的最大和。

按照伸展树的套路,给数列左右两边加上不存在的边界节点,给每个子树的空儿子指向哨兵节点。

维护最大子数列和

题目说的子数列其实要求至少包含一个元素,这就要很恶心的维护方法。

(其实让max_sum可以不含元素也能过90%)

每个节点定义max_sum:该节点的最大数列和(至少包含一个元素)

max_lsum:该节点的从左端开始的最大数列和(可以不包含元素)

max_rsum:该节点的到右端结束的最大数列和(可以不包含元素)

按照分冶法,max_sum=max{左儿子max_sum,右儿子max_sum,左儿子max_rsum+该节点的值+右儿子max_lsum}。

如果它和它的左右儿子都是普通节点,这个转移保证至少有一个元素。

如果它是普通节点或边界节点,它的左或右儿子是哨兵节点,则左儿子max_sum或右儿子max_sum是不可取的。故令哨兵节点的max_sum=-inf。

如果它是边界节点,它必定至多有一个儿子,令它的max_sum等于它的唯一儿子的max_sum,max_lsum与max_rsum同理。

覆盖子数列和翻转子数列

每个节点定义两个标记replaced和reversed。

replaced:这个节点及它的所有后代都应该修改为一个特定的值,但实际上只有这个节点的值已经修改。

reversed:这个节点及它的所有后代都应该交换左右子树(max_lsum和max_rsum也应该跟着交换),但实际上只有这个节点的左右子树已经交换。

可见这两个标记是互斥的,且replaced标记的优先级显然大于reversed标记。

打标记的时候注意维护每个结点的标记至多有一个就可以了。

 

350行的不压行代码,6.33KB,调了近8小时交了差不多二十遍才AC:

#include <algorithm>
#include <cctype>
#include <iostream>
#include <string>
using namespace std;
void getstr(string &s)
{
    int c;
    s = "";
    while (!isalpha(c = getchar()))
    {
        if (c == EOF)
            return;
    }
    do
        s += (char)c;
    while (isalpha(c = getchar()) || c == '-');
}
void getint(int &x)
{
    int c;
    bool flag = false;
    x = 0;
    while (!isdigit(c = getchar()))
    {
        if (c == EOF)
            return;
        if (c == '-')
            flag = true;
    }
    do
        x = x * 10 + c - '0';
    while (isdigit(c = getchar()));
    if (flag)
        x = -x;
}
namespace splay
{
const int inf = 0x7fffffff;
enum direction
{
    l = 0,
    r
};
struct node;
node *nil = 0, *l_edge, *r_edge;
struct node
{
    int val, size;
    node *ch[2];
    int sum;
    int max_sum, max_lsum, max_rsum;
    // max_sum 定义为最少包含一个元素的最大子数列和
    // max_lsum 定义为从左端开始的可以不包含元素的最大子数列和
    // max_lsum 定义为到右端结束的可以不包含元素的最大子数列和

    bool replaced, reversed;
    // 当replaced为true,表示它的所有后代的val应该与这个节点的val相同,但实际上后代节点并没有更新
    // 当reversed为true,表示它已经交换了左右节点和左右最大值,且它的所有后代都应该交换左右子树和左右最大值,但实际上后代节点并没有更新

    node(int v) : val(v), size(1), sum(v), replaced(false), reversed(false)
    {
        ch[l] = ch[r] = nil;
        if (v >= 0)
            max_sum = max_lsum = max_rsum = sum;
        else
        {
            max_sum = v;
            max_lsum = max_rsum = 0;
        }
    }
    int cmp(int k)
    {
        if (k == ch[l]->size + 1 || this == nil)
            return -1;
        else
            return k <= ch[l]->size ? l : r;
    }

    void reverse()
    {
        if (!replaced)
        {
            reversed ^= 1;
            swap(ch[l], ch[r]);
            swap(max_lsum, max_rsum);
        }
    }
    void replace(int v)
    {
        reversed = false;
        replaced = true;
        val = v;
        sum = v * size;
        if (v > 0)
            max_sum = max_lsum = max_rsum = sum;
        else
        {
            max_sum = v; // 由于子数列要求至少有一个元素,故当 val < 0
                         // ,只有一个元素时和最大
            max_lsum = max_rsum = 0;
        }
    }

    void push_down()
    {
        if (replaced)
        {
            ch[l]->replace(val);
            ch[r]->replace(val);
            replaced = false;
        }
        else if (reversed)
        {
            ch[l]->reverse();
            ch[r]->reverse();
            reversed = false;
        }
    }
    void pull_up()
    {
        if (this != nil)
        {
            size = ch[l]->size + ch[r]->size + 1;

            if (!replaced)
                sum = ch[l]->sum + ch[r]->sum + val;
            else
                sum = val * size;

            if (this != l_edge && this != r_edge)
            {
                max_sum = max(ch[l]->max_rsum + val + ch[r]->max_lsum,
                              max(ch[l]->max_sum,
                                  ch[r]->max_sum)); // 更新后 max_sum 至少包含一个元素
                max_lsum = max(
                    ch[l]->max_lsum,
                    ch[l]->sum + val +
                        ch[r]->max_lsum); // 更新后 max_lsum / max_rsum 可以不包含元素
                max_rsum = max(ch[r]->max_rsum, ch[l]->max_rsum + val + ch[r]->sum);
            }
            else if (this == l_edge) // 注意特判左右边界节点
            {
                // 若不特判,当左边界节点为根且整个数列的从左开始的最大值为0时
                // 就会出现 max_sum = ch[l]->max_rsum + val + ch[r]->max_lsum
                // 即 max_sum = 0,这显然不合法
                max_sum = ch[r]->max_sum;
                max_lsum = ch[r]->max_lsum;
                max_rsum = ch[r]->max_rsum;
            }
            else
            {
                // 右边界同理
                max_sum = ch[l]->max_sum;
                max_lsum = ch[l]->max_lsum;
                max_rsum = ch[l]->max_rsum;
            }
        }
    }

    void remove()
    {
        if (this != nil)
        {
            ch[l]->remove();
            ch[r]->remove();
            delete this;
        }
    }
} * root;
void init()
{
    if (!nil)
        nil = new node(0);
    nil->size = 0;
    nil->ch[l] = nil->ch[r] = nil;
    nil->max_sum = -inf;
    l_edge = new node(0), r_edge = new node(0);
    l_edge->max_sum = -inf;
    r_edge->max_sum = -inf;
    root = nil;
}
void rotate(node *&t, int d)
{
    t->push_down();
    t->ch[l]->push_down();
    t->ch[r]->push_down();
    node *k = t->ch[d ^ 1];
    t->ch[d ^ 1] = k->ch[d];
    k->ch[d] = t;
    t->pull_up();
    k->pull_up();
    t = k;
}
void splay(node *&t, int k)
{
    t->push_down();
    int d = t->cmp(k);
    if (d == r)
        k = k - t->ch[l]->size - 1;
    if (d != -1)
    {
        t->ch[d]->push_down();
        int d2 = t->ch[d]->cmp(k);
        int k2 = (d2 == r) ? k - t->ch[d]->ch[l]->size - 1 : k;
        if (d2 != -1)
        {
            splay(t->ch[d]->ch[d2], k2);
            if (d == d2)
            {
                rotate(t, d ^ 1);
                rotate(t, d ^ 1);
            }
            else
            {
                rotate(t->ch[d], d2 ^ 1);
                rotate(t, d ^ 1);
            }
        }
        else
            rotate(t, d ^ 1);
    }
}
void join(node *&t1, node *&t2)
{
    if (t1 == nil)
        swap(t1, t2);
    splay(t1, t1->size);
    t1->ch[r] = t2;
    t2 = nil;
    t1->pull_up();
}
node *split(node *&t, int k)
{
    if (k == 0)
    {
        node *subtree = t;
        t = nil;
        return subtree;
    }
    splay(t, k);
    node *subtree = t->ch[r];
    t->ch[r] = nil;
    t->pull_up();
    return subtree;
}
node *build_tree(int *p, int n)
{
    if (n == 0)
        return nil;
    node *fa;
    node *ch = new node(p[1]);
    for (int i = 2; i <= n; i++)
    {
        fa = new node(p[i]);
        fa->ch[l] = ch;
        fa->pull_up();
        ch = fa;
    }
    return fa;
}
node *select(int p, int tot)
{
    int ln = p, rn = ln + tot - 1;
    splay(root, rn + 1);
    splay(root->ch[l], ln - 1);
    return root->ch[l]->ch[r];
}
}
int n, m;
int num[500005];
int main()
{
    using namespace splay;
    ios::sync_with_stdio(false);
    getint(n);
    getint(m);
    for (int i = 1; i <= n; i++)
        getint(num[i]);
    init();

    node *t1, *t2; // tmp
    root = l_edge;
    t1 = build_tree(num, n);
    join(root, t1);
    t1 = r_edge;
    join(root, t1);

    string opt;
    int posi, tot, c;
    while (m--)
    {
        getstr(opt);
        switch (opt[0])
        {
        case 'I': // INSERT
            getint(posi);
            getint(tot);
            posi++;
            for (int i = 1; i <= tot; i++)
                getint(num[i]);
            t1 = build_tree(num, tot);
            t2 = split(root, posi);
            join(root, t1);
            join(root, t2);
            break;
        case 'D': // DELETE
            getint(posi);
            getint(tot);
            posi++;
            t1 = split(root, posi - 1);
            t2 = split(t1, tot);
            join(root, t2);
            t1->remove();
            break;
        case 'R': // REVERSE
            getint(posi);
            getint(tot);
            posi++;
            t1 = select(posi, tot);
            t1->reverse();
            root->ch[l]->pull_up();
            root->pull_up();
            break;
        case 'G': // GET-SUM
            getint(posi);
            getint(tot);
            posi++;
            t1 = select(posi, tot);
            cout << t1->sum << endl;
            break;
        case 'M':
            if (opt[2] == 'K') // MAKE_SAME
            {
                getint(posi);
                getint(tot);
                getint(c);
                posi++;
                t1 = select(posi, tot);
                t1->replace(c);
                root->ch[l]->pull_up();
                root->pull_up();
            }
            else // MAX_SUM
                cout << root->max_sum << endl;
            break;
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/ssttkkl/p/7105484.html