[Codeforces868F]Yet Another Minimization Problem

题意:给定一个序列 (a),要把它分成 (k) 个子段。每个子段的费用是其中相同元素的对数。求所有子段的费用之和的最小值。

序列分割问题的权值真是花样百出。

先写出 dp 方程(省略了第二维段数):

[dp[i] = min {dp[j] + pairs(j + 1, i)} ]

打表发现可以证明该式满足决策单调性,即函数 (pairs(i, j)) 满足四边形不等式。

(l_1 le l_2 le r_1 le r_2),要证

[pairs(l_1, r_1) + pairs(l_2, r_2) le pairs(l_1, r_2) + pairs(l_2, r_1) ]

即证

[pairs(l_1, r_2) - pairs(l_1, r_1) ge pairs(l_2, r_2) - pairs(l_2, r_1) ]

因为

[pairs(i, j + 1) = pairs(i, j) + cnt_{a_{j + 1}}(i, j) ]

所以即证

[sum_{j = r_1 + 1}^{r_2}left(cnt_{a_j}(l_1, j - 1) - cnt_{a_j}(l_2, j - 1) ight) ge 0 ]

显然成立。

现在问题是 (pairs(i, j)) 的求解,发现难以用数据结构维护,但移动左右端点时(像莫队一样)可以很容易维护。

所以用分治优化决策单调性,这样移动左右端点的次数是 (nlog n) 级别的。

复杂度 (O(knlog n))

#include <bits/stdc++.h>
#ifdef LOCAL
#define dbg(args...) std::cerr << "33[32;1m" << #args << " -> ", err(args)
#else
#define dbg(...)
#endif
inline void err() { std::cerr << "33[0m
"; }
template<class T, class... U>
inline void err(const T &x, const U &... a) { std::cerr << x << ' '; err(a...); }
template <class T>
inline void readInt(T &w) {
  char c, p = 0;
  while (!isdigit(c = getchar())) p = c == '-';
  for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
  if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
template <class T, class U>
inline bool smin(T &x, const U &y) { return y < x ? x = y, 1 : 0; }
template <class T, class U>
inline bool smax(T &x, const U &y) { return x < y ? x = y, 1 : 0; }

typedef long long LL;
typedef std::pair<int, int> PII;

constexpr int N(1e5 + 5);

int n, k, a[N];
LL f[N], g[N];
LL ask(int x, int y) {
  static int l = 1, r = 0, cnt[N];
  static LL ans;
  while (r < y) ans += cnt[a[++r]]++;
  while (r > y) ans -= --cnt[a[r--]];
  while (l < x) ans -= --cnt[a[l++]];
  while (l > x) ans += cnt[a[--l]]++;
  return ans;
}
void solve(int l, int r, int x, int y) {
  if (l > r) return;
  int mid = l + r >> 1, p;
  for (int i = x, e = std::min(mid - 1, y); i <= e; i++)
    if (smin(f[mid], g[i] + ask(i + 1, mid)))
      p = i;
  solve(l, mid - 1, x, p), solve(mid + 1, r, p, y); 
}
int main() {
  readInt(n, k);
  for (int i = 1; i <= n; i++) readInt(a[i]);
  memset(f, 0x3f, sizeof f);
  f[0] = 0;
  for (int i = 1; i <= k; i++) {
    memcpy(g, f, sizeof f);
    memset(f, 0x3f, sizeof f);
    solve(i, n, i - 1, n);
  }
  std::cout << f[n];
  return 0;
}
原文地址:https://www.cnblogs.com/HolyK/p/13975161.html