FFT模板

递归版

#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<math.h>
//#include<iostream>
using namespace std;

int n,m;
#define maxn 2222222
typedef complex<double> cp;
const double pi=acos(-1);
cp a[maxn],b[maxn],c[maxn];

void dft(cp *a,int n,int type)
{
    if (n==1) return;
    int m=n>>1; cp ll[m],rr[m];
    for (int i=0;i<n;i+=2) ll[i>>1]=a[i],rr[i>>1]=a[i+1];
    dft(ll,m,type); dft(rr,m,type);
    cp base=cp(cos(2*pi/n),sin(2*pi*type/n)),t=cp(1,0),tmp;
    for (int i=0;i<m;i++,t=t*base) tmp=rr[i]*t,a[i]=ll[i]+tmp,a[i+m]=ll[i]-tmp;
}

void mul(cp *a,cp *b,cp *c)
{
    dft(a,n,1); dft(b,n,1);
    for (int i=0;i<n;i++) c[i]=a[i]*b[i];
    dft(c,n,-1);
    for (int i=0;i<=m;i++) c[i]=c[i]/(double)n;
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=0,x;i<=n;i++) scanf("%d",&x),a[i]=x;
    for (int i=0,x;i<=m;i++) scanf("%d",&x),b[i]=x;
    m=m+n; for (n=1;n<=m;n<<=1); mul(a,b,c);
    for (int i=0;i<=m;i++) printf("%d ",(int)(c[i].real()+0.5));
    return 0;
}

非递归版

#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<math.h>
//#include<iostream>
using namespace std;

int n,m,wei;
#define maxn 5222222
typedef complex<double> cp;
const double pi=acos(-1);
cp a[maxn],b[maxn],c[maxn]; int rev[maxn];
void dft(cp *a,int n,int type)
{
    for (int i=0;i<n;i++) if (i<rev[i]) {cp t=a[i]; a[i]=a[rev[i]]; a[rev[i]]=t;}
    for (int i=1;i<=n;i<<=1)
    {
        cp t=cp(cos(pi/i),type*sin(pi/i));
        for (int j=0,p=i<<1;j<n;j+=p)
        {
            cp tmp=cp(1,0);
            for (int k=0;k<i;k++,tmp*=t)
            {
                cp p=a[j+k+i]*tmp;
                a[j+k+i]=a[j+k]-p;
                a[j+k]+=p;
            }
        }
    }
}

void mul(cp *a,cp *b,cp *c)
{
    for (int i=0;i<n;i++)
        for (int j=0;j<wei;j++)
            rev[i]|=((i>>j)&1)<<(wei-j-1);
    dft(a,n,1); dft(b,n,1);
    for (int i=0;i<n;i++) c[i]=a[i]*b[i];
    dft(c,n,-1);
    for (int i=0;i<=m;i++) c[i]=c[i]/(double)n;
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=0,x;i<=n;i++) scanf("%d",&x),a[i]=x;
    for (int i=0,x;i<=m;i++) scanf("%d",&x),b[i]=x;
    m=m+n; for (n=1,wei=0;n<=m;n<<=1,wei++); mul(a,b,c);
    for (int i=0;i<=m;i++) printf("%d ",(int)(c[i].real()+0.5));
    return 0;
}
原文地址:https://www.cnblogs.com/Blue233333/p/8420830.html