FFT多项式乘法

[LuoguP3803]

学了好久才懂了那么一点点哎

Code:

 1 #include <bits/stdc++.h>
 2 #define ll long long
 3 using namespace std;
 4 const int N = 1e7 + 7;
 5 const double Pi = acos(-1.0);
 6 ll read() {
 7     ll re = 0, f = 1;
 8     char ch = getchar();
 9     while (ch < '0' || ch > '9') {if (ch == '-') f = -f; ch = getchar();}
10     while ('0' <= ch && ch <= '9') {re = re * 10 + ch - '0'; ch = getchar();}
11     return re * f;
12 }
13 int n, m;
14 int l, pos[N], limit = 1;
15 struct Complex{
16     double x, y;
17     Complex (double nx = 0, double ny = 0) {x = nx, y = ny;}
18 }a[N], b[N];
19 Complex operator +(Complex a, Complex b) {return Complex(a.x + b.x, a.y + b.y);}
20 Complex operator -(Complex a, Complex b) {return Complex(a.x - b.x, a.y - b.y);}
21 Complex operator *(Complex a, Complex b) {return Complex(a.x*b.x-a.y*b.y, a.x*b.y+b.x*a.y);}
22 void FFT(Complex *A, int f) {
23     for (int i = 0; i < limit; i++) {
24         if (i < pos[i]) swap(A[i], A[pos[i]]);
25     }
26     for (int mid = 1; mid < limit; mid <<= 1) {
27         Complex wn(cos(Pi / mid), f * sin(Pi / mid));
28         for (int r = mid << 1, i = 0; i < limit; i += r) {
29             Complex w(1, 0);
30             for (int k = 0; k < mid; k++, w = w * wn) {
31                 Complex u = A[i + k], v = w * A[i + k + mid];
32                 A[i + k] = u + v;
33                 A[i + k + mid] = u - v;
34             }
35         }
36     }
37     if (f == -1) {
38         for (int i = 0; i < limit; i++) A[i].x /= limit;
39     }
40 }
41 int main () {
42     n = read(), m = read();
43     for (int i = 0; i <= n; i++) {
44         a[i].x = read();
45     }
46     for (int i = 0; i <= m; i++) {
47         b[i].x = read();
48     }
49     while (limit <= n + m) l++, limit <<= 1;
50     for (int i = 0; i < limit; i++) {
51         pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1));
52     }
53     FFT(a, 1), FFT(b, 1);
54     for (int i = 0; i < limit; i++) {
55         a[i] = a[i] * b[i];
56     }
57     FFT(a, -1);
58     for (int i = 0; i <= n + m; i++) {
59         printf("%d%c", (int)(a[i].x + 0.5), i == n + m ? '
' : ' ');
60     }
61     return 0;
62 }
View Code
原文地址:https://www.cnblogs.com/Sundial/p/12198653.html