并不对劲的DFT

FFT是一个很多人选择背诵全文的算法。

#include<algorithm>
#include<cmath>
#include<complex>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define cd complex<double>
#define maxn 1000110
#define maxm (maxn<<1)
using namespace std;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)&&ch!='-')ch=getchar();
	if(ch=='-')f=-1,ch=getchar();
	while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
	return x*f;
}
void write(int x)
{
	if(x==0){putchar('0'),putchar('
');return;}
	int f=0;char ch[20];
	if(x<0)putchar('-'),x=-x;
	while(x)ch[++f]=x%10+'0',x/=10;
	while(f)putchar(ch[f--]);
	putchar('
');
	return;
}
const double pi=acos(-1);
int n,m,r[maxm],len;
cd a[maxm],b[maxm];
void fft(cd * c,double f)
{
	rep(i,0,n-1)if(i<r[i])swap(c[i],c[r[i]]);
	for(int i=1;i<n;i<<=1)
	{
		cd wn(cos(pi/i),sin(f*pi/i)),x,y;
		for(int j=0;j<n;j+=(i<<1))
		{
			cd w(1,0);
			rep(k,0,i-1)
				x=c[j+k],y=w*c[j+i+k],c[j+k]=x+y,c[j+i+k]=x-y,w*=wn;	
		} 
	}
}
int main()
{
	n=read(),m=read();
	rep(i,0,n)a[i]=read();
	rep(i,0,m)b[i]=read();
	m+=n;
	for(n=1;n<=m;n<<=1)len++;
	rep(i,0,n-1)r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
	fft(a,1),fft(b,1);
	rep(i,0,n-1)a[i]*=b[i];
	fft(a,-1);
	rep(i,0,m)printf("%d ",int(a[i].real()/n+0.5));
	return 0;
}

并不对劲的大剑使认为有必要说明为什么IDFT最后要除以n。
假设原函数是 (F(x)=a_0x^0+a_1x^1+...+a_{n-1}x^{n-1})
那么对它进行DFT,将 (omega_n^0,omega_n^1,...,omega_n^{n-1})依次代入,得到(y_0,y_1,...,y_{n-1})
(G(x)=y_0x^0+y_1x^1+...+y_{n-1}x^{n-1})
IDFT就是已知G每一项的系数,求F每一项的系数
先对G进行DFT,将 (omega_n^0,omega_n^{-1},...,omega_n^{-(n-1)}) 依次代入,得到 (z_0,z_1,...,z_{n-1})
那么就会有 (z_k=sumlimits_{i=0}^{n-1}{y_i*(omega_n^{-k})^i})

[spacespacespace=sum_{i=0}^{n-1}{(sum_{j=0}^{n-1}a_j*(omega_n^i)^j)(omega_n^{-k})^i} ]

[space=sum_{i=0}^{n-1}{sum_{j=0}^{n-1}a_j*omega_n^{i*j}*omega_n^{-k*i}} ]

[=sum_{i=0}^{n-1}{sum_{j=0}^{n-1}a_j*omega_n^{i*(j-k)}}spacespacespacespacespace ]

[=sum_{j=0}^{n-1}{a_j*(sum_{i=0}^{n-1}omega_n^{i*(j-k)})}spacespace ]

可以将 (sumlimits_{i=0}^{n-1}omega_n^{i*(j-k)}) 这部分看成一个首项为1,公比为(omega_n^{j-k})的等比数列,那么由等比数列求和公式可知:
(j-k eq0)时,(sumlimits_{i=0}^{n-1}omega_n^{i*(j-k)}=frac{(omega_n^{j-k})^n-1}{omega_n^{j-k}-1}=frac{(omega_n^n)^{j-k}-1}{omega_n^{j-k}-1}=frac{(omega_n^0)^{j-k}-1}{omega_n^{j-k}-1}=frac{1^{j-k}-1}{omega_n^{j-k}-1}=0)
(j-k=0)时,(sumlimits_{i=0}^{n-1}omega_n^{i*(j-k)}=sum_{i=0}^{n-1}omega_n^{i*0}=sum_{i=0}^{n-1}1=n)
这样就有(z_k=a_k*nspace =>space a_k=frac{z_k}{n})
也就是说,IDFT相当于对当前多项式进行一次代入(omega_n^0,omega_n^{-1},...,omega_n^{-(n-1)})的DFT,再将每一项的系数除以n。

快速数论变化(FNTT)也是一个很多人选择背诵全文的算法。
只是把单位根换成了模意义下的原根而已

#include<algorithm>
#include<cmath>
#include<complex>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(int i=(x);i>=(y);--i)
#define cd complex<double>
#define maxn 2000110
#define maxm (maxn<<1)
#define LL long long
using namespace std;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)&&ch!='-')ch=getchar();
    if(ch=='-')f=-1,ch=getchar();
    while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
    return x*f;
}
void write(int x)
{
    if(x==0){putchar('0'),putchar('
');return;}
    int f=0;char ch[20];
    if(x<0)putchar('-'),x=-x;
    while(x)ch[++f]=x%10+'0',x/=10;
    while(f)putchar(ch[f--]);
    putchar('
');
    return;
}
const double pi=acos(-1);
const LL mod=998244353;
int n,m,r[maxm],len;
int a[maxm],b[maxm];
int mul(int x,int y)
{
    int ans=1;
    while(y)
    {
        if(y&1)ans=((LL)ans*(LL)x)%mod;
        x=((LL)x*(LL)x)%mod,y>>=1;
    }
    return ans;
}
void fntt(int * c,double f)
{
    rep(i,0,n-1)if(i<r[i])swap(c[i],c[r[i]]);
    for(int i=1;i<n;i<<=1)
    {
        int wn=mul(3,(mod-1)/(i<<1)),x,y;
        if(f==-1)wn=mul(wn,mod-2);
        for(int j=0;j<n;j+=(i<<1))
        {
            int w=1;
            rep(k,0,i-1)
                x=c[j+k]%mod,y=((LL)w*(LL)c[j+i+k])%mod,c[j+k]=(x+y)%mod,c[j+i+k]=(x-y+mod)%mod,w=(LL)w*(LL)wn%mod;	
        }
    }
}
int main()
{
    n=read(),m=read();
    rep(i,0,n)a[i]=read();
    rep(i,0,m)b[i]=read();
    m+=n;
    for(n=1;n<m+1;n<<=1)len++;
    rep(i,0,n-1)r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
    fntt(a,1),fntt(b,1);
    rep(i,0,n-1)a[i]=((LL)a[i]*(LL)b[i])%mod;
    fntt(a,-1);int inv=mul(n,mod-2);
    rep(i,0,m)printf("%d ",(LL)a[i]*(LL)inv%mod);
    return 0;
}


原文地址:https://www.cnblogs.com/xzyf/p/10021329.html