[arc102E]Stop. Otherwise...[容斥+二项式定理]

题意

给你 (n) 个完全相同骰子,每个骰子有 (k) 个面,分别标有 (1)(k) 的所有整数。对于([2,2k]) 中的每一个数 (x) 求出有多少种方案满足任意两个骰子的和都不为 (x) 的方案数。

分析

  • 对于每个 (x) ,考虑当 (ile x) 时, (i)(x-i) 只能出现一个。将他们看成同一种权值,数量记为 (w) ,剩余权值数量记位 (cnt) ,然后枚举有多少种特殊权值没出现 ((ans)) 并容斥:

[ans_i=2^{w-i}sumlimits_{j=i}^w(-1)^{j-i}inom{n+cnt-j-1}{cnt-j-1}inom{w}{j}inom{j}{i} ]

这样可以 (O(n^3)) 求解。

  • 考虑枚举 (ans) 的过程中和 (j) 这一项有关的内容:

    [egin{aligned}val_j&=sum_limits{i=0}^j(-1)^{j-i}inom{n+cnt-j-1}{cnt-j-1}inom{w}{j}inom{j}{i}2^{w-i}\&=(-1)^jinom{w}{j}inom{n+cnt-j-1}{cnt-j-1}2^wsum_{i=0}^jinom{j}{i}(-1)^{i}2^{-i}\&=(-1)^jinom{w}{j}inom{n+cnt-j-1}{cnt-j-1}2^{w-j}sum_{i=0}^jinom{j}{i}(-1)^{i}2^{j-i}\&=(-1)^jinom{w}{j}inom{n+cnt-j-1}{cnt-j-1}2^{w-j}(2-1)^jend{aligned} ]

    可以 (O(1)) 求一个 (val) ,于是复杂度优化到了 (O(n^2))

  • 注意当 (x) 为偶数时候单独讨论 (frac{x}{2}) 这个权值。

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline void Max(T &a, T b){if(a < b) a = b;}
template <typename T> inline void Min(T &a, T b){if(a > b) a = b;}
const int N = 4007, mod = 998244353;
int n, K, ans;
int fac[N], invfac[N], inv[N], bin[N], suf0[N], suf1[N];
int C(int n, int m) {
	if(n < m) return 0;
	return 1ll * fac[n] * invfac[m] % mod * invfac[n - m] % mod;
}
void add(int &a, int b) {
	a += b;if(a >= mod) a -= mod;
}
void solve(int n, int cnt, int w) {
	for(int i = 0; i <= w; ++i)
		add(ans, 1ll * (i & 1 ? mod - 1: 1) * C(n + cnt - i - 1, cnt - i - 1) % mod * C(w, i)% mod * bin[w - i] % mod);
}
int main() {
	K = gi(), n = gi();
	inv[1] = fac[0] = invfac[0] = 1, bin[0] = 1;
	rep(i, 1, 4000) {
		if(i ^ 1) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
		fac[i] = 1ll * fac[i - 1] * i % mod;
		invfac[i] = 1ll * invfac[i - 1] * inv[i] % mod;
		bin[i] = 1ll * bin[i - 1] * 2 % mod;
	}
	rep(k, 2, 2 * K) {
		ans = 0;
		int w = min(k / 2, K - (k - 1) / 2), cnt = K - w;
		if(k % 2 == 0 && K >= k / 2) {
			solve(n, cnt, w - 1);
			solve(n - 1, cnt, w - 1);
		}
		else solve(n, cnt, w);
		printf("%d
", ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/yqgAKIOI/p/10204393.html