[LOJ#3315]「ZJOI2020」抽卡

生成函数神题 QAQ,orz EI

我的多项式水平是外国人水平

题目链接

题目传送门

简要算法

概率与期望、容斥、生成函数、拉格朗日反演、牛顿迭代

(O(m^2))

(O(m^2)) 做法有很多,如 min-max 容斥,下面介绍一种看上去比较有优化空间的做法。

对于 (Ssubseteq{1,2,dots,m}),定义 (end_S=0/1) 表示 (S) 是否存在 (k) 个连续的数。

考虑每一轮的贡献,第 (x) 轮的贡献就是前 (x) 轮操作之后不会到达终态的概率:

[ans=sum_{end_S=0}sum_{xge 0}P(x轮之后选过的数组成的集合恰好为S) ]

对于 (x) 轮之后选过的数组成的集合恰好(S) 的概率,考虑容斥计算:

[P(x轮之后选过的数组成的集合恰好为S)=sum_{Tsubseteq S}(-1)^{|S|-|T|}(frac{sum_{iin T}is_i}m)^x ]

其中 (is_i) 表示是否存在编号为 (i) 的卡。

[ans=sum_{end_S=0}sum_{Tsubseteq S}(-1)^{|S|-|T|}sum_{xge 0}(frac{sum_{iin T}is_i}m)^x=sum_{end_S=0}sum_{Tsubseteq S}(-1)^{|S|-|T|}frac m{m-sum_{iin T}is_i} ]

(w_i(x)=x^{is_i}-1)(G(x)=sum_{end_S=0}prod_{iin S}w_i(x)),则答案为 (sum_{i=0}^{m-1}frac m{m-i}[x^i]G(x))

由于 (is_i=0)(w_i(x)=0)(is_i=1)(w_i(x)=x-1),故只要先 DP (f_{i,j}) 表示前 (i) 种编号选出 (j) 个均为 (is=1) 的方案数(转移可以容斥掉最后一段长为 (k) 的方案),则 (G(x)=sum_{ige 0}f_{max,i}(x-1)^i),可以直接计算。

Solution by EI

对于 (G(x)=sum_{ige 0}f_{max,i}(x-1)^i) 的每一项,注意到 ([x^i]G(x)=sum_{jge i}(-1)^{j-i}inom jif_{max,j}),可以一次卷积求出。

对于上面的 DP,实际上可以把输入的 (a) 数组排序之后分成一些值域连续段,求出每个连续段((is) 全为 (1))中选出 (0,1,dots) 个元素的方案数,最后用一次分治 NTT (O(mlog^2m)) 求出。

现在要解决的问题就是给定 (n),如何对于每个 (i=0,1,dots,n) 计算出在 (n) 个元素中选出 (i) 个使得没有任意连续的 (k) 个元素被选出的方案数。

可以转化成对于每个 (i=0,1,dots,n) 计算出把 (n+1) 拆分成 (n+1-i) 个不超过 (k) 的正整数之和的方案数。转化方法即为增加一个不能选的元素 (n+1),以所有不选的元素为右端点,把该元素左边有被选上的一段元素并起来作为一段。

也就是对于任意 (1le mle n+1) 求出:

[[x^{n+1}](sum_{i=1}^kx_i)^m ]

也就是求二元生成函数:

[frac1{1-u(sum_{i=1}^kx_i)} ]

(x^{n+1}) 次项。

对于只能求某一项的问题我们通常考虑拉格朗日反演,设 (G(x)=sum_{i=1}^kx_i)(G(x)) 的复合逆为 (G^{-1}(x)),我们有:

[[x^{n+1}]frac1{1-u(sum_{i=1}^kx_i)}=frac1{n+1}[x^n]frac u{(1-ux)^2}(frac x{G^{-1}(x)})^{n+1} ]

由于 (frac u{(1-ux)^2})(u) 的次数总是比 (x) 的次数多 (1),故如果求出了 ((frac x{G^{-1}(x)})^{n+1}),就能枚举 (frac u{(1-ux)^2})(x) 的次数计算这两个式子积的第 (n) 项了。

现在要求的就是 (F(x)=G^{-1}(x))。由于 (G(x)=sum_{i=1}^kx_i=frac{x-x^{k+1}}{1-x}),故我们有:

[frac{F(x)-F^{k+1}(x)}{1-F(x)}=x ]

[(1+x)F(x)-F^{k+1}(x)-x=0 ]

可以牛顿迭代。

总复杂度 (O(mlog^2m)),瓶颈在分治 NTT,但牛顿迭代部分的常数还不止一个 log

Code

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

const int N = 3e6 + 5, djq = 998244353;

int n, m, rev[N], yg[N], a[N], b[N], ff, tot, t1[N], t2[N], t3[N], t4[N], inv[N], f[N],
t5[N], t6[N], t7[N], cnt[N], len, fac[N], invf[N], ans;
std::vector<int> A[N];

int qpow(int a, int b)
{
	int res = 1;
	while (b)
	{
		if (b & 1) res = 1ll * res * a % djq;
		a = 1ll * a * a % djq;
		b >>= 1;
	}
	return res;
}

inline void add(int &a, const int &b) {if ((a += b) >= djq) a -= djq;}

inline void sub(int &a, const int &b) {if ((a -= b) < 0) a += djq;}

void FFT(int n, int *a, int op)
{
	for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
	yg[n] = qpow(1312005, (djq - 1) / n * ((n + op) % n));
	for (int i = n >> 1; i; i >>= 1)
		yg[i] = 1ll * yg[i << 1] * yg[i << 1] % djq;
	for (int k = 1; k < n; k <<= 1)
	{
		int x = yg[k << 1];
		for (int i = 0; i < n; i += k << 1)
		{
			int w = 1;
			for (int j = 0, *f1 = a + i, *f2 = a + i + k; j < k; j++, f1++, f2++)
			{
				int u = *f1, v = 1ll * w * (*f2) % djq;
				add(*f1 = u, v); sub(*f2 = u, v);
				w = 1ll * w * x % djq;
			}
		}
	}
	if (op == -1)
	{
		int gg = qpow(n, djq - 2);
		for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * gg % djq;
	}
}

void nealchen(int n)
{
	ff = 1; tot = 0;
	while (ff < n) ff <<= 1, tot++;
	for (int i = 0; i < ff; i++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << tot - 1);
}

void getinv(int n, int *a, int *b)
{
	b[0] = 1;
	for (int k = 1; k <= n; k <<= 1)
	{
		nealchen(k << 2);
		for (int i = k; i < ff; i++) b[i] = 0;
		for (int i = 0; i < ff; i++) t1[i] = i <= n && i < (k << 1) ? a[i] : 0;
		FFT(ff, b, 1); FFT(ff, t1, 1);
		for (int i = 0; i < ff; i++) b[i] = (2ll - 1ll * t1[i] * b[i] % djq
			+ djq) * b[i] % djq;
		FFT(ff, b, -1);
	}
}

void getln(int n, int *a, int *b)
{
	getinv(n, a, t2); b[n] = 0; nealchen(n << 1 | 1);
	for (int i = 1; i <= n; i++) b[i - 1] = 1ll * i * a[i] % djq;
	for (int i = n + 1; i < ff; i++) t2[i] = b[i] = 0;
	FFT(ff, b, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
	FFT(ff, b, -1);
	for (int i = n; i >= 1; i--) b[i] = 1ll * b[i - 1] * inv[i] % djq;
	b[0] = 0;
}

void getexp(int n, int *a, int *b)
{
	b[0] = 1;
	for (int k = 1; k <= n; k <<= 1)
	{
		for (int i = k; i < (k << 2); i++) b[i] = 0;
		getln((k << 1) - 1, b, t3); nealchen(k << 2);
		for (int i = 0; i < ff; i++)
		{
			if (i >= (k << 1)) {t2[i] = 0; continue;}
			t2[i] = i <= n ? a[i] : 0; sub(t2[i], t3[i]); if (!i) add(t2[i], 1);
		}
		FFT(ff, b, 1); FFT(ff, t2, 1);
		for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
		FFT(ff, b, -1);
	}
}

void getpow(int n, int k, int *a, int *b)
{
	getln(n, a, t4);
	for (int i = 0; i <= n; i++) t4[i] = 1ll * t4[i] * k % djq;
	getexp(n, t4, b);
}

void calc(int n, int *a)
{
	for (int i = 0; i <= n; i++) t5[i] = f[i];
	getpow(n, n + 1, t5, t6);
	for (int i = 0; i <= n; i++)
		a[n - i] = 1ll * inv[n + 1] * (i + 1) % djq * t6[n - i] % djq;
}

std::vector<int> polymul(std::vector<int> a, std::vector<int> b)
{
	int n = a.size(), m = b.size(); nealchen(n + m - 1);
	for (int i = 0; i < ff; i++) t1[i] = i < n ? a[i] : 0, t2[i] = i < m ? b[i] : 0;
	FFT(ff, t1, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
	FFT(ff, t1, -1); std::vector<int> res;
	for (int i = 0; i < n + m - 1; i++) res.push_back(t1[i]);
	return res;
}

std::vector<int> nealchen2003(int l, int r)
{
	if (l == r) return A[l];
	int mid = l + r >> 1;
	return polymul(nealchen2003(l, mid), nealchen2003(mid + 1, r));
}

int main()
{
	read(n); read(m); inv[1] = f[0] = fac[0] = invf[0] = 1;
	for (int i = 2; i <= n + 1; i++)
		inv[i] = 1ll * (djq - djq / i) * inv[djq % i] % djq;
	for (int k = 1; k <= n; k <<= 1)
	{
		getpow((k << 1) - 1, m, f, t5); nealchen(k << 2);
		for (int i = k << 1; i < ff; i++) t5[i] = 0;
		for (int i = 0; i < ff; i++) t6[i] = f[i], t7[i] = t5[i];
		FFT(ff, t6, 1); FFT(ff, t7, 1);
		for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
		FFT(ff, t6, -1);
		for (int i = k << 1; i < ff; i++) t6[i] = 0;
		for (int i = (k << 1) - 1; i >= 0; i--)
			t6[i] = i >= m ? (djq - t6[i - m]) % djq : 0,
			t5[i] = i >= m ? (1ll * djq * djq - 1ll * (m + 1) * t5[i - m]) % djq : 0;
		add(t5[1], 1); add(t5[0], 1); sub(t6[0], 1);
		for (int i = 0; i < k; i++) add(t6[i + 1], f[i]), add(t6[i], f[i]);
		getinv((k << 1) - 1, t5, t7); nealchen(k << 2);
		for (int i = k << 1; i < ff; i++) t6[i] = t7[i] = 0;
		FFT(ff, t6, 1); FFT(ff, t7, 1);
		for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
		FFT(ff, t6, -1);
		for (int i = 0; i < (k << 1); i++) sub(f[i], t6[i]);
	}
	getinv(n, f, t5); for (int i = 0; i <= n; i++) f[i] = t5[i];
	for (int i = 1; i <= n; i++) read(a[i]); std::sort(a + 1, a + n + 1);
	for (int i = 1; i <= n; i++)
	{
		if (i == 1 || a[i] > a[i - 1] + 1) len++;
		cnt[len]++;
	}
	for (int i = 1; i <= len; i++)
	{
		calc(cnt[i], t7);
		for (int j = 0; j <= cnt[i]; j++) A[i].push_back(t7[j]);
	}
	std::vector<int> nc = nealchen2003(1, len);
	for (int i = 0; i <= n; i++) t1[i] = nc[i];
	for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % djq,
		invf[i] = 1ll * invf[i - 1] * inv[i] % djq;
	for (int i = 0; i <= n; i++)
	{
		t1[i] = 1ll * t1[i] * fac[i] % djq;
		if (t2[n - i] = invf[i], i & 1) t2[n - i] = djq - t2[n - i];
	}
	nealchen(n << 1 | 1);
	for (int i = n + 1; i < ff; i++) t1[i] = t2[i] = 0;
	FFT(ff, t1, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
	FFT(ff, t1, -1);
	for (int i = 0; i < n; i++) ans = (1ll * invf[i] * t1[n + i]
		% djq * n % djq * inv[n - i] + ans) % djq;
	return std::cout << ans << std::endl, 0;
}
原文地址:https://www.cnblogs.com/xyz32768/p/13295845.html