CF1093F Vasya and Array 更优秀的做法

CF1093F Vasya and Array 更优秀的做法

摘要

本题有一个经典的 DP + 容斥做法,时间复杂度是 (O(nk))

本文作者在此基础上,创新性地发掘题目性质,简化 DP 状态,利用数据结构优化 DP,提出了一个时间复杂度 (O(n)) 的做法。显然相比经典做法,是更加优秀的。

本文将分别介绍这两种做法。

题目大意

题目链接

给出一段长度为 (n) 的整数序列,一个正整数 (k) ,一个正整数 ( ext{len})。序列中的所有数要么在 ([1,k]) 之间,要么等于 (-1)

我们称一个序列是好的,当且仅当不存在 ( ext{len}) 个连续的相同的数字。

你可以将每个 (-1),替换成任意一个 ([1,k]) 之间的整数。求有多少种方案,使得最终序列是好的。答案对 (998244353) 取模。

数据范围:(1leq nleq 10^5)(1leq kleq 100)(1leq ext{len}leq n)

经典做法

( ext{dp}_0(i,j)) 表示考虑了前 (i) 个位置,最后一个位置上填的数是 (j),前 (i) 个位置组成的序列合法的方案数。设 ( ext{sdp}(i)=sum_{j=1}^{k} ext{dp}_0(i,j))

首先,当 (a_i eq -1)(a_i eq j) 时,( ext{dp}_0(i,j)=0)

否则,我们暂时令 ( ext{dp}_0(i,j)= ext{sdp}(i - 1))。但是这会把一些不合法的方案算进去。具体来说,这样有可能出现 ([a_{i- ext{len}+1}dots a_i]) 全部相同的情况。这种情况会出现当且仅当如下两个条件都满足:

  • (igeq ext{len})
  • ([a_{i- ext{len}+1}dots a_i]) 中每个数都等于 (-1)(j)

所以,我们还要减去这种情况的数量:( ext{sdp}(i- ext{len})- ext{dp}_0(i- ext{len},j))。其中 ( ext{sdp}(i - ext{len})) 表示 ([a_{i- ext{len}+1}dots a_i]) 全部等于 (j) 时,前 (i - ext{len}) 位的填写方案。不过这些方案中,有一些方案可能在前 (i-1) 位就已经导致不合法了。这些提前不合法的方案本来就没有被算在 ( ext{sdp}(i-1)) 中,所以不需要被减去,它们的数量是 ( ext{dp}_0(i - ext{len},j))

综上所述,可以得到转移式:

[ ext{dp}_0(i,j)=egin{cases} 0&& a_i eq0 ext{ 且 }a_i eq j\ ext{sdp}(i-1)-( ext{sdp}(i- ext{len})- ext{dp}_0(i- ext{len},j))cdot [ ext{上文中两个条件}]&& ext{otherwise} end{cases} ]

时间复杂度(O(nk))

更优秀的做法

首先,当 ( ext{len} = 1) 时,答案一定是 (0)。以下只讨论 ( ext{len} > 1) 的情况。

先考虑一种朴素的 DP。设 ( ext{dp}_1(i,j,l)) 表示考虑了前 (i) 个位置,第 (i) 位上填的数是 (j),最后一个 ( eq j) 的位置是 (l),此时使得前 (i) 位组成的序列合法的方案数。

转移时,考虑当前位填了什么:

[egin{cases} ext{dp}_1(i-1,j,l) o ext{dp}_1(i,j, l) && ext{if }lgeq i - ext{len}+1\ ext{dp}_1(i-1,j,l) o ext{dp}_1(i,x, i) && ext{if }x eq j end{cases} ]

初始状态为 ( ext{dp}_1(0,0,0) = 1)。答案是 (sum_{j = 1}^{k}sum_{l = n - ext{len}+1}^{n - 1} ext{dp}_1(n,j,l))

这个朴素 DP 的时间复杂度是 (O(n^2k))


优化它!设上一位填的数为 (j),当前位填的数为 (x)。我们发现,当 (x)(j) 不同时,(x) 的值具体是什么其实不重要:对所有 (x eq j),它们的转移是一模一样的。这就给了我们简化的空间。

