多项式模板集

NTT && FFT

NTT板子

typedef long long ll;

const int P = 998244353, g = 3;
const int maxn = 1111111;

int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int qpow(int a, int b) {
	int res = 1;
	for (; b; a = 1ll*a*a%P, b >>= 1)
		if (b & 1) res = 1ll*res*a%P;
	return res;
}
int W[maxn << 2], inv[maxn << 2]; // 4倍空间
void prework(int n) {
	for (int i = 1; i < n; i <<= 1) { // 不取等
		W[i] = 1;
		for (int j = 1, Wn = qpow(g, (P-1)/i>>1); j < i; j++) W[i+j] = 1ll*W[i+j-1]*Wn%P; // 不取等
	}
	inv[1] = 1;
	for (int i = 2; i <= n; i++) inv[i] = 1ll*(P-P/i)*inv[P%i]%P; // 取等
}
void ntt(int *a, int n, int opt) {
	static int rev[maxn << 2] = {0}; // {0} 赋初值
	for (int i = 1; i < n; i++)
		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0, t; i < q; i++)
				t = 1ll*a[p+q+i]*W[q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
	if (~opt) return;
	std::reverse(a+1, a+n);
	for (int i = 0; i < n; i++) a[i] = 1ll*a[i]*inv[n]%P;
}

int getsize(int n) { int x = 1; while (x < n) x <<= 1; return x; }

FFT板子

// 手动定义comp
struct comp { double x, y; };
comp operator + (comp a, comp b) { return (comp){a.x+b.x, a.y+b.y}; }
comp operator - (comp a, comp b) { return (comp){a.x-b.x, a.y-b.y}; }
comp operator * (comp a, comp b) { return (comp){a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x}; }

comp W[maxn << 2];
void prework(int n) {
	for (int i = 1; i < n; i <<= 1)
		for (int j = 0; j < i; j++)
			W[i+j] = (comp){cos(PI/i*j), sin(PI/i*j)};
            // 这里直接算,防止丢精度。这在MTT中很重要
}

void fft(comp *a, int n, int opt) { // 与NTT没有区别
	static int rev[maxn << 2] = {0};
	for (int i = 1; i < n; i++)
		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0; i < q; i++) {
				comp t = a[p+q+i]*W[q+i]; a[p+q+i] = a[p+i]-t, a[p+i] = a[p+i]+t;
			}
	if (~opt) return;
	std::reverse(a+1, a+n);
	for (int i = 0; i < n; i++) a[i].x /= n, a[i].y /= n;
}

任意模数NTT(MTT)

MTT事实上就是将一个大整数拆分成两部分(a imes Base+b),分开相乘最后相加就能保证精度了。直接做FFT次数很多常数很大,在%完myy的论文和代码后发现有一种只做4次FFT的方法,需要一些trick和引理。

考虑构造(A_i=a+bi)(B_i=a-bi),如果快速求出(FFT(A))(FFT(B)),那么就能加减消元求出(FFT(a))(FFT(b))。myy论文中讲如果求出了(FFT(A)),可以利用其直接推出(FFT(B))

原文是这样说的:

[B(omega^k)=overline{A(omega^{-k})} ]

其中上划线表示共轭复数

(omega^k)幅角为( heta),注意到

[egin{aligned} B(omega^k)&=sum_{j=0}^nB_jomega^{jk}=sum_{j=0}^n(a_j-b_ji)(cos(j heta)+isin(j heta))\ &=sum_{j=0}^nBig((a_jcos(j heta)+b_jsin(j heta))+i(a_jsin(j heta)-b_jcos(j heta))Big)\ &=sum_{j=0}^nBig((a_jcos(-j heta)-b_jsin(-j heta))-i(a_jsin(-j heta)+b_jcos(-j heta))Big)\ &=overline{sum_{j=0}^nBig((a_jcos(-j heta)-b_jsin(-j heta))+i(a_jsin(-j heta)+b_jcos(-j heta))Big)}\ &=overline{sum_{i=0}^n(a_j+b_ji)(cos(-j heta)+isin(-j heta))}\ &=overline{sum_{i=0}^nA_jomega^{-jk}}=overline{A(omega^{-k})} end{aligned}]

