hdu 1402(FFT乘法 || NTT乘法)

A * B Problem Plus

Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others)
Total Submission(s): 9413    Accepted Submission(s): 1468


Problem Description
Calculate A * B.
 
Input
Each line will contain two integers A and B. Process to end of file.

Note: the length of each integer will not exceed 50000.
 
Output
For each case, output A * B in one line.
 
Sample Input
1 2 1000 2
 
Sample Output
2 2000
 
Author
DOOM III
 
Recommend
DOOM III
 

就一个高精度乘法 FFT加速。

最近正好要捡起fft,就顺便整理了模板。

FFT的原理还是算法导论靠谱,没有那么艰深难懂,就涉及怎么进行FFT和FFT需要的原理和定理。

看看算法导论里FFT的部分,一定要读到迭代实现那部分!!

看了好久求和引理,才发觉他是为了保证$w_n^k$与$w_n^{k+2/h}$的对称性(即$w_n^{k+2/h}=-w_n^k$)的,这个引理是必要的。

对于多项式序列,我们可以用两个O(nlgn)(n>max(len1,len2)*2)的FFT将其系数表示转化为点值表示(DFT),然后用O(n) 相乘,接着用FFT把结果的点值表示变为系数表示(IDFT),总体算起来是3O(nlgn)+O(n),即O(nlgn)的时间复杂度。比O(n^2)好多了。

