[CF833B] The Bakery

[CF833B] The Bakery - dp,线段树

Description

将长 n 的序列分成 k 段,每段的价值为区间内不同数字个数,使得总价值最大。(n le 35000, k le 50)

Solution

(f[i][j]=max f[l][j-1]+C(l+1,i))

现在,对与所有的 l,我们已经知道 C(l+1,i-1),我们希望快速得到 C(l+1,i)

对于一个特定的 l,如果 ([l+1,i-1]) 中能找到 (a[i]),那么没有贡献

(a[i]) 上次出现的位置是 (last[i]),那么 ([last[i],i-1]) 的 l 位置会 +1

我们用线段树维护(对当前的 i,j,所有 l)(f[l-1][j-1]+C(l,i)),每次修改 ([last[i]+1,i]),每次查询最大值

注意在 dp 时,我们外层枚举 j,内层枚举 i,那么每次换 j 时需要开一棵新的线段树

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 70005;
const int M = 55;

struct Node
{
    int val;
    int tag;

    void set(int x)
    {
        val += x;
        tag += x;
    }

    Node operator+(const Node &rhs) const
    {
        return {max(val, rhs.val), 0};
    }
} node[4 * N];

void pushup(int p)
{
    node[p] = node[p * 2] + node[p * 2 + 1];
}

void pushdown(int p)
{
    if (node[p].tag)
    {
        node[p * 2].set(node[p].tag);
        node[p * 2 + 1].set(node[p].tag);
        node[p].tag = 0;
    }
}

void build(int p, int l, int r, vector<int> &src)
{
    if (l == r)
        node[p] = {src[l], 0};
    else
        build(p * 2, l, (l + r) / 2, src),
            build(p * 2 + 1, (l + r) / 2 + 1, r, src),
            pushup(p);
}

void modify(int p, int l, int r, int ql, int qr)
{
    if (l > qr || r < ql)
        return;
    if (l >= ql && r <= qr)
        node[p].set(1);
    else
        pushdown(p),
            modify(p * 2, l, (l + r) / 2, ql, qr),
            modify(p * 2 + 1, (l + r) / 2 + 1, r, ql, qr),
            pushup(p);
}

int query(int p, int l, int r, int ql, int qr)
{
    if (l > qr || r < ql)
        return 0;
    if (l >= ql && r <= qr)
        return node[p].val;
    else
    {
        pushdown(p);
        return max(query(p * 2, l, (l + r) / 2, ql, qr),
                   query(p * 2 + 1, (l + r) / 2 + 1, r, ql, qr));
    }
}

int n, k, a[N], f[N][M], last[N], pos[N];
signed main()
{
    ios::sync_with_stdio(false);
    cin >> n >> k;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= n; i++)
        last[i] = pos[a[i]], pos[a[i]] = i;
    for (int j = 1; j <= k; j++)
    {
        vector<int> src;
        src.push_back(0);
        for (int i = 1; i <= n; i++)
            src.push_back(f[i - 1][j - 1]);
        build(1, 1, n, src);
        for (int i = 1; i <= n; i++)
        {
            modify(1, 1, n, last[i] + 1, i);
            f[i][j] = query(1, 1, n, 1, i);
        }
    }
    cout << f[n][k] << endl;
}

原文地址:https://www.cnblogs.com/mollnn/p/14369513.html