储能表[SDOI2016]

【题目描述】
有一个 (n)(m) 列的表格,行从 (0)(n-1) 编号,列从 (0)(m-1) 编号。
每个格子都储存着能量。最初,第 (i) 行第 (j) 列的格子储存着 ((i ext{xor} j)) 点能量。所以,整个表格储存的总能量是:

(sumlimits_{i=0}^{n-1}sumlimits_{j=0}^{m-1}(i ext{xor} j))

随着时间的推移,格子中的能量会渐渐减少。一个时间单位后,每个格子中的能量都会减少 (1)。显然,一个格子的能量减少到 (0) 之后就不会再减少了。
也就是说,(k) 个时间单位后,整个表格储存的总能量是:

(sumlimits_{i=0}^{n-1}sumlimits_{j=0}^{m-1}max((i xor j)-k,0))

给出一个表格,求 (k) 个时间单位后它储存的总能量。
由于总能量可能较大,输出时对 (p) 取模。

题解

先把原式拆开 原式=所有满足(i^j)>k 的(i^j)之和 - 所有满足(i^j)>k 的数对(i,j)的数量 * k

可以使用数位DP来求解有多少对(i,j)满足(i^j)>k 以及 它们的(i^j)之和是多少

具体做法就是记忆化搜索 记录(limn, limm, limk)分别表示当前枚举二进制前(i)位,是否卡满(n,m)上界以及(k)下界

这个东西我也不知道怎么讲清楚 不是数位DP的基本套路吗

我用了个pair来存DP答案 first存的是有多少对满足条件 second存的是满足条件的(i^j)之和

注意取模

【代码】

#include <bits/stdc++.h>
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;

inline ll read() {
	ll x = 0, f = 1; char ch = getchar();
	for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
	for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
	return x * f;
}

ll t, n, m, k, mod, mx;
pii dp[105][2][2][2], ans;
bool vis[105][2][2][2];

inline ll getl(ll x) {
	ll ret = 0;
	while (x) ret++, x >>= 1;
	return ret;
}

pii dfs(ll d, bool limn, bool limm, bool limk) {
	if (d > mx) return mp(1, 0);
	if (vis[d][limn][limm][limk]) return dp[d][limn][limm][limk];
	ll N = (n>>(mx-d)) & 1, M = (m>>(mx-d)) & 1, K = (k>>(mx-d)) & 1;
	for (ll i = 0; i <= (limn ? N : 1); i++) {
		for (ll j = 0; j <= (limm ? M : 1); j++) {
			if (limk && (i^j) < K) continue;
			pii res = dfs(d+1, limn&&(i==N), limm&&(j==M), limk&&((i^j)==K));
			dp[d][limn][limm][limk].first = (dp[d][limn][limm][limk].first + res.first) % mod;
			dp[d][limn][limm][limk].second = (((dp[d][limn][limm][limk].second + res.second) % mod) + (1ll << (mx - d)) * (i^j) % mod * res.first % mod) % mod;
		}
	}
	vis[d][limn][limm][limk] = true;
	return dp[d][limn][limm][limk];
}

int main() {
	t = read(); 
	while (t--) {
		n = read(); m = read(); k = read(); mod = read(); 
		n--, m--;
		memset(dp, 0, sizeof(dp)); memset(vis, 0, sizeof(vis));
		mx = max(max(getl(n), getl(m)), getl(k));
		ans = dfs(1, 1, 1, 1);
		printf("%lld
", (ans.second - k % mod * ans.first % mod + mod) % mod);
	}
	return 0;
} 
原文地址:https://www.cnblogs.com/ak-dream/p/AK_DREAM36.html