以下是学习的两个版本。

  1 #include<bits/stdc++.h>
  2 #define clr(x) memset(x,0,sizeof(x))
  3 #define clr_1(x) memset(x,-1,sizeof(x))
  4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
  5 #define LL long long
  6 #define mod 1000000007
  7 #define PI 3.1415926535
  8 using namespace std;
  9 char s1[200010],s2[200010];
 10 int a[200010],b[200010];
 11 //复数序列结构体
 12 struct complexed
 13 {
 14         double r,i;
 15         complexed(double _r=0.0,double _i=0.0)
 16         {
 17             r=_r;
 18             i=_i;
 19         }
 20         complexed operator +(complexed b)
 21         {
 22             return complexed(r+b.r,i+b.i);
 23         }
 24         complexed operator -(complexed b)
 25         {
 26             return complexed(r-b.r,i-b.i);
 27         }
 28         complexed operator *(complexed b)
 29         {
 30             return complexed(r*b.r-i*b.i,i*b.r+r*b.i);
 31         }
 32 }num1,num2;
 33 vector<complexed> multi1,multi2;
 34 inline int max(int a,int b)
 35 {
 36     return a>b?a:b;
 37 }
 38 //并将长度变为2…^(k+1)
 39 void changelen(int &len)
 40 {
 41     int mul=1;
 42     while(mul<len)
 43         mul<<=1;
 44     mul<<=1;
 45     len=mul;
 46     return ;
 47 }
 48 //将整数序列复制到复数序列中
 49 void copyed(int *a,vector<complexed> &multi,int len)
 50 {
 51     multi.resize(len);
 52     for(int i=0;i<len;i++)
 53         multi[i]=(complexed){a[i],0};
 54     return;
 55 }
 56 //DFT的话on=1,IDFT on=-1;
 57 void fft(vector<complexed> &multi,int len,int on)
 58 {
 59     complexed wn,w,u,t;
 60     //wn,w,u,t如算法导论中所示
 61     vector<complexed> ans;
 62     ans.resize(len);
 63     //ans存每次操作计算后的y,最后再作为下次的multi。
 64     for(int h=len/2;h>=1;h>>=1)
 65     {
 66         wn=(complexed){cos(2*on*PI/(len/h)),sin(2*on*PI/(len/h))};
 67         for(int i=0;i<h;i++)
 68         {
 69             w=(complexed){1,0};
 70             for(int j=0;j<len/h/2;j++)
 71             {
 72                 //蝴蝶操作
 73                 u=multi[i+2*h*j];
 74                 t=multi[i+2*h*j+h]*w;
 75                 ans[i+h*j]=u+t;
 76                 ans[i+h*j+len/2]=u-t;
 77                 w=w*wn;
 78             }
 79         }
 80         //ans作为下次计算的multi
 81         multi=ans;
 82     }
 83     //IDFT每个元素都得除以n
 84     if(on==-1)
 85         for(int i=0;i<len;i++)
 86             multi[i].r/=len;
 87     return ;
 88 }
 89 int main()
 90 {
 91     int len1,len2,len;
 92     while(scanf("%s%s",s1,s2)!=EOF)
 93     {
 94         len1=strlen(s1);
 95         len2=strlen(s2);
 96         clr(a);
 97         clr(b);
 98         for(int i=0;i<len1;i++)
 99         {
100             a[len1-i-1]=s1[i]-'0';
101         }
102         for(int i=0;i<len2;i++)
103         {
104             b[len2-i-1]=s2[i]-'0';
105         }
106         len=max(len1,len2);
107         //取长度较长者作为长度,并将长度变为2…^(k+1)
108         changelen(len);
109         //将两个整数序列复制到复数序列中
110         copyed(a,multi1,len);
111         copyed(b,multi2,len);
112         //对两个复数序列进行DFT,变为点值表示
113         fft(multi1,len,1);
114         fft(multi2,len,1);
115         //对应点点值相乘
116         for(int i=0;i<len;i++)
117             multi1[i]=multi1[i]*multi2[i];
118         //将的出来的点值表示进行IDFT变为系数表示
119         fft(multi1,len,-1);
120         //四舍五入减小损失精度
121         for(int i=0;i<len;i++)
122         {
123             a[i]=(int)(multi1[i].r+0.5);
124         }
125         //进位
126         for(int i=0;i<len;i++)
127         {
128             a[i+1]=a[i+1]+a[i]/10;
129             a[i]%=10;
130         }
131         len=len1+len2-1;
132         //去掉前导0
133         while(a[len]<=0 && len>0) len--;
134         for(int i=len;i>=0;i--)
135             printf("%d",a[i]);
136         printf("
");
137     }
138     return 0;
139 }
无位逆序置换的步长实现
  1 #include<bits/stdc++.h>
  2 #define clr(x) memset(x,0,sizeof(x))
  3 #define clr_1(x) memset(x,-1,sizeof(x))
  4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
  5 #define LL long long
  6 #define mod 1000000007
  7 #define PI 3.1415926535
  8 using namespace std;
  9 char s1[200010],s2[200010];
 10 int a[200010],b[200010];
 11 struct complexed
 12 {
 13         double r,i;
 14         complexed(double _r=0.0,double _i=0.0)
 15         {
 16             r=_r;
 17             i=_i;
 18         }
 19         complexed operator +(complexed b)
 20         {
 21             return complexed(r+b.r,i+b.i);
 22         }
 23         complexed operator -(complexed b)
 24         {
 25             return complexed(r-b.r,i-b.i);
 26         }
 27         complexed operator *(complexed b)
 28         {
 29             return complexed(r*b.r-i*b.i,i*b.r+r*b.i);
 30         }
 31 }num1,num2;
 32 complexed multi1[200010<<2],multi2[200010<<2];
 33 inline int max(int a,int b)
 34 {
 35     return a>b?a:b;
 36 }
 37 void changelen(int &len)
 38 {
 39     int mul=1;
 40     while(mul<len)
 41         mul<<=1;
 42     mul<<=1;
 43     len=mul;
 44     return ;
 45 }
 46 //将整数序列复制到复数序列中
 47 void copyed(int *a,complexed *multi,int len)
 48 {
 49     for(int i=0;i<len;i++)
 50         multi[i]=(complexed){a[i],0};
 51     return;
 52 }
 53 //位逆序变换
 54 void bitchange(complexed *multi,int len)
 55 {
 56     int i,j,k;
 57     for(i = 1, j = len/2;i < len-1; i++)
 58     {
 59         if(i < j)swap(multi[i],multi[j]);
 60         k = len/2;
 61         while( j >= k)
 62         {
 63             j -= k;
 64             k /= 2;
 65         }
 66         if(j < k) j += k;
 67     }
 68     return ;
 69 }
 70 //DFT的话on=1,IDFT on=-1;
 71 void fft(complexed *multi,int len,int on)
 72 {
 73     bitchange(multi,len);//位逆序置换
 74     complexed wn,w,u,t;//如算法导论所示
 75     for(int h=2;h<=len;h<<=1)
 76     {
 77         wn=(complexed){cos(2*on*PI/h),sin(2*on*PI/h)};
 78         for(int i=0;i<len;i+=h)
 79         {
 80             //蝴蝶操作
 81             w=(complexed){1,0};
 82             for(int j=i;j<i+h/2;j++)
 83             {
 84                 u=multi[j];
 85                 t=multi[j+h/2]*w;
 86                 multi[j]=u+t;
 87                 multi[j+h/2]=u-t;
 88                 w=w*wn;
 89             }
 90         }
 91     }
 92     //IDFT每个元素都得除以n
 93     if(on==-1)
 94         for(int i=0;i<len;i++)
 95             multi[i].r/=len;
 96     return ;
 97 }
 98 void mul(int *a,int *b,int &len1,int &len2)
 99 {
100         int len=max(len1,len2);
101         //取长度较长者作为长度,并将长度变为2…^(k+1)
102         changelen(len);
103         //将两个整数序列复制到复数序列中
104         copyed(a,multi1,len);
105         copyed(b,multi2,len);
106         //对两个复数序列进行DFT,变为点值表示
107         fft(multi1,len,1);
108         fft(multi2,len,1);
109         //对应点点值相乘
110         for(int i=0;i<len;i++)
111             multi1[i]=multi1[i]*multi2[i];
112         //将的出来的点值表示进行IDFT变为系数表示
113         fft(multi1,len,-1);
114         //四舍五入减小损失精度
115         for(int i=0;i<len;i++)
116         {
117             a[i]=(int)(multi1[i].r+0.5);
118         }
119         while(len-1>0 && a[len-1]==0)
120                 len--;
121         len1=len;
122         return ;
123 }
124 int main()
125 {
126     int len1,len2,len;
127     while(scanf("%s%s",s1,s2)!=EOF)
128     {
129         len1=strlen(s1);
130         len2=strlen(s2);
131         clr(a);
132         clr(b);
133         for(int i=0;i<len1;i++)
134         {
135             a[len1-i]=s1[i]-'0';
136         }
137         for(int i=0;i<len2;i++)
138         {
139             b[len2-i]=s2[i]-'0';
140         }
141         mul(a+1,b+1,len1,len2);
142         //进位
143         len=len1;
144         for(int i=1;i<len;i++)
145         {
146             a[i+1]=a[i+1]+a[i]/10;
147             a[i]%=10;
148         }
149         while(a[len]>9)
150         {
151             a[len+1]=a[len+1]+a[len]/10;
152             a[len]%=10;
153             len++;
154         }
155         for(int i=len;i>=1;i--)
156             printf("%d",a[i]);
157         printf("
");
158     }
159     return 0;
160 }
位逆序置换的迭代实现

后来看了ntt,小改了下原迭代实现的模板,实现了迭代实现的NTT模板:

  1 #include<bits/stdc++.h>
  2 #define clr(x) memset(x,0,sizeof(x))
  3 #define clr_1(x) memset(x,-1,sizeof(x))
  4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
  5 #define LL long long
  6 #define mod 1004535809
  7 #define PI 3.1415926535
  8 #define P 1004535809
  9 #define G 3
 10 using namespace std;
 11 char s1[200010],s2[200010];
 12 LL a[200010],b[200010],c[200010];
 13 LL quick_pow(LL mul,LL n)
 14 {
 15     LL res=1;
 16     mul=(mul%mod+mod)%mod;
 17     while(n)
 18     {
 19         if(n%2)
 20             res=res*mul%mod;
 21         mul=mul*mul%mod;
 22         n/=2;
 23     }
 24     return res;
 25 }
 26 inline int max(int a,int b)
 27 {
 28     return a>b?a:b;
 29 }
 30 void bitchange(LL *a,int len)
 31 {
 32     int i,j,k;
 33     for(i = 1, j = len>>1;i < len-1; i++)
 34     {
 35         if(i < j)swap(a[i],a[j]);
 36         k = len>>1;
 37         while( j >= k)
 38         {
 39             j -= k;
 40             k >>= 1;
 41         }
 42         if(j < k) j += k;
 43     }
 44     return ;
 45 }
 46 void changelen(int &len)
 47 {
 48     int mul=1;
 49     while(mul<len)
 50         mul<<=1;
 51     mul<<=1;
 52     len=mul;
 53     return ;
 54 }
 55 //DFT的话on=1,IDFT on=-1;
 56 void ntt(LL *a,int len,LL on)
 57 {
 58     bitchange(a,len);//位逆序置换
 59     LL wn,w,u,t;//如算法导论所示
 60     for(int h=2;h<=len;h<<=1)
 61     {
 62         wn=quick_pow(G,(P-1)/h)%mod;
 63         for(int i=0;i<len;i+=h)
 64         {
 65             //蝴蝶操作
 66             w=1;
 67             for(int j=i;j<i+h/2;j++)
 68             {
 69                 u=a[j]%mod;
 70                 t=a[j+h/2]*w%mod;
 71                 a[j]=(u+t)%mod;
 72                 a[j+h/2]=(u-t+mod)%mod;
 73                 w=w*wn%mod;
 74             }
 75         }
 76     }
 77     //IDFT调换次序实现wn^-1的情况,并且乘以len的逆元
 78     if(on==-1)
 79     {
 80         //k^0显然不调换次序,但是k^1与k^-1,k^2与k^-2.... k^len/2与k^-len/2 要调换次序
 81         for(int i=1;i<len/2;i++)
 82             swap(a[i],a[len-i]);
 83         LL re=quick_pow(len,P-2);
 84         for(int i=0;i<len;i++)
 85             a[i]=a[i]*re%mod;
 86     }
 87     return ;
 88 }
 89 void mul(LL *a,LL *b,int &len1,int &len2)
 90 {
 91         int len=max(len1,len2);
 92         //取长度较长者作为长度,并将长度变为2…^(k+1)
 93         changelen(len);
 94         //对两个整数序列进行DFT,变为点值表示
 95         ntt(a,len,1);
 96         ntt(b,len,1);
 97         //对应点点值相乘
 98         for(int i=0;i<len;i++)
 99             a[i]=b[i]*a[i]%mod;
100         //将的出来的点值表示进行IDFT变为系数表示
101         ntt(a,len,-1);
102         while(len-1>0 && a[len-1]==0)
103                 len--;
104         len1=len;
105         return ;
106 }
107 int main()
108 {
109     int len1,len2,len;
110     while(scanf("%s%s",s1,s2)!=EOF)
111     {
112         len1=strlen(s1);
113         len2=strlen(s2);
114         clr(a);
115         clr(b);
116         for(int i=0;i<len1;i++)
117         {
118             a[len1-i]=s1[i]-'0';
119         }
120         for(int i=0;i<len2;i++)
121         {
122             b[len2-i]=s2[i]-'0';
123         }
124         mul(a+1,b+1,len1,len2);
125         //进位
126         len=len1;
127         for(int i=1;i<len;i++)
128         {
129             a[i+1]=a[i+1]+a[i]/10;
130             a[i]%=10;
131         }
132         while(a[len]>9)
133         {
134             a[len+1]=a[len+1]+a[len]/10;
135             a[len]%=10;
136             len++;
137         }
138         for(int i=len;i>=1;i--)
139             printf("%lld",a[i]);
140         printf("
");
141     }
142     return 0;
143 }
NTT的迭代实现

NTT需要爆搜下找到该质数的原根(这部分一般不写到代码里,一般是自己找出来以后再直接作为常量放在程序里,建议分解完P-1的质因数后去搜索快点,一般原根都不太大比较好搜)。在比赛中一般给出的质数P,P-1后一般是C*2^k的形式,才能支持2^k的分治。

学习资料推荐:http://blog.sina.com.cn/s/blog_7c4c33190102wht6.html 这个看下原理一类的,包括FFT的。其中笔者把(P-1)*2^m写错写成了P*2^m了。

代码以及等价性参考ACdreamer的代码:http://blog.csdn.net/acdreamers/article/details/39026505

原文地址:https://www.cnblogs.com/wujiechao/p/7299853.html