[HEOI2016/TJOI2016]求和(第二类斯特林数+NTT)

Address

LuoguP4091

Solution

  • \[ans=\sum_{i=0}^{n}\sum_{j=0}^{i}S(i,j)*2^j*(j!) \]

  • 因为\(i>j\) 时,\(S(i,j)=0\),所以:

\[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)*2^j*(j!) \]

  • 众所周知 :

\[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k \]

因此:

\[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k*2^j \]

  • 发现 \(2^j\) 只包含了变量 \(j\),所以把它提到前面:

\[ans=\sum_{j=0}^{n}2^j*\sum_{i=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k \]

  • 然后把 \(C_j^k\) 拆成阶乘形式,再整理得:

\[ans=\sum_{j=0}^{n}2^j*(j!)*\sum_{k=0}^j*\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!} \]

  • 于是令 \(f(i)=\frac{(-1)^i}{i!},g(j)=\frac{\sum_{i=0}^nj^i}{j!}\)
  • 显然 \(g(j)\) 可以用等比数列求和公式变成:

\[\frac{j^{n+1}-1}{j!(j-1)} \]

  • 那么用 \(NTT\)\(f\)\(g\) 乘起来就行了。

Code

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int e = 1e6 + 5, mod = 998244353;
int a[e], lim = 1, rev[e], b[e], n, ans, fa[e], g[e], cc[e], dd[e];

inline int ksm(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = 1ll * res * x % mod;
		y >>= 1;
		x = 1ll * x * x % mod;
	}
	return res;
}

inline void fft(int n, int *a, int op)
{
	int i, j, k, r = (op == 1 ? 3 : 998244354 / 3);
	for (i = 0; i < n; i++) 
	if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (k = 1; k < n; k <<= 1)
	{
		int w0 = ksm(r, (mod - 1) / (k << 1));
		for (i = 0; i < n; i += (k << 1))
		{
			int w = 1;
			for (j = 0; j < k; j++)
			{
				int b = a[i + j], c = 1ll * w * a[i + j + k] % mod;
				a[i + j] = (b + c) % mod;
				a[i + j + k] = (b - c + mod) % mod;
				w = 1ll * w * w0 % mod;
			}
		}
	}
}

int main()
{
	cin >> n;
	int i, k = 0, fac = 1;
	while (lim < n * 2)
	{
		lim <<= 1;
		k++;
	}
	for (i = 1; i < lim; i++)
	rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
	for (i = 0; i <= n; i++)
	{
		if (i != 0) fac = 1ll * fac * i % mod;
		
		if (i & 1) a[i] = mod - 1;
		else a[i] = 1;
		a[i] = 1ll * a[i] * ksm(fac, mod - 2) % mod;
		
		if (i == 0) b[i] = 1;
		else if (i == 1) b[i] = n + 1; 
		else
		b[i] = 1ll * (ksm(i, n + 1) + mod - 1) % mod * ksm(i - 1, mod - 2) % mod
		* ksm(fac % mod, mod - 2) % mod;
	int j;
	fft(lim, a, 1);
	fft(lim, b, 1);
	for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
	fft(lim, a, -1);
	for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * ksm(lim, mod - 2) % mod;
	int p = 1;
	fac = 1;
	for (i = 0; i <= n; i++)
	{
		if (i != 0) fac = 1ll * fac * i % mod;
		int c = a[i];
		ans = (ans + 1ll * c * fac % mod * p) % mod;
		p = 2ll * p % mod;
	}
	cout << ans << endl;
	fclose(stdin);
	fclose(stdout);
	return 0;
}
原文地址:https://www.cnblogs.com/cyf32768/p/12196181.html