[USACO19JAN]Train Tracking 2——神仙结论题+DP

原题链接
orz xzz巨佬
首先发现一个结论:两个相邻的(c)值如果不相同的话,就可以固定某个位置的值了
这启示我们把连续且相等的(c)给单独拿出来看,也就是对于一些(c_i=c_{i+1}=...=c_j=v),能不能从中得出一些东西
这一段代表的区间总长为(j-i+k),所有的数都大于等于(v),同时每(k)个中就有至少一个(v),有一个比较显然的(dp):设(f[i])表示最后一个(v)(i)位置时的合法方案数,(p)(1e9-v)那么有转移

[f[i]=sumlimits_{j=i-k}^{i-1}p^{i-j-1}f[j] ]

这样是(O(nk))的,不可接受,于是用一个错位相消就可以得到一个可以(O(1))转移的式子

[f[i]=(p+1)f[i-1]-p^k*f[i-k-1] ]

其实上面这个式子也有实际意义,可以直接推出来
然后把这一段的(c)放到序列中看看,考虑它们左右的两个数(c_{i-1})(c_{j+1})影响((a)为原序列)
①如果(c_{i-1}>c_i),表明(a_{i+k-1}=v)(c_i~c_j)中的前(k-1)个在前一段中已经被考虑过了,因此本次需要考虑的长度减少(k-1+1=k)
②如果(c_{j+1}>c_j),表明(a_{i-k+1}=v)(c_i~c_j)中的后(k-1)个在后一段中已经被考虑过了,因此本次需要考虑的长度也减少(k-1+1=k)
这意味着我们只需要在(i-j+k)的基础上减掉几个(k)就可以将其化归到上一个模型上去了
具体实现的话,我们只需要找出所有的极大连续相等子段并把它们的贡献累乘起来就行了
代码

#include <algorithm>
#include  <iostream>
#include   <cstdlib>
#include   <cstring>
#include    <cstdio>
#include    <random>
#include    <string>
#include    <vector>
#include     <cmath>
#include     <ctime>
#include     <queue>
#include       <map>
#include       <set>

#define IINF 0x3f3f3f3f3f3f3f3fLL
#define u64 unsigned long long
#define pii pair<int, int>
#define mii map<int, int>
#define u32 unsigned int
#define lbd lower_bound
#define ubd upper_bound
#define INF 0x3f3f3f3f
#define vi vector<int>
#define ll long long
#define mp make_pair
#define pb push_back
#define is insert
#define se second
#define fi first
#define ps push

#define $SHOW(x) cout << #x" = " << x << endl
#define $DEBUG() printf("%d %s
", __LINE__, __FUNCTION__)

using namespace std;

#define MAXN 100000
#define MOD 1000000007

int n, k, c[MAXN + 5], f[MAXN + 5];

int fpow(int x, int p) {
	int ret = 1;
	while (p) {
		if (p & 1) ret = 1LL * ret * x % MOD;
		x = 1LL * x * x % MOD;
		p >>= 1;
	}
	return ret;
}

int solve(int v, int l) {
	int p = 1000000000 - v, pk = fpow(p, k);
	f[0] = f[1] = 1;
	for (int i = 2; i <= l + 1; ++i) {
		f[i] = 1LL * (p + 1) * f[i - 1] % MOD;
		if(i - k - 1 >= 0) f[i] = (f[i] - 1LL * pk * f[i - k - 1] % MOD + MOD) % MOD;
	}
	return f[l + 1]; // 注意这里返回l+1而不是l,否则就会钦定a[l]为v了
}

int main() {
	scanf("%d%d", &n, &k);
	for (int i = 1; i <= n - k + 1; ++i) scanf("%d", &c[i]);
	int ans = 1;
	for (int i = 1, j, len; i <= n - k + 1; i = j + 1) {
		j = i;
		while (c[j+1] == c[i]) j++;
		len = j - i + k;
		if (i != 1 && c[i - 1] > c[i]) len -= k;
		if (j != n - k + 1 && c[j + 1] > c[i]) len -= k;
		if (len > 0) ans = 1LL * ans * solve(c[i], len) % MOD;
	}
	printf("%d
", ans);
    return 0;
}
原文地址:https://www.cnblogs.com/dummyummy/p/11039476.html