【NOI2017】泳池

一道很棒的神仙题!

(calc(k)) 表示区域大小不超过 (k) 的概率,那么答案就是 (calc(k) - calc(k - 1))

有一些列的第一个位置就是障碍,它们会把泳池划分为不相关的几个部分。可以从这个角度进行 DP

(f_n) 表示,连续 (n) 列,最后一列的第一个位置是障碍,且最大合法矩形面积不超过 (k) 的概率。

(g_n) 表示,连续 (n) 列,每一列的第一个位置都不是障碍,且最大合法矩形面积不超过 (k) 的概率。

枚举上一个障碍点在哪,我们就得到转移:

[f_n = (1-q) sum_{i = 0} ^ {min{n - 1, k}} f_{n - i - 1} g_{i} ]

边界条件是 (f_0 = 1),我们要求的答案就是 (frac{f_{n+1}}{1-q})

现在我们要求出 (g) 数组。

(dp_{i,j}) 表示连续的 (i) 列,出现的最低障碍是 (j+1), 且最大合法矩形面积不超过 (k) 的概率,则 (g_i = sum_{j = 1} ^ {lfloor frac{k}{i} floor} dp_{i}{j})

我们可以枚举第一个出现的最低障碍进行转移,就得到:

[dp_{i,j} = (1-q) q^j sum_{k = 1} ^ i (sum_{t geq j + 1} dp_{k - 1, t}) (sum_{t geq j} dp_{i - k,j}) ]

因为 (ij leq k), 所以状态数只有 (O(klogk)), 加个后缀和优化就可以 (O(k^2logk)) 处理了。

现在的问题是,怎样求得 (f) 的第 (n) 项。

可以发现上面的式子是一个齐次线性递推,直接拖个板子就可以AC啦。

事实上是 UOJ 上被 hack 成 97 分了。

原因是当线性递推式的阶数为 (1) 时,需要对一个一次多项式取模。快速幂的时候初始多项式是一次的,如果不取模就挂了。

#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int N = 2005;
const int mod = 998244353;
 
template <typename T> T read(T &x) {
	int f = 0;
	register char c = getchar();
	while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
	for (x = 0; c >= '0' && c <= '9'; c = getchar())
		x = (x << 3) + (x << 1) + (c ^ 48);
	if (f) x = -x;
	return x;
}

namespace Comb {
	const int Maxn = 1e6 + 10;
	
	int fac[Maxn], fav[Maxn], inv[Maxn];
	
	void comb_init() {
		fac[0] = fav[0] = 1;
		inv[1] = fac[1] = fav[1] = 1;
		for (int i = 2; i < Maxn; ++i) {
			fac[i] = 1LL * fac[i - 1] * i % mod;
			inv[i] = 1LL * -mod / i * inv[mod % i] % mod + mod;
			fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
		}
	}
 
	inline int C(int x, int y) {
		if (x < y || y < 0) return 0;
		return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
	}

	inline int Qpow(int x, int p) {
		int ans = 1;
		for (; p; p >>= 1) {
			if (p & 1) ans = 1LL * ans * x % mod;
			x = 1LL * x * x % mod;
		}
		return ans;
	}

	inline int Inv(int x) {
		return Qpow(x, mod - 2);
	}
 
	inline void upd(int &x, int y) {
		(x += y) >= mod ? x -= mod : 0;
	}

	inline int add(int x, int y) {
		return (x += y) >= mod ? x - mod : x;
	}

	inline int dec(int x, int y) {
		return (x -= y) < 0 ? x + mod : x;
	}

}
using namespace Comb;

namespace Linear {
	int n, k;
	int a[N], h[N], p[N];
	int b[N], c[N];

	void module(int *x) {
		for (int i = k * 2; i >= k; --i) {
			int tmp = 1LL * Inv(p[k]) * x[i] % mod;
			for (int j = 0; j <= k; ++j) {
				x[i - j] = dec(x[i - j], 1LL * p[k - j] * tmp % mod);
			}
		}
	}

	void mul(int *x, int *y, int *z) {
		static int res[N];
		for (int i = 0; i <= k * 2; ++i) res[i] = 0;
		for (int i = 0; i < k; ++i) {
			for (int j = 0; j < k; ++j) {
				upd(res[i + j], 1LL * x[i] * y[j] % mod);
			}
		}
		module(res);
		for (int i = 0; i < k; ++i) z[i] = res[i];
	}
	
	void poly_pow(int p) {
		while (p) {
			if (p & 1) mul(b, c, c);
			mul(b, b, b);
			p >>= 1;
		}
	}
	
	int solve() {
		if (n <= k * 2) return h[n];
		memset(b, 0, sizeof b);
		memset(c, 0, sizeof c);
		memset(p, 0, sizeof p);
		p[k] = 1;
		for (int i = 0; i < k; ++i) p[i] = dec(0, a[k - i]);
		b[1] = 1; c[0] = 1;
		if (k == 1) module(b);
		poly_pow(n - k);
		int ans = 0;
		for (int i = 0; i < k; ++i)
			upd(ans, 1LL * h[i + k] * c[i] % mod);
		return ans;
	}
}

int n, m, x, y, q;
int f[N], g[N], G[N], sdp[N][N], dp[N][N];

int solve() {
	memset(f, 0, sizeof f);
	memset(g, 0, sizeof g);
	memset(dp, 0, sizeof dp);
	memset(sdp, 0, sizeof sdp);
	dp[0][0] = 1;
	for (int i = 1; i <= m; ++i) {
		for (int j = m / i; j >= 0; --j) {
			for (int k = 1; k <= i; ++k) {
				int s1 = 0, s2 = 0;
				if (k == 1) s1 = 1;
				else s1 = sdp[k - 1][j + 1];
				if (i == k) s2 = 1;
				else s2 = sdp[i - k][j];
				upd(dp[i][j], 1LL * s1 * s2 % mod);
			}
			dp[i][j] = 1LL * dp[i][j] * Qpow(q, j) % mod;
			dp[i][j] = 1LL * dp[i][j] * dec(1, q) % mod;
			sdp[i][j] = add(sdp[i][j + 1], dp[i][j]);
		}
	}
	g[1] = dec(1, q);
	for (int i = 2; i <= m + 1; ++i) {
		g[i] = 1LL * sdp[i - 1][1] * dec(1, q) % mod;
	}
	f[0] = 1;
	for (int i = 1; i <= m * 2 + 2; ++i) {
		for (int j = 1; j <= m + 1 && j <= i; ++j) {
			upd(f[i], 1LL * f[i - j] * g[j] % mod);
		}
 	}
	Linear :: n = n + 1;
	Linear :: k = m + 1;
	for (int i = 1; i <= m + 1; ++i) {
		Linear :: a[i] = g[i];
	}
	for (int i = 1; i <= m * 2 + 2; ++i) {
		Linear :: h[i] = f[i];
	}
	return 1LL * Linear :: solve() * Inv(dec(1, q)) % mod;
}

int main() {
	comb_init();
	read(n); read(m); read(x); read(y);
	q = 1LL * x * Inv(y) % mod;
	int ans1 = solve();
	--m;
	int ans2 = solve();
//	cout << ans1 << ' ' << ans2 << endl;
	cout << dec(ans1, ans2) << endl;
	return 0;
}
原文地址:https://www.cnblogs.com/Vexoben/p/11844230.html