题解 BZOJ4709

题目描述

一道简单DP优化调了好久qwq

首先分析题目,发现每次从一边取贝壳是完全没用的,此题本质就是将区间分成数个区间,使区间价值和最大。

可以发现一个性质,那就是最优解的每个区间的两端点一定相同且为选取的(s_0)。因为如果区间两端点的值不同,那么完全可以将多余的值分为另一个区间使价值和更大。

所以可以写出简单的dp式:

(dp[i] = max(dp[j-1] + s[i] * (sum[i] - sum[j]+1)^2) quad (s[j] == s[i]))

其中(sum[i])为1…i中(s[i])的个数,可以简单的(O(1))维护,所以总复杂度(O(n^2))

观察单调性,发现对于决策a,b((a<b))如果在k处a比b优,那么在k之后a也一定比b优,而k可以通过二分(O(log_n))求出

所以可以使用单调栈维护最优决策,对于每次决策,如果栈顶不优了就弹出栈顶。同时,为了维护栈的单调性,每次入栈z时,如果z与栈顶元素的分界点(k_1)比栈顶与栈顶的下一个元素的分界点(k_2)靠后,那么便可以弹出栈顶元素。

/**************************************************************
    Problem: 4709
    User: liuxinyuan
    Language: C++
    Result: Accepted
    Time:328 ms
    Memory:3144 kb
****************************************************************/

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <stack>
#include <vector>
#define gc getchar
#define il inline
#define re register
#define LL long long
#define mid(l, r) (((l) + (r)) >> 1)
#define sqr(x) (1ll * (x) * 1ll * (x))
#define m_p(x, y) make_pair(x, y)
using namespace std;
template <typename T>
void rd(T &s)
{
    s = 0;
    bool p = 0;
    char ch;
    while (ch = gc(), p |= ch == '-', ch < '0' || ch > '9')
        ;
    while (s = s * 10 + ch - '0', ch = gc(), ch >= '0' && ch <= '9')
        ;
    s *= (p ? -1 : 1);
}
template <typename T, typename... Args>
void rd(T &s, Args &... args)
{
    rd(s);
    rd(args...);
}
const int MAXM = 10005;
const int MAXN = 100050;
vector<int> sta[MAXM];
LL dp[MAXN];
int s[MAXN];
int cnt[MAXM], sum[MAXN];
int n;
il LL cal(int x, int y)
{
    return dp[x - 1] + s[x] * 1ll * y * 1ll * y;
}
int lower(int a, int b)
{
    int v = s[a];
    int l = 1, r = cnt[v], ans = cnt[v] + 1;
    while (l <= r)
    {
        int m = mid(l, r);
        if (cal(a, m - sum[a] + 1) >= cal(b, m - sum[b] + 1))
            ans = m,
            r = m - 1;
        else
            l = m + 1;
    }
    return ans;
}
int main()
{
    int v;
    rd(n);
    for (int i = 1; i <= n; ++i)
    {
        rd(s[i]);
        sum[i] = ++cnt[s[i]];
    }
    for (int i = 1; i <= n; ++i)
    {
        v = s[i];
        while (sta[v].size() >= 2 && lower(sta[v][sta[v].size() - 2], sta[v][sta[v].size() - 1]) < lower(sta[v][sta[v].size() - 1], i))
            sta[v].pop_back();
        sta[v].push_back(i);
        while (sta[v].size() >= 2 && lower(sta[v][sta[v].size() - 2], sta[v][sta[v].size() - 1]) <= sum[i])
            sta[v].pop_back();
        dp[i] = cal(sta[v][sta[v].size() - 1], sum[i] - sum[sta[v][sta[v].size() - 1]] + 1);
        // cout << i << " " << sta[v][sta[v].size() - 1] << endl;
    }
    // for (int i = 1; i <= n; ++i)
    printf("%lld", dp[n]);
    return 0;
}
原文地址:https://www.cnblogs.com/happyLittleRabbit/p/10598907.html