运用这个性质优化可以大幅减少FFT的次数。

void conv(int *x, int *y, int *z, int n) { // z=x*y,长度为n
	for (int i = 0; i < n; i++) x[i] %= P, y[i] %= P; // 提前取模
	static comp a[maxn << 2], b[maxn << 2], da[maxn << 2], db[maxn << 2], dc[maxn << 2], dd[maxn << 2];
	for (int i = 0; i < n; i++)
		a[i] = (comp){x[i] >> 15, x[i] & 32767}, b[i] = (comp){y[i] >> 15, y[i] & 32767}; // 将整数拆分成两部分
	fft(a, n, 1), fft(b, n, 1);
	for (int i = 0; i < n; i++) {
		int j = (n-1) & (n-i);
		static comp a1, a2, b1, b2; // 分离出x和y每个部分的插值结果
		a1 = (a[i] + conj(a[j])) * (comp){0.5, 0};
		a2 = (a[i] - conj(a[j])) * (comp){0, -0.5};
		b1 = (b[i] + conj(b[j])) * (comp){0.5, 0};
		b2 = (b[i] - conj(b[j])) * (comp){0, -0.5};
		da[i] = a1*b1, db[i] = a1*b2, dc[i] = a2*b1, dd[i] = a2*b2;
	}
	for (int i = 0; i < n; i++)
		a[i] = da[i] + db[i]*(comp){0, 1}, b[i] = dc[i] + dd[i]*(comp){0, 1}; // IDFT(x+yi)=IDFT(x)+iIDFT(y),这个将用于下文
	fft(a, n, -1), fft(b, n, -1);
	for (int i = 0; i < n; i++) {
		int ax = (ll)(a[i].x+0.5)%P, ay = (ll)(a[i].y+0.5)%P, bx = (ll)(b[i].x+0.5)%P, by = (ll)(b[i].y+0.5)%P; // 一定要转化成ll(数值为2^30*n)
		z[i] = (((ll)ax << 30) + ((ll)(ay + bx) << 15) + by) % P;
	}
}

快速沃尔什变换(FWT)

用于解决对下标进行位运算卷积问题的方法。即

[c_k=sum_{ioplus j=k}a_i imes b_j ]

其中(oplus)分别为|&^的情况。

先考虑|FWT干这样的事情:类似于FFT,对于每一个(i),它要求出(fwt(a)_i=sum_{j|i=i}a_j),在(mathcal O(nlog n))的时间复杂度在两者之间快速变换。然后(j|i=i,k|i=iLeftrightarrow(j|k)|i=i)

如果求出了(fwt(a))(fwt(b)),发现有

[egin{aligned} fwt(a)_i imes fwt(b)_i&=left(sum_{j|i=i}a_j ight)left(sum_{k|i=i}b_k ight)\ &=sum_{j|i=i,k|i=i}a_jb_k=sum_{(j|k)|i=i}a_jb_k=sum_{t|i=i}sum_{j|k=i}a_jb_k=fwt(c)_iend{aligned}]

很像系数变点值!考虑怎么变换。我们按最高位为0或1来分成两组序列(a^{[0]})(a^{[1]}),不难发现

[fwt(a)=merge(fwt(a^{[0]}),fwt(a^{[0]})+fwt(a^{[1]})) ]

最高位是0的序列对应最高位是1的一定是包含关系,所以右边相加。

同理我们不难发现

[a=merge(a^{[0]},a^{[1]}-a^{[0]}) ]

// or
void fwt_or(int *a, int n, int opt) {
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0; i < q; i++)
				a[p+q+i] = inc(a[p+q+i], ~opt ? a[p+i] : P-a[p+i]);
}

对于&,也满足(fwt(a)_i imes fwt(b)_i=fwt(c)_i)。同上我们也可以推导出

[fwt(a)=merge(fwt(a^{[0]})+fwt(a^{[1]}),fwt(a^{[1]})) ]

[a=merge(a^{[0]}-a^{[1]},a^{[1]}) ]

对于^稍微麻烦些,定义(x otimes y)(x&y)中二进制下1的个数对2取模。有

[(iotimes j) xor (iotimes k)=iotimes(j xor k) ]

