「算法笔记」快速傅里叶变换(FFT)

一、引入

首先,定义多项式的形式为 (f(x)=sum_{i=0}^n a_ix^i),其中 (a_i) 为系数,(n) 为次数,这种表示方法称为“系数表示法”,一个多项式是由其系数确定的。

可以证明,(n+1) 个点可以唯一确定一个 (n) 次多项式。对于 (f(x)),代入 (n+1) 个不同的 (x),得到 (n+1) 个不同的 (y)。一个 (n) 次的多项式就可以等价地换成 (n+1) 个等式,相当于平面上的 (n+1) 组坐标 ((x_i,y_i)),这种表示方法称为“点值表示法”

多项式乘法 (卷积):(C(x)=A(x)cdot B(x))(A,B,C) 的系数构成的数列分别为 (a,b,c),则 (c_k=sum_{i=0}^ka_ib_{k-i})。理解:因为 (x^i imes x^{k-i}=x^k)(a_i)(b_{k-i}) 相乘后,它们后面的未知数就变成了 (x^k),对 (c_k) 产生贡献。

暴力求解:两个 (n) 次多项式相乘,时间复杂度 (mathcal{O}(n^2))

快速傅里叶变换(Fast Fourier Transform,简称 FFT)可以 (mathcal{O}(nlog n)) 求解。

二、基本步骤

利用点值表示法,分三步快速求出多项式乘积:

  1. 由系数表示法转换成点值表示法。(DFT)
  2. 利用点值表示法,求两个多项式的乘积。
  3. 再将点值表示法转化成系数表示法。(IDFT)

