LuoGuP4721:【模板】分治 FFT

Pre

式子变换需要注意一下

Solution

注意到(f(x)g(x)+f_0=f(x))

其实开始我没看出来,后来发现仔细分析一下就可以了。

然后式子变换

(f(x)=frac{f_0}{1-g(x)})

注意这里的(1-)是只减常数项,因为这里的(f(x))(g(x))是指的函数,而不是系数。

Code

#include <cstdio>
#include <queue>
#include <cstring>
#define ll long long
#define xx first
#define yy second
using namespace std;
inline void swap (int &a, int &b) {
	int c = a;
	a = b,
	b = c;
}
const int N = 250000 + 5, mod = 998244353, inver = 332748118;
int nn, g[N], f[N];
inline int add (int u, int v) {return u + v >= mod ? u + v - mod : u + v;}
inline int mns (int u, int v) {return u - v < 0 ? u - v + mod : u - v;}
inline int mul (int u, int v) {return 1LL * u * v % mod;}
inline int qpow (int u, int v) {
	int tot = 1, base = u % mod;
	while (v){
		if (v & 1) tot = mul (tot, base);
		base = mul (base, base);
		v >>= 1;
	}
	return tot;
}
int c[N], rev[N];
inline void NTT (int *a, int n, int bit, bool flag) {
	for (int i = 0; i < n; ++i) {
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
		if (i > rev[i]) swap (a[i], a[rev[i]]);
	}
	for (int l = 2; l <= n; l <<= 1) {
		int wi = qpow (flag ? inver : 3, (mod - 1) / l);
		int m = l / 2;
		for (int *k = a; k != a + n; k += l) {
			int w = 1;
			for (int i = 0; i < m; ++i) {
				int tmp = mul (k[i + m], w);
				k[i + m] = mns (k[i], tmp);
				k[i] = add (k[i], tmp);
				w = mul (w, wi);
			}
		}
	}
	int tmp = qpow (n, mod - 2);
	for (int i = 0; i < n && flag; ++i) {
		a[i] = mul (a[i], tmp);
	}
}
inline void Inv (int *a, int *b, int deg) {
	if (deg == 1) {
		b[0] = qpow (a[0], mod - 2);
		return ;
	}
	Inv (a, b, (deg + 1) >> 1);
	int n = 1, bit = 0;
	while (n < (deg << 1)) n <<= 1, ++bit;
	for (int i = 0; i < deg; ++i) c[i] = a[i]; for (int i = deg; i < n; ++i) c[i] = 0;
	NTT (c, n, bit, false);
	NTT (b, n, bit, false);
	for (int i = 0; i < n; ++i) b[i] = mns (mul (2, b[i]), mul (c[i], mul (b[i], b[i])));
	NTT (b, n, bit, true);
	for (int i = deg; i < n; ++i) b[i] = 0;
}
int main () {
	#ifdef chitongz
	freopen ("x.in", "r", stdin);
	#endif
	scanf ("%d", &nn);
	for (int i = 1; i <= nn - 1; ++i) scanf ("%d", &g[i]), g[i] = mns (mod, g[i]);
	g[0] = add (g[0], 1);
	Inv (g, f, nn);
	for (int i = 0; i < nn; ++i) printf ("%d ", f[i]);
	puts ("");
	return 0;
}

Conclusion

注意一下什么时候系数减法,什么时候常熟减法。

原文地址:https://www.cnblogs.com/ChiTongZ/p/11351252.html