构造(fwt(a)_i=sum_{iotimes j=0}a_j-sum_{iotimes j=1}a_j),则

[egin{aligned} fwt(a)_i imes fwt(b)_i&=left(sum_{iotimes j=0}a_j-sum_{iotimes j=1}a_j ight)left(sum_{iotimes j=0}b_j-sum_{iotimes j=1}b_j ight)\ &=sum_{iotimes j=0,iotimes k=0}a_jb_k-sum_{iotimes j=0,iotimes k=1}a_jb_k-sum_{iotimes j=1,iotimes k=0}a_jb_k+sum_{iotimes j=1,iotimes k=1}a_jb_k\ &=sum_{(iotimes j)xor(iotimes k)=0}a_jb_k-sum_{(iotimes j)xor(iotimes k)=1}a_jb_k\ &=sum_{iotimes(j xor k)=0}a_jb_k-sum_{iotimes(j xor k)=1}a_jb_k\ &=fwt(c)_i end{aligned} ]

符合要求,所以能推导出

[fwt(a)=merge(fwt(a^{[0]})+fwt(a^{[1]}),fwt(a^{[0]})-fwt(a^{[1]})) ]

[a=merge(frac{a^{[0]}+a^{[1]}}2,frac{a^{[0]}-a^{[1]}}2) ]

// and && xor
void fwt_and(int *a, int n, int opt) {
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0; i < q; i++)
				a[p+i] = inc(a[p+i], ~opt ? a[p+q+i] : P-a[p+q+i]);
}

void fwt_xor(int *a, int n, int opt) {
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0; i < q; i++) {
				int t = a[p+q+i]; a[p+q+i] = inc(a[p+i], P-t); a[p+i] = inc(a[p+i], t);
				if (opt == -1) a[p+i] = 1ll*a[p+i]*inv2%P, a[p+q+i] = 1ll*a[p+q+i]*inv2%P;
			}
}

多项式

