FFT入门

这篇文章会讲讲FFT的原理和代码。

先贴picks博客(又名FFT从入门到精通):http://picks.logdown.com/posts/177631-fast-fourier-transform

首先FFT是干嘛用的?

额其实在oi中它就是一个用来算“快速卷积”的东西。

卷积是啥?

给定两个数组a、b,求数组c使得:

for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)
        if(i+j<n) c[i+j]+=a[i]*b[j];
这就叫做长度为n的“卷积”。

正常模拟是O(n^2)的,这时候我们就可以用FFT来加速到O(nlogn)!

我们发现,如果我们令a[i]为x^i的系数,那么a、b就可以表示为一个多项式,c就可以被表示为这两个多项式的乘积。

首先我们可以发现,我们对于一个n次多项式,可以用一个多项式的形式来表示它,也可以找到n个位置的值,这样也可以唯一确定这个多项式。

所以我们就初步有了一个思路,我们找到a、b在n个点处的取值,乘在一起,搞回去确定c的多项式形式。

为了和谐,我们一般令n为2的次幂。(注意)

关于这个东西一般有两种写法,一般被称为复数FFT和NTT。

先讲NTT好了......

假设a、b都是整系数多项式,然后模数P十分刺激,满足P是质数,$2^k|P-1$且$2^k>n$时,我们就可以使用NTT。

然后你还要知道原根的有关概念...简单来说就是原根的次幂在模P意义下循环节为$varphi(P)$,对于素数来说就是P-1。

这里就说一点,998244353的原根是3...

设g为P的原根,那么我们令$omega_n=g^{frac{P-1}{n}}$,可以发现:

$omega_{2n}^{2m}=omega_{n}^m$,$omega_{2n}^m=-omega_{2n}^{m+n}$。(确实挺显然的)

那么我们取$omega_n^k$,其中k∈{0...n-1},作为n个点,如何算出这n个点处的取值呢?

我们假设偶次项提出来作为a0,奇次项提出来作为a1。

(例如1+2x+3x^2+4x^3,偶次项提出来为1+3x,奇次项提出来为2+4x,注意这里的次数也要相应改变)

那么我们可以发现

所以我们可以用a0和a1的点值表示算出a的点值表示。

T(n)=2T(n/2)+O(n),由主定理复杂度为O(nlogn)。

接下来转回去的话,由于某种奇怪的性质(详细证明可以看picks博客),我们只要用$omega_{n}^{-m}$代替原来的$omega_n^{m}$,带进去,最后除以n就行了。即把那一堆$omega$翻转一下。

当然如果你真这样瞎搞常数似乎真的挺大的,事实上有一些更靠谱的做法,上图:

image

开始我们把输入的数二进制位翻转,就可以得到左边,然后按这个图上进行蝶形运算(就是刚才那两个公式)就可以算出结果了。

额复数FFT更加简单。

我们令$omega_{n}$为单位根,即满足$x^n=1$的复数,它可以看做复平面上x轴正方向绕逆时针方向旋转$frac{2pi}{n}$的复数。所以$omega_n=cos(frac{2pi}{n})+sin(frac{2pi}{n})i$。

听起来十分靠谱...但是这种东西毕竟自己瞎写的话常数实在太大了...

下面这个是n+e的NTT模板,有改动,uoj34:

#include <iostream>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <time.h>
#include <stdlib.h>
using namespace std;
#define ll long long
ll MOD=998244353;
ll w[2][666666];
ll qp(ll a,ll b)
{
    ll ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%MOD;
        a=a*a%MOD; b>>=1;
    }
    return ans;
}
int K;
void fftinit(int n)
{
    for(K=1;K<n;K<<=1);
    w[0][0]=w[0][K]=1;
    ll g=qp(3,(MOD-1)/K); //3是原根
    for(int i=1;i<K;i++) w[0][i]=w[0][i-1]*g%MOD;
    for(int i=0;i<=K;i++) w[1][i]=w[0][K-i];
}
void fft(int* x,int v)
{
    for(int i=0,j=0;i<K;i++)
    {
        if(i>j) swap(x[i],x[j]);
        for(int l=K>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=K;i<<=1)
    {
        for(int j=0;j<K;j+=i)
        {
            for(int l=0;l<i>>1;l++)
            {
                ll t=(ll)x[j+l+(i>>1)]*w[v][K/i*l]%MOD;
                x[j+l+(i>>1)]=(x[j+l]-t+MOD)%MOD;
                x[j+l]=(x[j+l]+t)%MOD;
            }
        }
    }
    if(!v) return;
    ll rv=qp(K,MOD-2);
    for(int i=0;i<K;i++) x[i]=x[i]*rv%MOD;
}
int N,M,a[666666],b[666666],c[666666];
int main()
{
    scanf("%d%d",&N,&M);
    ++N; ++M; int t=N+M-1;
    for(int i=0;i<N;i++) scanf("%d",a+i);
    for(int i=0;i<M;i++) scanf("%d",b+i);
    fftinit(t); fft(a,0); fft(b,0);
    for(int i=0;i<K;i++) c[i]=(ll)a[i]*b[i]%MOD;
    fft(c,1);
    for(int i=0;i<t;i++) printf("%d ",c[i]);
}
原文地址:https://www.cnblogs.com/zzqsblog/p/5665654.html