【CF960G】Bandit Blues

【CF960G】Bandit Blues

题面

洛谷

题解

思路和这道题一模一样,这里仅仅阐述优化的方法。

看看答案是什么:

[Ans=C(a+b-2,a-1)centerdot s(n-1,a+b-2) ]

组合数我们已经可以(O(N))求了,主要是第一类斯特林数存在问题。

考虑它的转移:

[s(n,m)=s(n-1,m-1)+(n-1)*s(n-1,m) ]

根据这个转移,我们写出它(n)固定时的生成函数

[G(x)=prod_{i=0}^{n-1}(x+i) ]

然后每一个(s(n,m))就是升序第(m)项的次数。

为什么生成函数是这个?

引用(yyb)的:

(n)为定值时的所有的第一类斯特林数按照(n)分类分成行,发现每次的(s(n,m))转移必定要从(n−1)行转移过来,而每次转移都是(m−1)变到(m),系数为(1),因此有一项(x),同理有一项(n−1),因此就可以得到上面的那个生成函数。

然后对于这个东西我们用分治+(NTT)就可以(O(nlog^2))地做了。

代码

#include <iostream> 
#include <cstdio> 
#include <cstdlib> 
#include <cstring> 
#include <cmath> 
#include <algorithm>
#include <vector> 
using namespace std;
const int Mod = 998244353;
int fpow(int x, int y) {
	int res = 1; 
	while (y) {
		if (y & 1) res = 1ll * res * x % Mod; 
		y >>= 1; 
		x = 1ll * x * x % Mod; 
	}
	return res; 
} 
const int G = 3, iG = fpow(G, Mod - 2); 
const int MAX_N = 3e5 + 5; 
int rev[MAX_N], Limit; 
void NTT(vector<int> &p, int op) { 
	for (int i = 0; i < Limit; i++) if (i < rev[i]) swap(p[i], p[rev[i]]); 
	for (int i = 1; i < Limit; i <<= 1) { 
		int rot = fpow(op == 1 ? G : iG, (Mod - 1) / (i << 1)); 
		for (int j = 0; j < Limit; j += (i << 1)) { 
			int w = 1; 
			for (int k = 0; k < i; k++, w = 1ll * w * rot % Mod) { 
				int x = p[j + k], y = 1ll * w * p[i + j + k] % Mod; 
				p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod; 
			} 
		} 
	} 
	if (op == -1) { 
		int inv = fpow(Limit, Mod - 2); 
		for (int i = 0; i < Limit; i++) p[i] = 1ll * inv * p[i] % Mod; 
	} 
}

vector<int> mul(vector<int> &A, vector<int> &B) { 
	static vector<int> C; 
	C.clear(); 
	int p = 0, sz = A.size() + B.size() - 2; 
	for (Limit = 1; Limit <= sz + 1; Limit <<= 1, ++p); 
	for (int i = 0; i < Limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (p - 1)); 
	A.resize(Limit), B.resize(Limit);
	NTT(A, 1), NTT(B, 1); 
	for (int i = 0; i < Limit; i++) C.push_back(1ll * A[i] * B[i] % Mod); 
	NTT(C, -1); 
	return C; 
} 
vector<int> Div(int l, int r) { 
	vector<int> L, R; 
	if (l == r) return {l, 1}; 
	int mid = (l + r) >> 1; 
	L = Div(l, mid), R = Div(mid + 1, r);  
	return mul(L, R); 
}
int fac(int x) { int res = 1; for (int i = 1; i <= x; i++) res = 1ll * res * i % Mod; return res; } 
int C(int n, int m) {
	if (m > n) return 0; 
	else return 1ll * fac(n) * fpow(fac(m), Mod - 2) % Mod * fpow(fac(n - m), Mod - 2) % Mod; 
} 
vector<int> Ans; 
int main () {
#ifndef ONLINE_JUDGE 
    freopen("cpp.in", "r", stdin); 
#endif 
	int N, A, B; 
	cin >> N >> A >> B;
	if (!A || !B || A + B - 2 > N - 1) return puts("0") & 0; 
    if (N == 1) return puts("1") & 0; 
	Ans = Div(0, N - 2); 
	printf("%d
", (int)(1ll * Ans[A + B - 2] * C(A + B - 2, A - 1) % Mod)); 
	return 0; 
} 

原文地址:https://www.cnblogs.com/heyujun/p/10360065.html