题目大意
有(n)个士兵((1 leq n leq 10^5)),第(i)个士兵的身高为(h_{i}),现在要求把士兵按照原来的顺序分成连续的若干组,要求每组的士兵数量不超过(len)。
同时,我们设每组的最后一个士兵的身高为(b_{i}),则有(b_{i} > b_{i - 1})((b_0 = 0)),现在我们设每种分组方案的价值为(sum b_{i}^2 - b_{i - 1}),求能得到的最大价值为多少?
题解
我们设(dp[i])表示前(i)个士兵分成任意组的最大价值,容易得到:
[dp[i] = underset{i - len leq j < i}{max} { dp[j] + k_{i}^2 - k_{j} }
]
整理一下,得到:
[dp[i] = k_{i}^2 + underset{i - len leq j < i}{max} { dp[j] - k_{j} }
]
我们可以用线段树来维护(underset{i - len leq j < i}{max} { dp[j] - k_{j} }).
但是。如何保证题目中要求的(b_{i} > b_{i - 1})呢?
其实,对于每个士兵,我们可以先按照身高来进行升序排列,如果身高相同,我们就按照编号(原来的顺序)降序排列,然后对于排序后的士兵(i),我们设他原来的编号为(idx_{i}),则我们就查找线段树上([idx_{i} - len, idx_{i} - 1])的价值,同时更新也是更新线段树上的(idx_{i})的位置。
因为对于每个士兵(i),如果在排序前能找到和他进行状态转移的士兵(j),那么排序后,肯定有(idx_{j} in [idx_{i} - len, idx_{i} - 1]),这个大家可以自己试几个情况,所以这样做即可。
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#define MAX_N (100000 + 5)
#define SIZE (1 << 21)
#define lowbit(x) ((x) & -(x))
#define Getchar() (p1 == p2 && (p2 = (p1 = fr) + fread(fr, 1, SIZE, stdin), p1 == p2) ? EOF : *p1++)
using namespace std;
char fr[SIZE], * p1 = fr, * p2 = fr;
void Read(int & res)
{
res = 0;
char ch = Getchar();
while(!isdigit(ch)) ch = Getchar();
while(isdigit(ch)) res = res * 10 + ch - '0', ch = Getchar();
return;
}
struct Node
{
int h;
int idx;
friend inline bool operator < (Node a, Node b)
{
if(a.h != b.h) return a.h < b.h;
return a.idx > b.idx;
}
};
int T;
int n, len;
Node a[MAX_N];
long long s[MAX_N << 2];
void Modify(int x, int l, int r, int pos, long long val)
{
if (r < pos || pos < l) return;
if (l == r)
{
s[x] = val;
return;
}
int mid = l + r >> 1;
Modify(x << 1, l, mid, pos, val);
Modify(x << 1 | 1, mid + 1, r, pos, val);
s[x] = max(s[x << 1], s[x << 1 | 1]);
return;
}
long long Query(int x, int l, int r, int L, int R)
{
if (r < L || R < l) return -0x7f7f7f7f7f7f7f7f;
if (L <= l && r <= R) return s[x];
int mid = l + r >> 1;
return max(Query(x << 1, l, mid, L, R), Query(x << 1 | 1, mid + 1, r, L, R));
}
int main()
{
Read(T);
for (int I = 1; I <= T; ++I)
{
memset(s, -0x7f, sizeof s);
Read(n); Read(len);
for (int i = 1; i <= n; ++i)
{
Read(a[i].h);
a[i].idx = i;
}
sort(a + 1, a + n + 1);
long long tmp;
printf("Case #%d: ", I);
Modify(1, 0, n, 0, 0);
for (int i = 1; i <= n; ++i)
{
tmp = Query(1, 0, n, max(0, a[i].idx - len), a[i].idx - 1);
if (tmp < -0x7f7f7f7f)
{
if (a[i].idx == n)
{
printf("No solution
");
break;
}
continue;
}
if (a[i].idx == n)
{
printf("%lld
", (long long)a[i].h * a[i].h + tmp);
break;
}
Modify(1, 0, n, a[i].idx, (long long)a[i].h * a[i].h + tmp - a[i].h);
}
}
return 0;
}