loj6402. yww 与校门外的树

题意

略。(语文太差)

题解

首先一个结论:随机一个([0, 1])之间的实数序列,只用到各个位置相互之间的大小关系,每种关系出现的概率等同于随机一个([1, n])的排列。
原因是出现某些位置值相同的概率是无穷小,并且只有有限种情况出现相同的值,因此可以忽略。
然后在草稿纸上画一画,就知道如果(p_1, p_2, ldots, p_x in [n - x + 1, n]),那么(1, 2, ldots, x)(x + 1, x + 2, ldots, n)之间必然没有边。
考虑dp。但是如果直接dp答案会有点难,所以设(f_n)表示长度为(n)的序列,不存在一个(x),使得(1, 2, ldots, x)(x + 1, x + 2, ldots, n)没有边的不同序列数,并规定(f_0 = 1)
考虑补集转化,有

[f_n = n! - sum_{i = 1} ^ {n - 1} f_{n - i} i! \ 2n! = sum_{i = 0} ^ {n} f_{n - i} i! + [n = 0] \ ]

考虑其生成函数

[F(x) = sum_{i geq 0} f_i x ^ i ]

与一个辅助生成函数

[G(x) = sum_{i geq 0} i! x ^ i ]

[2G(x) = G(x)F(x) + 1 \ F(x) = frac{2G(x) - 1}{G(x)} \ ]

这可以用多项式求逆求出。
考虑答案。设(H(x) = sum_{i geq 0} i f_i x ^ i)(此时(f_i)已求出),则有

[ans = [x ^ n] sum_{k geq 0} H(X) ^ k = [x ^ n] frac{1}{1 - H(x)} ]

再做一次多项式求逆即可。
复杂度(mathcal O(n log n))

#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector <int> poly;
const int N = 1 << 20, mod = 998244353, G = 3;
const ll infty = 1ll * mod * mod;
int power (int a, int b) {
	int ret = 1;
	if (b < 0) {
		b += mod - 1;
	}
	for ( ; b; b >>= 1, a = 1ll * a * a % mod) {
		if (b & 1) {
			ret = 1ll * ret * a % mod;
		}
	}
	return ret;
}
namespace {
	void printp (poly a) {
		if (!a.size()) {
			return;
		}
		printf("%d", a[0]);
		for (int i = 1; i < (int)a.size(); ++i) {
			printf(" %d", a[i]);
		}
		putchar('
');
	}
	int adjust (int n) {
		int ret = 1;
		for ( ; ret < n; ret <<= 1);
		return ret;
	}
	poly trans (int n, int *a) {
		int m = adjust(n);
		poly ret; ret.resize(m, 0);
		for (int i = 0; i < n; ++i) {
			ret[i] = a[i];
		}
		return ret;
	}
	void dnt (int n, poly &_a) {
		static int rev[N], a[N], wi[N];
		for (int i = 0; i < n; ++i) {
			rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
			a[i] = _a[rev[i]];
		}
		for (int l = 2, _w; l <= n; l <<= 1) {
			_w = power(G, (mod - 1) / l), wi[l >> 1] = 1;
			for (int i = (l >> 1) + 1; i < l; ++i) {
				wi[i] = 1ll * wi[i - 1] * _w % mod;
			}
		}
		for (int l = 2, m = 1; l <= n; m = l, l <<= 1) {
			for (int i = 0; i < n; i += l) {
				int *u = a + i, *v = a + i + m, *w = wi + m;
				for (int j = 0, x, y; j < m; ++u, ++v, ++w, ++j) {
					x = *u, y = 1ll * (*v) * (*w) % mod;
					*u = (x + y) % mod, *v = (x - y + mod) % mod;
				}
			}
		}
		for (int i = 0; i < n; ++i) {
			_a[i] = a[i];
		}
	}
	void idnt (int n, poly &_a) {
		static int rev[N], a[N], wi[N];
		for (int i = 0; i < n; ++i) {
			rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
			a[i] = _a[rev[i]];
		}
		for (int l = 2, _w; l <= n; l <<= 1) {
			_w = power(G, mod - 1 - (mod - 1) / l), wi[l >> 1] = 1;
			for (int i = (l >> 1) + 1; i < l; ++i) {
				wi[i] = 1ll * wi[i - 1] * _w % mod;
			}
		}
		for (int l = 2, m = 1; l <= n; m = l, l <<= 1) {
			for (int i = 0; i < n; i += l) {
				int *u = a + i, *v = a + i + m, *w = wi + m;
				for (int j = 0, x, y; j < m; ++u, ++v, ++w, ++j) {
					x = *u, y = 1ll * (*v) * (*w) % mod;
					*u = (x + y) % mod, *v = (x - y + mod) % mod;
				}
			}
		}
		int invn = power(n, mod - 2);
		for (int i = 0; i < n; ++i) {
			_a[i] = 1ll * a[i] * invn % mod;
		}
	}
	poly conv (int n, poly a, poly b, int f = 0) {
		a.resize(n, 0), b.resize(n, 0);
		n <<= 1, a.resize(n, 0), b.resize(n, 0);
		dnt(n, a), dnt(n, b);
		for (int i = 0; i < n; ++i) {
			a[i] = 1ll * a[i] * b[i] % mod;
			if (f) {
				a[i] = 1ll * a[i] * b[i] % mod;
			}
		}
		idnt(n, a);
		return a;
	}
	poly plu (int n, poly a, poly b) {
		for (int i = 0; i < n; ++i) {
			if ((a[i] += b[i]) >= mod) {
				a[i] -= mod;
			}
		}
		return a;
	}
	poly minus (int n, poly a, poly b) {
		for (int i = 0; i < n; ++i) {
			if ((a[i] -= b[i]) < 0) {
				a[i] += mod;
			}
		}
		return a;
	}
	poly inv (int n, poly f) {
		if (n == 1) {
			return (poly) {power(f[0], mod - 2)};
		}
		f.resize(n);
		poly g = inv(n >> 1, f);
		g.resize(n, 0);
		f = conv(n, f, g, 1);
		g = minus(n, plu(n, g, g), f);
		return g.resize(n), g;
	}
}
int fac[N], c[N];
poly A, B, C;
int solve (int n) {
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = 1ll * fac[i - 1] * i % mod;
	}
	A = trans(n, fac);
	B = plu(A.size(), A, A), B[0] = (B[0] - 1 + mod) % mod;
	A = inv(A.size(), A);
	B = conv(B.size(), B, A);
	for (int i = 0; i < n; ++i) {
		c[i] = (mod - 1ll * i * B[i] % mod) % mod;
	}
	c[0] = (c[0] + 1) % mod;
	C = trans(n, c);
	C = inv(C.size(), C);
	return C[n - 1];
}
int n;
int main () {
	cin >> n;
	cout << solve(n + 1) << endl;
	return 0;
}
原文地址:https://www.cnblogs.com/psimonw/p/11637430.html