「任意模数多项式乘法」

「任意模数多项式乘法」

前置知识

多项式乘法

基本问题

给定一个 (n) 次多项式 (F(x)) 和一个 (m) 次多项式,求出

[F(x) imes G(x) ]

系数对 (p) 取模,且不保证 (p) 可以分解成 (p=2^ka+1) 之形式,(0leq a_i,b_ileq 10^9)(2leq pleq 10^9+9)

考虑直接用 (FFT),但是值域太大,(long; double) 都炸了,精度也无法保证

直接用 (NTT),但是是任意模数,根本用不了

处理这种问题,我们常有两种做法:

三模NTT

(1202) 年了,不会还有人写三模 (NTT)

好吧,其实是我不会

另一种比较常用的做法是

MTT

既然 (FFT) 处理不了值域很大的情况,我们就从问题入手,将值域缩小

不妨将两个多项式拆成:

[F(x)=M imes A(x)+B(x) ]

[G(x)=M imes C(x)+D(x) ]

(M=2^{15}) 时,可以完美避免炸 (double) 的问题

现在问题就转化为了

[(M imes A(x)+B(x)) imes (M imes C(x)+D(x)) ]

[M^2A(x)C(x)+M(B(x)C(x)+A(x)D(x))+B(x)D(x) ]

所以直接 (7)(FFT) 就解决啦

(PS):可以省到 (4)(FFT),但是我还不会

代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>

typedef long long ll;
typedef unsigned long long ull;

using namespace std;

const int maxn = 3e5 + 50, INF = 0x3f3f3f3f;
const double pi = acos (-1);

inline int read () {
	register int x = 0, w = 1;
	register char ch = getchar ();
	for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
	for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
	return x * w;
}

inline void write (register int x) {
	if (x / 10) write (x / 10);
	putchar (x % 10 + '0');
}

int n, m, M, mod, len = 1, bit;
int rev[maxn], f[maxn];

struct Complex {
	double x, y;
	Complex () {}
	Complex (register double a, register double b) { x = a, y = b; }
	inline Complex operator + (const Complex &a) const { return Complex (x + a.x, y + a.y); }
	inline Complex operator - (const Complex &a) const { return Complex (x - a.x, y - a.y); }
	inline Complex operator * (const Complex &a) const { return Complex (x * a.y + y * a.x, y * a.y - x * a.x); }
} g[maxn], a[maxn], b[maxn], c[maxn], d[maxn], omega[maxn];

inline void FFT (register int len, register Complex * a, register int opt) {
	for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
	for (register int d = 1; d < len; d <<= 1) {
		for (register int i = 0; i < len; i += d << 1) {
			for (register int j = 0; j < d; j ++) {
				register Complex w = omega[len / (d << 1) * j]; w.x *= opt;
				register Complex x = a[i + j], y = w * a[i + j + d];
				a[i + j] = x + y, a[i + j + d] = x - y;
			}
		}
	}
}

int main () {
	n = read(), m = read(), mod = read(), M = 1 << 15;
	while (len <= n + m) len <<= 1, bit ++;
	for (register int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1), omega[i] = Complex (sin (2 * pi * i / len), cos (2 * pi * i / len));
	for (register int i = 0, x; i <= n; i ++) x = read(), a[i].y = x / M, b[i].y = x % M;
	for (register int i = 0, x; i <= m; i ++) x = read(), c[i].y = x / M, d[i].y = x % M;
	FFT (len, a, 1), FFT (len, b, 1), FFT (len, c, 1), FFT (len, d, 1);
	for (register int i = 0; i < len; i ++) g[i] = a[i] * c[i]; FFT (len, g, -1);
	for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod * M % mod * M % mod) % mod;
	for (register int i = 0; i < len; i ++) g[i] = a[i] * d[i] + b[i] * c[i]; FFT (len, g, -1);
	for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod * M % mod) % mod;
	for (register int i = 0; i < len; i ++) g[i] = b[i] * d[i]; FFT (len, g, -1);
	for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod) % mod;
	for (register int i = 0; i <= n + m; i ++) printf ("%d ", f[i]); putchar ('
');
	return 0;
}
原文地址:https://www.cnblogs.com/Rubyonly233/p/14220625.html