FFT/NTT求高精度多项式乘法

加了蝴蝶变换优化的快速傅里叶变换。

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cmath>
 4 using namespace std;
 5 const int MAXN=1e7+10;
 6 inline int read()
 7 {
 8     char c=getchar();int x=0,f=1;
 9     while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
10     while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
11     return x*f;
12 }
13 const double Pi=acos(-1.0);
14 struct complex
15 {
16     double x,y;
17     complex (double xx=0,double yy=0){x=xx,y=yy;}
18 }a[MAXN],b[MAXN];
19 complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);}
20 complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);}
21 complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);}//不懂的看复数的运算那部分 
22 int N,M;
23 int l,r[MAXN];
24 int limit=1;
25 void fast_fast_tle(complex *A,int type)
26 {
27     for(int i=0;i<limit;i++) 
28         if(i<r[i]) swap(A[i],A[r[i]]);//求出要迭代的序列 
29     for(int mid=1;mid<limit;mid<<=1)//待合并区间的中点
30     {
31         complex Wn( cos(Pi/mid) , type*sin(Pi/mid) ); //单位根 
32         for(int R=mid<<1,j=0;j<limit;j+=R)//R是区间的右端点,j表示前已经到哪个位置了 
33         {
34             complex w(1,0);//
35             for(int k=0;k<mid;k++,w=w*Wn)//枚举左半部分 
36             {
37                  complex x=A[j+k],y=w*A[j+mid+k];//蝴蝶效应 
38                 A[j+k]=x+y;
39                 A[j+mid+k]=x-y;
40             }
41         }
42     }
43 }
44 int main()
45 {
46     int N=read(),M=read();
47     for(int i=0;i<=N;i++) a[i].x=read();
48     for(int i=0;i<=M;i++) b[i].x=read();
49     while(limit<=N+M) limit<<=1,l++;
50     for(int i=0;i<limit;i++)
51         r[i]= ( r[i>>1]>>1 )| ( (i&1)<<(l-1) ) ;
52     // 在原序列中 i 与 i/2 的关系是 : i可以看做是i/2的二进制上的每一位左移一位得来
53     // 那么在反转后的数组中就需要右移一位,同时特殊处理一下复数 
54     fast_fast_tle(a,1);
55     fast_fast_tle(b,1);
56     for(int i=0;i<=limit;i++) a[i]=a[i]*b[i];
57     fast_fast_tle(a,-1);
58     for(int i=0;i<=N+M;i++)
59         printf("%d ",(int)(a[i].x/limit+0.5));
60     return 0;
61 }

据说比fft更快的ntt。

 1 #include<cstdio>
 2 #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
 3 #define swap(x,y) x ^= y, y ^= x, x ^= y
 4 #define LL long long 
 5 const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118; 
 6 char buf[1<<21], *p1 = buf, *p2 = buf;
 7 inline int read() { 
 8     char c = getchar(); int x = 0, f = 1;
 9     while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
10     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
11     return x * f;
12 }
13 int N, M, limit = 1, L, r[MAXN];
14 LL a[MAXN], b[MAXN];
15 inline LL fastpow(LL a, LL k) {
16     LL base = 1;
17     while(k) {
18         if(k & 1) base = (base * a ) % P;
19         a = (a * a) % P;
20         k >>= 1;
21     }
22     return base % P;
23 }
24 inline void NTT(LL *A, int type) {
25     for(int i = 0; i < limit; i++) 
26         if(i < r[i]) swap(A[i], A[r[i]]);
27     for(int mid = 1; mid < limit; mid <<= 1) {    
28         LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));
29         for(int j = 0; j < limit; j += (mid << 1)) {
30             LL w = 1;
31             for(int k = 0; k < mid; k++, w = (w * Wn) % P) {
32                  int x = A[j + k], y = w * A[j + k + mid] % P;
33                  A[j + k] = (x + y) % P,
34                  A[j + k + mid] = (x - y + P) % P;
35             }
36         }
37     }
38 }
39 int main() {
40     N = read(); M = read();
41     for(int i = 0; i <= N; i++) a[i] = (read() + P) % P;
42     for(int i = 0; i <= M; i++) b[i] = (read() + P) % P;
43     while(limit <= N + M) limit <<= 1, L++;
44     for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));    
45     NTT(a, 1);NTT(b, 1);    
46     for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
47     NTT(a, -1);    
48     LL inv = fastpow(limit, P - 2);
49     for(int i = 0; i <= N + M; i++)
50         printf("%d ", (a[i] * inv) % P);
51     return 0;
52 }
原文地址:https://www.cnblogs.com/St-Lovaer/p/13907000.html