对于步骤二,给出两个 (n) 次多项式 (A)(B) 的点值表达式,我们可以 (mathcal{O}(n)) 求出其乘积 (C) 的点值表达式。显然 (C)(2n) 次的,取 (2n+1) 个不同的 (x_i)

  • 代入 (A)({(x_0,y_0),(x_1,y_1),cdots,(x_{2n},y_{2n})})
  • 代入 (B)({(x_0,y'_0),(x_1,y'_1),cdots,(x_{2n},y'_{2n})})
  • (C) 的点值表达式为:({(x_0,y_0y'_0),(x_1,y_1y'_1),cdots,(x_{2n},y_{2n}y'_{2n})})

下面重点分析步骤一。有一些前置概念。

三、复数

我们把形如 (z=a+bi)(a,b) 为实数)的数称为 复数,其中,(a,b) 分别叫做复数 (z) 的实部与虚部,(i) 为虚数单位,(i^2=-1)

我们把复数表示在 复平面 上((x) 轴叫实轴,(y) 轴叫虚轴),就像把实数表示在数轴上。如图,复数 (z=a+bi) 与复平面内的点 (Z(a,b)) 一一对应。

连接 (OZ)。复数 (z=a+bi) 与平面向量 (vec{OZ}) 一一对应。

复数的四则运算:

  • 加减法:((a+bi)pm (c+di)=(apm c)+(bpm d)i)
  • 乘法:((a+bi)(c+di)=ac+adi+bci+bdi^2=(ac-bd)+(bc+ad)i)
  • 除法:(frac{a+bi}{c+di}=frac{(a+bi) imes (c-di)}{(c+di) imes(c-di)}=frac{(ac+bd)+(bc-ad)i}{c^2+d^2}=frac{ac+bd}{c^2+d^2}+frac{bc-ad}{c^2+d^2}i)
complex<double>a;    //STL。a.real() 返回复数 a 的实部,a.imag() 返回复数 a 的虚部 
struct cp{    //手写,比 STL 快一点
    double a,b;
    cp operator+(cp &x){return (cp){a+x.a,b+x.b};} 
    cp operator-(cp &x){return (cp){a-x.a,b-x.b};}
    cp operator*(cp &x){return (cp){a*x.a-b*x.b,a*x.b+b*x.a};} 
    cp operator/(cp &x){double v=x.a*x.a+x.b*x.b; return (cp){(a*x.a+b*x.b)/v,(b*x.a-a*x.b)/v};}    //除法一般不用 qwq
};

我们称 (overline{z}=a-bi) 为复数 (z=a+bi)共轭复数

对于复数 (z=a+bi)模长 (|z|) 为点 (Z(a,b)) 到原点的距离,幅角 为对应的向量 (vec{OZ}) 与横轴正半轴的夹角。

复数的乘法可以表达为:模长相乘,辐角相加

四、单位根

1. 定义

(n) 次单位根是满足 (x^n=1) 的复数 (x)

首先,单位根的模长必然为 (1)。因为若 (|x|>1),则 (|x^n|=|x|^n>1);若 (|x|<1),则 (|x^n|=|x|^n<1)

所以单位根表示的点一定在 单位圆(圆心为原点,半径为 (1))上。

其次,单位根的辐角 ( heta) 一定满足 (frac{n heta}{2pi}inmathbb{Z})。也就是一个向量从 ((1,0)) 开始,每次旋转 ( heta) 的角度,旋转 (n) 次后还落在 ((1,0)) 上,那么它一定旋转了整数圈。

然后发现 (n) 次单位根正好是模长为 (1),辐角为 (frac{2kpi}{n}) 的向量对应的复数。

记模长为 (1),辐角为 (frac{2kpi}{n}) 的向量对应的 (n) 次单位根为 (omega_n^k),称为第 (k)(n) 次单位根。

还能发现,(omega_n^k=omega_n^{kmod n}),所以一般情况下,我们认为 (n) 次单位根有 (n) 个,即 (omega_n^0,omega_n^1,cdots,omega_n^{n-1})

2. 性质

单位根的性质:(这些性质在后文会被用到)

  • 性质 1:(n) 次单位根对应的向量将单位圆 (n) 等分。
    两个相邻的 (n) 次单位根对应的向量的夹角相等。单位根的辐角是周角的 (frac{1}{n})

  • 性质 2:(omega_n^k=omega_n^{kmod n})
    在弧度制下,任意弧度 ( heta)( heta+2kpi\,(kinmathbb{Z})) 表示相同的角。

  • 性质 3:({(omega_n^k)}^p=omega_n^{kp})
    (k)(n) 次单位根对应向量的辐角变为原来的 (p) 倍,相当于 (omega_n^{kp}) 对应的向量。

  • 性质 4:(omega_{dn}^{dk}=omega_{n}^k)
    考虑两者对应向量的辐角,(frac{2dkpi}{dn}=frac{2kpi}{n})。也可以这样理解:(dn) 次单位根对应的向量将单位圆 (dn) 等分,取第 (dk) 个。(n) 次单位根对应的向量将单位圆 (n) 等分,取第 (k) 个。两者等价。

  • 性质 5:(omega_n^{k+n/2}=-omega_n^k)。其中 (n) 为偶数。
    相当于一个复数对应的向量进行一次中心对称,(a+bi) 变为 (-a-bi)

根据性质 3 有,({(omega_n^k)}^2=omega_{n}^{2k})。根据性质 4 有,(omega_n^{2k}=omega_{n/2}^k),其中 (n) 为偶数。

3. 求法

根据性质 3,有 (omega_n^k=(omega_n^1)^k)。也就是说,只要求出 (omega_n^1),就能得到 (omega_n^0,omega_n^1,cdots,omega_n^{n-1})

(omega_n^1) 所对应的向量模长为 (1),辐角为 (frac{2pi}{n}),得到 (omega_n^1) 所对应的点为 ((cos(frac{2pi}{n}),sin(frac{2pi}{n})))

(pi)double pi=acos(-1)

补充:(omega_n^k=e^{frac{2pi ik}{n}}=cos(frac{2pi k}{n})+icdot sin(frac{2pi k}{n}))。其中 (i) 为虚数单位。

五、DFT

将系数表示法转换成点值表示法。

1. 基本思路

对于 (n-1) 次多项式(也就是有 (n) 项) (f(x)=sum_{i=0}^{n-1} a_ix^i),将奇偶次数分离。

(f(x)=(a_0+a_2x^2+a_4x^4+cdots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+a_5x^5+cdots+a_{n-1}x^{n-1}))

定义两个新的多项式 (f_1(x))(f_2(x))

  • (f_1(x)=a_0+a_2x+a_4x^2+cdots+a_{n-2}x^{{n/2-1}})
  • (f_2(x)=a_1+a_3x+a_5x^2+cdots+a_{n-1}x^{n/2-1})

于是有 (f(x)=f_1(x^2)+xf_2(x^2))

(omega_n^k\,(k<frac{n}{2})) 代入得:(f(omega_n^k)=f_1(omega_n^{2k})+omega_n^kf_2(omega_n^{2k})=f_1(omega_{n/2}^k)+omega_n^kf_2(omega_{n/2}^k))

同理,将 (omega_n^{k+n/2}\,(k<frac{n}{2})) 代入得:(f(omega_n^{k+n/2})=f_1(omega_n^{2k+n})+omega_n^{k+n/2}f_2(omega_n^{2k+n}))
(=f_1(omega_n^{2k})+omega_n^{k+n/2}f_2(omega_n^{2k})=f_1(omega_{n/2}^k)+omega_n^{k+n/2}f_2(omega_{n/2}^k)=f_1(omega_{n/2}^k)-omega_n^kf_2(omega_{n/2}^k))

发现两者的右边只有正负号的区别。

第一个式子的 (k) 取遍 ([0,frac{n}{2}-1]) 时,(k+frac{n}{2}) 取遍 ([frac{n}{2},n-1])

如果我们知道 (f_1(x),f_2(x)) 分别在 (x=omega_{n/2}^0,omega_{n/2}^1,cdots,omega_{n/2}^{n/2-1}) 的点值表示,就可以 (mathcal{O}(n)) 求出 (f(x))(x=omega_n^0,omega_n^1,cdots,omega_n^{n-1}) 的点值表示。

然后,发现 (f_1(x),f_2(x))(f(x)) 的性质完全相同,这样就把问题分成了两个子问题,对于这两个子问题再进行递归求解。这样就可以在 (mathcal{O}(nlog n)) 的时间复杂度内求出点值表达式。

2. 代码实现

DFT 是利用单位根的特殊性质进行分治。

考虑到它能处理的多项式长度只能为 (2^k)(k) 为整数),否则在分治时左右的项数就会不同。我们可以在最高次补一些系数为 (0) 的项,把原来的 (n) 补到 (2^k\,(2^kgeq n))。这样不影响计算结果。

别忘了:(f(omega_n^k)=f_1(omega_{n/2}^k)+omega_n^kf_2(omega_{n/2}^k),f(omega_n^{k+n/2})=f_1(omega_{n/2}^k)-omega_n^kf_2(omega_{n/2}^k))

递归实现 DFT:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+5;
int n,len;
double x,pi=acos(-1);
complex<double>a[N];
void DFT(complex<double>*a,int n){
    if(n==1) return ;    //边界
    int m=n/2;
    complex<double>a1[m],a2[m];
    for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];    //按奇偶分类 
    DFT(a1,m),DFT(a2,m);    //处理子问题
    complex<double>x(cos(2*pi/n),sin(2*pi/n)),w(1,0);    //x: w_n^1。w: 当前的 n 次单位根,初始值为 w_n^0,即 1。 
    for(int i=0;i<m;i++)
        a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x;    //w*=x: 得到下一个单位根。 
} 
signed main(){
    scanf("%lld",&n);
    for(int i=0;i<n;i++)
        scanf("%lf",&x),a[i]=x;    //a[i] 的实部赋值为 x 
    for(len=1;len<n;len<<=1);    //把项数补到 2 的整幂,高次项的系数默认为 0 
    DFT(a,len);
    for(int i=0;i<len;i++)
        printf("(%.4lf,%.4lf)
",a[i].real(),a[i].imag());
    return 0;
}

六、IDFT

DFT 的逆运算。将点值表示法转化成系数表示法。

结论:把 DFT 中的 (omega_n^1) 换成它的共轭复数,即 ((cos(frac{2pi}{n}),-sin(frac{2pi}{n}))),得到的系数再除以 (n) 即可。证明略。

可以把 IDFT 和 DFT 放在一起写。

void FFT(complex<double>*a,int n,int opt){    //opt=1 为 DFT,opt=-1 为 IDFT
    if(n==1) return ;
    int m=n/2;
    complex<double>a1[m],a2[m];
    for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
    FFT(a1,m,opt),FFT(a2,m,opt);
    complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
    for(int i=0;i<m;i++)
        a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x; 
} 

七、迭代实现

目前的代码如下。可以发现它的效率并不是很高。

//Luogu P3803
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len;
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){    //opt=1/-1: DFT/IDFT
    if(n==1) return ;
    int m=n/2;
    complex<double>a1[m],a2[m];
    for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
    FFT(a1,m,opt),FFT(a2,m,opt);
    complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
    for(int i=0;i<m;i++)
        a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x;     //蝴蝶操作(只是一个名字 qwq)。这里 w*a2[i] 算了两次,先记录下来再算可以减小常数。
} 
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lf",&x),a[i]=x;
    for(int i=0;i<=m;i++)
        scanf("%lf",&x),b[i]=x;
    n=n+m+1;    //一个 n 次多项式和一个 m 次多项式的乘积是一个 n+m 次多项式 (有 n+m+1 项)
    for(len=1;len<n;len<<=1); 
    FFT(a,len,1),FFT(b,len,1);
    for(int i=0;i<len;i++) a[i]*=b[i];    //点值直接乘
    FFT(a,len,-1);
    for(int i=0;i<n;i++)
        printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'
':' ');    //注意这里是除以 len
    return 0;
}

如图,考虑递归的结构:

求出第 (4) 层的状态,就能往上合并求出前 (3) 层的状态。

观察发现,最后的数组下标的序列是原序列的 二进制翻转。比如 (6=(110)_2),反过来就是 ((011)_2=3)。而第 (4)(a_3) 的位置就是原来 (a_6) 的位置。

//r[i] 表示 i 二进制翻转后的结果。求 0~len-1 在二进制位数为 log2(len)-1 意义下的二进制翻转。
for(int i=0;i<len;i++)    //len 是 2 的幂次 
    r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);

理解:考虑 (i)(frac{i}{2}) 在二进制下的关系。(i) 可以看作是 (frac{i}{2}) 在二进制下的每一位左移一位得到。翻转后,(i)(frac{i}{2}) 在二进制下的每一位右移一位得到,然后判一下最后一位即可。

迭代实现:

//Luogu P3803
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len,r[N];
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){    //opt=1/-1: DFT/IDFT
    for(int i=0;i<n;i++) 
        if(i<r[i]) swap(a[i],a[r[i]]);    //求出最后一层的序列 
    for(int k=2;k<=n;k<<=1){    //枚举区间长度 
        int m=k>>1;     //待合并的长度 
        complex<double>x(cos(2*pi/k),sin(2*pi/k)*opt),w(1,0),v;
        for(int i=0;i<n;i+=k,w=1)    //枚举起始点 
            for(int j=i;j<i+m;j++)    //遍历区间
                v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x;    //蝴蝶操作。注意先 a[j+m]=a[j]-v 再 a[j]=a[j]+v。
    }
} 
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lf",&x),a[i]=x;
    for(int i=0;i<=m;i++)
        scanf("%lf",&x),b[i]=x;
    n=n+m+1;
    for(len=1;len<n;len<<=1); 
    for(int i=0;i<len;i++)    //二进制翻转 
        r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
    FFT(a,len,1),FFT(b,len,1);
    for(int i=0;i<len;i++) a[i]*=b[i];
    FFT(a,len,-1);
    for(int i=0;i<n;i++)
        printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'
