hdu 1402 A * B Problem Plus (FFT&DFT)

http://acm.hdu.edu.cn/showproblem.php?pid=1402

  大数乘法,将乘法转换成多项式来解,用DFT加上分治法,将O(n^2)的复杂度降至O(n logn).

  在FFT(快速傅里叶变换)中,构造旋转因子并且利用蝴蝶操作将原来的每个系数离散化,然后将两个多项式对应的系数乘起来。因为傅里叶变换是一个可逆的操作,所以最后IDFT,将答案的每个系数还原,最后输出结果。

  目前我对FFT的机理还没完全理解,只能明白是利用DFT(离散傅里叶变换)将一个多项式离散化,构造出傅里叶级数,然后再利用其周期性的性质,对应系数相乘,这时相当于每一位都和相对的每一位相乘,然后将结果储存到应在的位置。前面的FFT复杂度是O(n logn),最后才进行进位的操作,复杂度是O(n)。因此,总的复杂度就从O(n^2)降到了O(n logn)了。

View Code
  1 #include <cstdio>
  2 #include <cmath>
  3 #include <cstring>
  4 #include <cstdlib>
  5 
  6 const int maxn = 1 << 16;
  7 const double pi = acos((double)-1);
  8 
  9 struct virt{ // 定义虚数,并重载虚数的运算符
 10     double r;
 11     double i;
 12     void ins(double a = 0.0, double b = 0.0){
 13         r = a;
 14         i = b;
 15     }
 16     virt operator + (const virt &x){
 17         virt ret;
 18         ret.ins(r + x.r, i + x.i);
 19         return ret;
 20     }
 21     virt operator - (const virt &x){
 22         virt ret;
 23         ret.ins(r - x.r, i - x.i);
 24         return ret;
 25     }
 26     virt operator * (const virt &x){
 27         virt ret;
 28         ret.ins(r * x.r - i * x.i, r * x.i + i * x.r);
 29         return ret;
 30     }
 31 }revtmp[maxn << 1], val1[maxn << 1], val2[maxn << 1];
 32 int ep[21], ans[maxn << 1];
 33 double base[19];
 34 
 35 void pre(){
 36     ep[0] = 1;
 37     for (int i = 1; i <= 20; i++){
 38         ep[i] = ep[i - 1] << 1;
 39     }
 40     for (int i = 0; i < 19; i++){
 41         base[i] = 2 * pi / ep[i + 1];
 42     }
 43 }
 44 
 45 int rev(int x, int l){
 46     int ret = 0;
 47 
 48     while (l--){
 49         ret |= (x & 1);
 50         ret <<= 1;
 51         x >>= 1;
 52     }
 53 
 54     return ret >> 1;
 55 } // 将数x在二进制形式的最后l位倒置
 56 
 57 void reverse(virt *a, int len){
 58     int bits;
 59 
 60     bits = (int)ceil(log((double)len) / log(2.0));
 61     for (int i = 0; i < len; i++){
 62         revtmp[rev(i, bits)] = a[i];
 63     }
 64     for (int i = 0; i < len; i++){
 65         a[i] = revtmp[i];
 66     }
 67 } // 重组虚数数组,以便之后分治处理DFT,假设长度为8,重置后按自然序排列的倒位序相应为0 4 2 6 1 5 3 7
 68 
 69 void fft(virt *a, int len, bool idft){
 70     virt u, t;
 71     double bits = (int)ceil(log((double)len) / log(2.0));
 72 
 73     reverse(a, len); // 先将数倒位排序
 74     for (int i = 0; i < bits; i++){
 75         virt wi, w; // wi是单位旋转因子
 76 
 77         wi.ins(cos((idft ? -1 : 1) * base[i]), sin((idft ? -1 : 1) * base[i])); // 判断是dft还是idft,两种操作相逆
 78         for (int k = 0; k < len; k += ep[i + 1]){
 79             w.ins(1.0, 0.0); // 重置旋转因子
 80             for (int j = 0; j < ep[i]; j++){
 81 /************下面几行是蝴蝶操作的几个步骤************/
 82                 t = w * a[k + j + ep[i]];
 83                 u = a[k + j];
 84                 a[k + j] = u + t;
 85                 a[k + j + ep[i]] = u - t;
 86                 w = w * wi;
 87             }
 88         }
 89     }
 90     if (idft){ // 如果是idft,要将每一位除以多项式的长度
 91         for (int i = 0; i < len; i++){
 92             a[i].r /= len;
 93         }
 94     }
 95 }
 96 
 97 void cal(virt *a, virt *b, int len){
 98     fft(a, len, false);
 99     fft(b, len, false);
100     for (int i = 0; i < len; i++){
101         a[i] = a[i] * b[i];
102     }
103     fft(a, len, true);
104 #ifndef ONLINE_JUDGE
105     puts("answer:");
106     for (int i = 0; i < len; i++){
107         printf("%2d : %10.5f %10.5f\n", i, a[i].r, a[i].i);
108     }
109     puts("");
110 #endif
111 }
112 
113 void deal(char *a, char *b){
114     int len_a = strlen(a);
115     int len_b = strlen(b);
116     int len = 1;
117 
118     if (len_a > len_b) len = (int)ceil(log((double)len_a) / log(2.0));
119     else len = (int)ceil(log((double)len_b) / log(2.0));
120     len = ep[len + 1];
121 #ifndef ONLINE_JUDGE
122     printf("len %d\n", len);
123 #endif
124     for (int i = 0; i < len_a; i++)
125         val1[i].ins((double)a[len_a - i - 1] - '0');
126     for (int i = len_a; i < len; i++)
127         val1[i].ins();
128     for (int i = 0; i < len_b; i++)
129         val2[i].ins((double)b[len_b - i - 1] - '0');
130     for (int i = len_b; i < len; i++)
131         val2[i].ins();
132 
133     cal(val1, val2, len);
134     for (int i = 0; i < len; i++){
135         ans[i] = val1[i].r + 0.5;
136     }
137     for (int i = 1; i < len; i++){
138         ans[i] += ans[i - 1] / 10;
139         ans[i - 1] %= 10;
140     }
141 
142     len = len_a + len_b;
143     while (ans[len] <= 0) len--;
144     while (len >= 0){
145         printf("%d", ans[len]);
146         len--;
147     }
148     puts("");
149 }
150 
151 char in_a[maxn], in_b[maxn];
152 
153 int main(){
154     pre();
155     while (~scanf("%s%s", in_a, in_b)){
156         if (in_a[0] == '0' || in_b[0] == '0'){
157             printf("0\n");
158             continue;
159         }
160         deal(in_a, in_b);
161     }
162 
163     return 0;
164 }

——written by Lyon

原文地址:https://www.cnblogs.com/LyonLys/p/hdu_1402_Lyon.html