定义 (a_{n+1} = k+1)。定义 ( ext{nxt}_i) 表示位置 (i) 后面第一个 (a_{i'} eq -1) 的位置 (i')。设 ( ext{dp}_2(i,jin{0,1},l)) 表示考虑了前 (i) 个位置,第 (i) 位上填的数是 / 否等于 (a_{ ext{nxt}_i}),前 (i) 位里最后一个填的数与第 (i) 位上不同的位置是 (l),此时使得前 (i) 位组成的序列合法的方案数。

转移分 (a_i) 是否为 (-1) 两种情况。

(a_i eq -1) 时,枚举 (l)。则有如下转移:

[egin{cases} ext{dp}_2(i-1,0,l) o ext{dp}_2(i,[a_i=a_{ ext{nxt}_i}],i - 1)\ ext{dp}_2(i-1,1,l) o ext{dp}_2(i,[a_i=a_{ ext{nxt}_i}],l) && ext{if }lgeq i - ext{len} + 1 end{cases} ]

(a_i = -1) 时,显然 ( ext{nxt}_{i} = ext{nxt}_{i-1})。枚举 (l)。我们分别考虑如下情况:

  • (i-1) 位填的数与 (a_{ ext{nxt}_{i}}) 不同:
    • (i) 位上填的数与第 (i-1) 位上填的数相同。
    • (i) 位上填的数与第 (a_{ ext{nxt}_i}) 相同。
    • (i) 位上填的数,既不等于第 (i-1) 位上填的数,也不等于 (a_{ ext{nxt}_{i}})
  • (i-1) 位填的数与 (a_{ ext{nxt}_{i}}) 相同:
    • (i) 位填的数与 (a_{ ext{nxt}_{i}}) 不同。
    • (i) 位填的数与 (a_{ ext{nxt}_{i}}) 相同。

这五种情况分别对应如下转移:

[egin{cases} ext{dp}_2(i - 1, 0, l)& o ext{dp}_2(i, 0, l)&& ext{if }lgeq i - ext{len} + 1\ ext{dp}_2(i - 1, 0, l)& o ext{dp}_2(i, 1, i - 1)\ ext{dp}_2(i - 1, 0, l) imes(k - 2) & o ext{dp}_2(i,0,i-1)\ ext{dp}_2(i - 1, 1, l) imes(k - 1) & o ext{dp}_2(i, 0, i - 1)\ ext{dp}_2(i - 1, 1, l)& o ext{dp}_2(i,1,l) && ext{if }lgeq i - ext{len} + 1 end{cases} ]

上述式子里默认 (1 < i < ext{nxt}_i leq n)。当 (i = 1)( ext{nxt}_i = n+1) 时,有一些特殊情况要考虑。为了表述简洁,这里就不细写了。

现在,这个 DP 的时间复杂度是 (O(n^2))。虽然无法 AC,但这是迈向 (O(n)) 做法的关键一步。我将这一 DP 的代码附在了本文末尾:点击跳转


在这个 DP 里,第 (1) 维的枚举不可避免,第 (2) 维的状态数已经被我们优化到 (O(1))。考虑优化第 (3) 维。

观察转移式,发现从 (i-1) 变成 (i) 时,第 (3) 维的转移,相当于做如下操作:

  1. 区间求和。对所有 (lin[i - ext{len},i - 2]) 求和。
  2. 单点加。
  3. 把一段区间的值覆盖为 (0)
  4. (a_i eq -1)(a_i eq a_{ ext{nxt}_i}) 时,需要把 ( ext{dp}_2(i,1)) 中的一段 (l),复制到 ( ext{dp}_2(i,0)) 对应的位置上。

按第二维的 (0,1),维护两棵线段树,通过打懒标记即可实现这四种操作。

时间复杂度 (O(nlog n))


继续优化,发现区间操作都是假的。

  • 操作 (3) 要么是单点清空,要么是全局清空。
  • 操作 (1) 和操作 (4),因为其他位置都清空了,所以区间求和、区间复制,其实就是全局求和、全局交换。

所以我们只需要用两个数组来维护。

时间复杂度 (O(n))

参考代码

最终代码

友情提醒:使用读入、输出优化可以使代码更快,详见本博客公告。

// problem: CF1093F
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

const int MAXN = 1e5;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }


int n, K, len, a[MAXN + 5];

struct FantasticDataStructure {
	int sum;
	int arr[MAXN + 5];
	int TIM;
	int tim[MAXN + 5];
	void upd(int p) {
		if (tim[p] < TIM) {
			tim[p] = TIM;
			arr[p] = 0;
		}
	}
	void point_add(int p, int v) {
		upd(p);
		add(arr[p], v);
		add(sum, v);
	}
	void point_set0(int p) {
		upd(p);
		sub(sum, arr[p]);
		arr[p] = 0;
	}
	void global_set0() {
		TIM++;
		sum = 0;
	}
	int query() {
		return sum;
	}
	FantasticDataStructure() {}
};
FantasticDataStructure S[2];
int id[2];

