【FFT】学习笔记

首先,多项式有两种表示方式,系数表示和点值表示

对于两个多项式相乘而言,用系数表示进行计算是O(n^2)的

而用点值表示进行计算是O(n)的

那么我们自然就会去想如果把系数表示的多项式转化为点值表示的多项式进行计算,不就可以减少时间复杂度了么

然而,一般情况下系数表示的多项式想要转化成点值表示的多项式,或是点值表示的多项式想要转化成系数表示的多项式,复杂度都是O(n^2)的

但这只是一般情况

我们可以通过取特殊值把系数表示转化成点值表示,这样的话能把复杂度降到O(nlogn),这就是DFT了

同样通过求逆之类的操作可以把点值表示转换为系数表示,同样复杂度为O(nlogn),这就是IDFT了

嘛。。。

简单来说就是这样的吧

代码

//by 减维
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<bitset>
#include<set>
#include<cmath>
#include<vector>
#include<set>
#include<map>
#include<ctime>
#include<algorithm>
#define ll long long
#define il inline
#define rg register
#define db double
#define mpr make_pair
#define maxn 200005
#define inf (1<<30)
#define eps 1e-8
#define pi 3.1415926535897932384626
using namespace std;

inline int read()
{
    int ret=0;bool fla=0;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-'){fla=1;ch=getchar();}
    while(ch>='0'&&ch<='9'){ret=ret*10+ch-'0';ch=getchar();}
    return fla?-ret:ret;
}

struct cp{
    db x,y;
}A[maxn],B[maxn],C[maxn],D[maxn];

int n,m,mx,len,rev[maxn],a[maxn],cnt[maxn];

cp operator + (const cp &x,const cp &y){return (cp){x.x+y.x,x.y+y.y};}
cp operator - (const cp &x,const cp &y){return (cp){x.x-y.x,x.y-y.y};}
cp operator * (const cp &x,const cp &y){return (cp){x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x};}

void FFT(cp *a,int op)
{
    for(int i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
    for(int k=1;k<n;k<<=1)
    {
        cp omi=(cp){cos(pi/k),sin(pi/k)*op};
        for(int i=0;i<n;i+=(k<<1))
        {
            cp w=(cp){1.0,0.0};
            for(int j=0;j<k;++j,w=w*omi)
            {
                cp x=a[i+j],y=a[i+j+k]*w;
                a[i+j]=x+y,a[i+j+k]=x-y;
            }
        }
    }
    if(op==-1) for(int i=0;i<n;++i) a[i].x/=n;
}

int main()
{
    n=read();
    for(int i=1,x;i<=n;++i) x=read(),mx=max(mx,3*x),A[x].x=1,B[x*2].x=1,C[x*3].x=1;
    m=mx;
    for(n=1;n<=m;n<<=1) len++;
    for(int i=0;i<n;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    FFT(A,1);FFT(B,1),FFT(C,1);
    for(int i=0;i<n;++i)
    {
        cp tmp1=(cp){1.0/6.0,0};
        cp tmp2=(cp){3.0,0};
        cp tmp3=(cp){2.0,0};
        cp tmp4=(cp){1.0/2.0,0};
        D[i]=(A[i]*A[i]*A[i]-tmp2*B[i]*A[i]+tmp3*C[i])*tmp1;
        D[i]=D[i]+(A[i]*A[i]-B[i])*tmp4;
        D[i]=D[i]+A[i];
    }
    FFT(D,-1);
    for(int i=0;i<n;++i)
    {
        int pri=(int)(D[i].x+0.5);
        if(pri>0) printf("%d %d
",i,pri);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/rir1715/p/8351384.html