SPOJ FFT TSUM

第一道FFT的题目。

在网上找了很多FFT的资料,但一直都看不懂,最后是看算法导论学的FFT,算法导论上面写的很详细,每一步推导过程都有严格的证明。

下面说这道题

题意:

给一个序列s,有n个不互相同的整数。现在从这个序列中选出一个包含3个不同的整数的集合,对于他们的和为sum来说,求一共有多少种选法。(注意:3个数的先后顺序都看做一种选法)

分析:

构造一个多项式A(x),这n个数作为多项式的指数。

A3(x)中的每一项的指数对应三个数的和,前面的系数是取数的方案数。

然而这并不是题目所求,这样的选法是任意取三个数,可能相同可能不同。

其中多计算了不合法的方案:

任意取三个数的方案数 = 取三个相同的数 + 取两个相同的数和另一个不同的数 + 三个互不相同的数

用式子表达出来就是: (图片来自叉姐PPT)

整理一下,答案就是:

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <algorithm>
 5 #include <complex>
 6 #include <cmath>
 7 using namespace std;
 8 
 9 typedef long long LL;
10 const double PI = acos(-1.0);
11 typedef complex<double> Complex;
12 
13 const int maxn = (1 << 17);
14 
15 void FFT(Complex P[], int n, int oper)
16 {
17     for(int i = 1, j = 0; i < n - 1; i++)
18     {
19         for(int s = n; j ^= s >>= 1, ~j & s; );
20         if(i < j) swap(P[i], P[j]);
21     }
22 
23     int log = 0;
24     while((n & (1 << log)) == 0) log++;
25     for(int s = 0; s < log; s++)
26     {
27         int m = (1 << s);
28         int m2 = m * 2;
29         Complex wm = Complex(cos(PI / m), sin(PI / m) * oper);
30         for(int k = 0; k < n; k += m2)
31         {
32             Complex w(1, 0);
33             for(int j = 0; j < m; j++, w = w * wm)
34             {
35                 Complex t = w * P[k + j + m];
36                 Complex u = P[k + j];
37                 P[k + j] = u + t;
38                 P[k + j + m] = u - t;
39             }
40         }
41     }
42 
43     if(oper == -1) for(int i = 0; i < n; i++) P[i].real() /= n;
44 }
45 
46 int A[maxn], A2[maxn], A3[maxn];
47 Complex a[maxn], b[maxn];
48 
49 int main()
50 {
51     int n; scanf("%d", &n);
52     while(n--)
53     {
54         int x; scanf("%d", &x);
55         x += 20000;
56         A[x]++;
57         A2[x*2]++;
58         A3[x*3]++;
59     }
60     for(int i = 0; i < maxn; i++) a[i] = A[i], b[i] = A2[i];
61 
62     FFT(a, maxn, 1);
63     FFT(b, maxn, 1);
64     for(int i = 0; i < maxn; i++) a[i] = a[i] * (a[i] * a[i] - b[i] * 3.0);
65     FFT(a, maxn, -1);
66 
67     for(int i = 0; i < maxn; i++)
68     {
69         LL ans = (LL)((a[i].real() + 0.5) + A3[i] * 2) / 6;
70         if(ans > 0) printf("%d : %lld
", i - 60000, ans);
71     }
72 
73     return 0;
74 }
代码君
原文地址:https://www.cnblogs.com/AOQNRMGYXLMV/p/4808247.html