#include <bits/stdc++.h>
using std::reverse; using std::vector; using std::swap; using std::max;
const int N = 100005, P = 998244353, inv2 = P+1>>1;
typedef vector<int> Poly;
typedef long long LL;
int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int pow(int a, int b) {
	int t = 1;
	for (; b; b >>= 1, a = 1LL*a*a%P)
		if (b & 1) t = 1LL*t*a%P;
	return t;
}
int W[N*4], inv[N*4];
void prework(int n) {
	for (int i = 1; i < n; i <<= 1)
		for (int j = W[i] = 1, Wn = pow(3, (P-1)/i>>1); j < i; j++)
			W[i+j] = 1LL*W[i+j-1]*Wn%P;
	inv[1] = 1;
	for (int i = 2; i <= n; i++) inv[i] = 1LL*(P-P/i)*inv[P%i]%P;
}
void fft(Poly &a, int n, int opt) {
	a.resize(n);
	static int rev[N*4];
	for (int i = 1; i < n; i++)
		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) swap(a[i], a[rev[i]]);
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q<<1)
			for (int i = 0, t; i < q; i++)
				t = 1LL*W[q+i]*a[p+q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
	if (opt) return;
	for (int i = 0, inv = pow(n, P-2); i < n; i++) a[i] = 1LL*a[i]*inv%P;
	reverse(a.begin()+1, a.end());
}
Poly poly_inv(Poly A) {
	Poly B(1, pow(A[0], P-2)), C(2);
	for (int L = 1; L < A.size(); L <<= 1) {
		(C = A).resize(L*2); fft(B, L*4, 1), fft(C, L*4, 1);
		for (int i = 0; i < L*4; i++) B[i] = (2*B[i]-1LL*B[i]*B[i]%P*C[i]%P+P)%P;
		fft(B, L*4, 0); B.resize(L*2);
	}
	return B.resize(A.size()), B;
}
int getsz(int n) { int x = 1; while (x < n) x <<= 1; return x; }
void fix(Poly &A) { int x = A.size(); while (x > 1 && !A[x-1]) x--; A.resize(x); }
Poly operator + (Poly A, Poly B) {
	A.resize(max(A.size(), B.size()));
	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], B[i]);
	return fix(A), A;
}
Poly operator - (Poly A, Poly B) {
	A.resize(max(A.size(), B.size()));
	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], P-B[i]);
	return fix(A), A;
}
Poly operator * (int k, Poly A) {
	for (int i = 0; i < A.size(); i++) A[i] = 1LL*k*A[i]%P;
	return A;
}
Poly operator * (Poly A, Poly B) {
	int L = getsz(A.size()+B.size()-1);
	fft(A, L, 1), fft(B, L, 1);
	for (int i = 0; i < L; i++) A[i] = 1LL*A[i]*B[i]%P;
	return fft(A, L, 0), fix(A), A;
}
Poly operator / (Poly A, Poly B) {
	int n = A.size()-B.size()+1;
	reverse(A.begin(), A.end()); A.resize(n);
	reverse(B.begin(), B.end()); B.resize(n);
	return A = A * poly_inv(B), A.resize(n), reverse(A.begin(), A.end()), fix(A), A;
}
Poly operator % (Poly A, Poly B) { return A - A/B*B; }
Poly poly_deri(Poly A) {
	for (int i = 0; i < A.size()-1; i++) A[i] = 1LL*(i+1)*A[i+1]%P;
	return A.resize(A.size()-1), A;
}
Poly poly_int(Poly A) {
	for (int i = A.size()-1; i; i--) A[i] = 1LL*A[i-1]*inv[i]%P;
	return A[0] = 0, A;
}
Poly poly_sqrt(Poly A) {
	Poly B(1, 1), iB, C(2);
	for (int L = 1; L < A.size(); L <<= 1) {
		(C = A).resize(L*2); B.resize(L*2); iB = poly_inv(B);
		fft(B, L*4, 1), fft(iB, L*4, 1), fft(C, L*4, 1);
		for (int i = 0; i < L*4; i++) B[i] = (1LL*B[i]*B[i]+C[i])%P*iB[i]%P*inv2%P;
		fft(B, L*4, 0); B.resize(L*2);
	}
	return B.resize(A.size()), B;
}
Poly poly_ln(Poly A) {
	Poly B = poly_deri(A) * poly_inv(A);
	return B.resize(A.size()), poly_int(B);
}
Poly poly_exp(Poly A) {
	Poly B(1, 1), C;
	for (int L = 1; L < A.size(); L <<= 1)
		B.resize(L*2), C = A + Poly(1, 1) - poly_ln(B), C.resize(L*2), B = B*C;
	return B.resize(A.size()), B;
}
Poly poly_pow(Poly A, int k) {
	return poly_exp(k * poly_ln(A));
}
#define lc (o << 1)
#define rc (o << 1 | 1)
Poly Q[N*4];
void build(int o, int l, int r, int x[]) {
	if (l == r) { Q[o].push_back(P-x[l]), Q[o].push_back(1); return; }
	int mid = l+r>>1;
	build(lc, l, mid, x), build(rc, mid+1, r, x);
	Q[o] = Q[lc] * Q[rc];
}
void calc(Poly A, int o, int l, int r, int x[]) {
	if (l == r) { x[l] = A[0]; return; }
	int mid = l+r>>1;
	calc(A % Q[lc], lc, l, mid, x), calc(A % Q[rc], rc, mid+1, r, x);
}
void poly_calc(Poly A, int n, int x[]) {
	build(1, 1, n, x); calc(A, 1, 1, n, x);
}
Poly inter(int o, int l, int r, int x[], int y[]) {
	if (l == r) return Poly(1, 1LL*y[l]*pow(x[l], P-2)%P);
	int mid = l+r>>1;
	return inter(lc, l, mid, x, y)*Q[rc] + inter(rc, mid+1, r, x, y)*Q[lc];
}
Poly poly_inter(int n, int x[], int y[]) {
	return build(1, 1, n, x), calc(poly_deri(Q[1]), 1, 1, n, x), inter(1, 1, n, x, y);
}
int n; Poly A;
int main() {
	scanf("%d", &n); A.resize(n); prework(n*2);
	for (int i = 0; i < n; i++) scanf("%d", &A[i]);
	A = poly_exp(A);
	for (int i = 0; i < n; i++) printf("%d ", A[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/ac-evil/p/13048974.html