LG P3803 【模板】多项式乘法

( ext{FFT}) 模板

#include <cstdio> 
#include <iostream>
#include <cmath>
#define re register
using namespace std;

const int N = 2e6 + 1e5;
int rev[N], n, m;

inline int read()
{
	char ch = getchar(); int f = 1, x = 0;
	while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
	while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
	return x * f;
}

const double Pi = acos(-1.0);
struct complex{
	double x, y;
	inline complex operator + (const complex &a) const {return complex{x + a.x, y + a.y};}
	inline complex operator - (const complex &a) const {return complex{x - a.x, y - a.y};}
	inline complex operator * (const complex &a) const {return complex{x * a.x - y * a.y, x * a.y + y * a.x};}
}a[N], b[N];

inline void FFT(complex *a, int lim, int inv)
{
	if (lim == 1) return;
	for(re int i = 0; i < lim; i++)
	if (i < rev[i]) swap(a[i], a[rev[i]]);
	for(re int mid = 1; mid < lim; mid <<= 1)
	{
		complex I = complex{cos(Pi / mid), inv * sin(Pi / mid)};
		for(re int i = 0; i < lim; i += (mid << 1))
		{
			complex W = complex{1, 0};
			for(re int j = 0; j < mid; j++, W = W * I)
			{
				complex x = a[i + j], y = W * a[i + j + mid];
				a[i + j] = x + y, a[i + j + mid] = x - y;
			}
		}
	}
}

int main()
{
	n = read(), m = read();
	for(re int i = 0; i <= n; i++) a[i].x = read();
	for(re int i = 0; i <= m; i++) b[i].x = read();
	
	int limit = 1;
	while (limit <= n + m) limit <<= 1;
	int bit = 0;
	while ((1 << bit) < limit) ++bit;
	for(re int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
	
	FFT(a, limit, 1), FFT(b, limit, 1);
	for(re int i = 0; i < limit; i++) a[i] = a[i] * b[i];
	FFT(a, limit, -1);
	for(re int i = 0; i <= n + m; i++) printf("%d ", (int)(a[i].x / limit + 0.5));
}

( ext{NTT}) 模板

#include <cstdio> 
#include <iostream>
#define LL long long
#define re register
using namespace std;

const int N = 2e6 + 1e5;
const int P = 998244353, g = 3;
int n, m, rev[N], a[N], b[N];

inline void read(int &x)
{
	x = 0; char ch = getchar(); int f = 1;
	while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
	while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
	x *= f;
}

inline int fpow(int x, int y)
{
	int res = 1;
	for(; y; y >>= 1)
	{
		if (y & 1) res = 1LL * res * x % P;
		x = 1LL * x * x % P;
	}
	return res;
}

inline void NTT(int *a, int lim, int inv)
{
	if (lim == 1) return;
	for(re int i = 0; i < lim; i++)
	if (i < rev[i]) swap(a[i], a[rev[i]]);
	for(re int mid = 1; mid < lim; mid <<= 1)
	{
		int I = fpow(g, (P - 1) / (mid << 1));
		if (inv == -1) I = fpow(I, P - 2);
		for(re int i = 0; i < lim; i += (mid << 1))
		{
			int W = 1;
			for(re int j = 0; j < mid; j++, W = 1LL * W * I % P)
			{
				LL x = a[i + j], y = 1LL * W * a[i + j + mid] % P;
				a[i + j] = (x + y) % P, a[i + j + mid] = (x - y + P) % P;
			}
		}
	}
}

int main()
{
	read(n), read(m);
	for(re int i = 0; i <= n; i++) read(a[i]);
	for(re int i = 0; i <= m; i++) read(b[i]);
	
	int limit = 1;
	while (limit <= n + m) limit <<= 1;
	int bit = 0;
	while ((1 << bit) < limit) ++bit;
	for(re int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
	
	NTT(a, limit, 1), NTT(b, limit, 1);
	for(re int i = 0; i < limit; i++) a[i] = 1LL * a[i] * b[i] % P;
	NTT(a, limit, -1);
	int inv = fpow(limit, P - 2);
	for(re int i = 0; i <= n + m; i++) printf("%d ", 1LL * a[i] * inv % P);
}
原文地址:https://www.cnblogs.com/leiyuanze/p/15138928.html