':' ');
    return 0;
}

八、模板

P3803 【模板】多项式乘法(FFT) 为例。

递归实现:

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len;
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){    //opt=1/-1: DFT/IDFT
    if(n==1) return ;
    int m=n/2;
    complex<double>a1[m],a2[m];
    for(int i=0;i<m;i++) a1[i]=a[i<<1],a2[i]=a[i<<1|1];
    FFT(a1,m,opt),FFT(a2,m,opt);
    complex<double>x(cos(2*pi/n),sin(2*pi/n)*opt),w(1,0);
    for(int i=0;i<m;i++)
        a[i]=a1[i]+w*a2[i],a[i+m]=a1[i]-w*a2[i],w*=x; 
} 
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lf",&x),a[i]=x;
    for(int i=0;i<=m;i++)
        scanf("%lf",&x),b[i]=x;
    n=n+m+1;
    for(len=1;len<n;len<<=1); 
    FFT(a,len,1),FFT(b,len,1);
    for(int i=0;i<len;i++) a[i]*=b[i];
    FFT(a,len,-1);
    for(int i=0;i<n;i++)
        printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'
':' ');
    return 0;
}

迭代实现:(比递归快)

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+5;
int n,m,len,r[N];
double x,pi=acos(-1);
complex<double>a[N],b[N];
void FFT(complex<double>*a,int n,int opt){    //opt=1/-1: DFT/IDFT
    for(int i=0;i<n;i++)
        if(i<r[i]) swap(a[i],a[r[i]]);
    for(int k=2;k<=n;k<<=1){
        int m=k>>1; 
        complex<double>x(cos(2*pi/k),sin(2*pi/k)*opt),w(1,0),v;
        for(int i=0;i<n;i+=k,w=1)
            for(int j=i;j<i+m;j++) v=w*a[j+m],a[j+m]=a[j]-v,a[j]=a[j]+v,w*=x;
    }
} 
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lf",&x),a[i]=x;
    for(int i=0;i<=m;i++)
        scanf("%lf",&x),b[i]=x;
    n=n+m+1;
    for(len=1;len<n;len<<=1); 
    for(int i=0;i<len;i++)
        r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);
    FFT(a,len,1),FFT(b,len,1);
    for(int i=0;i<len;i++) a[i]*=b[i];
    FFT(a,len,-1);
    for(int i=0;i<n;i++)
        printf("%d%c",(int)(a[i].real()/len+0.5),i==n-1?'
':' ');
    return 0;
}

记忆:

  • (f(x)=f_1(x^2)+xf_2(x^2))
  • (f(omega_n^k)=f_1(omega_{n/2}^k)+omega_n^kf_2(omega_{n/2}^k))
  • (f(omega_n^{k+n/2})=f_1(omega_{n/2}^k)-omega_n^kf_2(omega_{n/2}^k))
转载请注明原文链接
原文地址:https://www.cnblogs.com/maoyiting/p/14390149.html