// DP 转移: a[i] != -1 / a[i] == -1
void trans1(int p, int nxtval) {
	int v = S[id[0]].query();
	if (a[p] == nxtval) {
		S[id[0]].global_set0();
		if (p - len >= 0) {
			S[id[1]].point_set0(p - len);
		}
		S[id[1]].point_add(p - 1, v);
	} else {
		swap(id[0], id[1]);
		S[id[1]].global_set0();
		if (p - len >= 0) {
			S[id[0]].point_set0(p - len);
		}
		S[id[0]].point_add(p - 1, v);
	}
}
void trans2(int p, int flag) {
	int v0 = S[id[0]].query();
	int v1 = S[id[1]].query();
	
	if (p == 1) {
		S[id[0]].global_set0();
	} else {
		if (p - len >= 0) {
			S[id[0]].point_set0(p - len);
		}
	}
	
	int toadd = 0;
	if (flag + (p != 1) + 1 <= K) {
		toadd = (ll)v0 * (K - flag - (p != 1)) % MOD;
	}
	add(toadd, (ll)v1 * (K - 1) % MOD);
	S[id[0]].point_add(p - 1, toadd);
	
	if (p - len >= 0) {
		S[id[1]].point_set0(p - len);
	}
	if (flag) {
		S[id[1]].point_add(p - 1, v0);
	}
}
int main() {
	cin >> n >> K >> len;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i];
	}
	if (len == 1) {
		cout << 0 << endl;
		return 0;
	}
	
	id[0] = 0;
	id[1] = 1;
	S[id[1]].point_add(0, 1);
	
	a[0] = K + 1;
	a[n + 1] = K + 2;
	for (int i = 1, j = 2; i <= n; ++i) {
		ckmax(j, i + 1);
		while (a[j] == -1)
			++j;
		if (a[i] != -1) {
			trans1(i, a[j]);
		} else {
			trans2(i, j != n + 1);
		}
	}
	cout << S[id[0]].query() << endl;
	return 0;
}

n^2 DP

为了帮助读者更好地理解题解,这里附上 (O(n^2)) 朴素 DP 的代码。

#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

const int MAXN = 1000;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }


int n, K, len, a[MAXN + 5];
int dp[MAXN + 5][2][MAXN + 5];

void trans1(int p, int nxtval) {
	for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
		add(dp[p][a[p] == nxtval][p - 1], dp[p - 1][0][i]);
		if (i >= p - len + 1)
			add(dp[p][a[p] == nxtval][i], dp[p - 1][1][i]);
	}
}
void trans2(int p, int flag) {
	for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
		// a[p - 1] 和 nxtval 不同
		if (p != 1 && i >= p - len + 1) {
			add(dp[p][0][i], dp[p - 1][0][i]); // a[p] 和 a[p - 1] 相同
		}
		if (flag) {
			add(dp[p][1][p - 1], dp[p - 1][0][i]); // a[p] 和 nxtval 相同
		}
		if (flag + (p != 1) + 1 <= K) {
			// a[p] 和 nxtval, a[p - 1] 都不同
			add(dp[p][0][p - 1], (ll)dp[p - 1][0][i] * (K - flag - (p != 1)) % MOD);
		}
		
		// a[p - 1] 和 nxtval 相同
		add(dp[p][0][p - 1], (ll)dp[p - 1][1][i] * (K - 1) % MOD);
		if (i >= p - len + 1) {
			add(dp[p][1][i], dp[p - 1][1][i]); // a[p] 和 a[p - 1], nxtval 都相同
		}
	}
}
int main() {
	cin >> n >> K >> len;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i];
	}
	
	dp[0][0][0] = 1;
	a[0] = K + 1;
	a[n + 1] = K + 2;
	
	for (int i = 1, j = 2; i <= n; ++i) {
		ckmax(j, i + 1);
		while (a[j] == -1)
			++j;
		if (a[i] != -1) {
			trans1(i, a[j]);
		} else {
			trans2(i, j != n + 1);
		}
	}
	int ans = 0;
	for (int i = max(0, n - len + 1); i <= n - 1; ++i) {
		add(ans, dp[n][0][i]);
	}
	cout << ans << endl;
	return 0;
}
原文地址:https://www.cnblogs.com/dysyn1314/p/14022343.html