从理论上说,经过人们优化的FFT已经十分优秀,能够处理大部分的多项式乘法,但是有的时候仍然会出现下面的情况:
1)常数仍然比较大
2)在进行与整数有关的FFT时,发现得到的结果是一堆诡异的数,你需要不停的和精度搏斗
那么在这时,你就需要学会快速数论变换(NTT)
前置芝士
快速傅里叶变换
你可以上网百度,或者看我的博客
阶与原根
我们由欧拉定理可以知道,对于任意的正整数(a、m),如果满足(gcd(a,m)=1),就有(a^{varphi(m)}equiv 1(mod m))
但我们发现,还有一些数满足(a^pequiv 1(mod m))且(p< varphi(m)),因此人们定义了阶
设(m>1),且(gcd(a,m)=1),则满足(a^pequiv 1(mod m))的最小正整数(p)成为(a)对模(m)的阶,记作(delta_m(a))
于是就会有(delta_m(a)|p),充分性很显然,我们证一下必要性
我们设(p=delta_m(a)·q+r)(其中(0leq r < delta_m(a)))
那么(a^p=a^{delta_m(a)·q}·a^requiv a^requiv 1(mod m))
由(p)是最小正整数知(r=0),所以(delta_m(a)|p)
然后你就有了(delta_m(a)|varphi(p))
好吧这个性质并没有什么用
由上面的欧拉定理,我们不难理解数学家们为什么搞出这么一个蛋疼的定义——原根
如果(delta_m(a)=varphi(m)),那么称(a)是模(m)的一个原根,为了下面表述的方便我们将它记作(g)
我们发现原根有一个这样的性质:(g^0,g^1,g^2,cdots,g^{varphi(m)-1})构成了一个模(m)的完全剩余系
证明:考虑反证法,即假设c存在(i,j)((i>j))满足(g^iequiv g^j(mod m))
两边同时除以(g^j),有(g^{i-j}equiv 1),而很明显(i-j<varphi(m)),这与原根的定义相矛盾
于是性质得证
性质验证
在FFT中,我们使用单位根的原因就是单位根满足的一些性质可以加速计算,如果原根也满足的话,那么我们在计算时可以直接替换
性质1
(w_n^0,w_n^1,w_n^2,cdots,w_n^{n-1})两两不同
这在上面已经得到了证明
性质2
(w_{2n}^{2p}=w_n^p)
如果(w_{n}=g^p),那么就应该有(w_{2n}=g^{frac{p}{2}}),他们在乘上两倍的次幂之后值相等
性质3
(w_{n}^{frac{n}{2}+p}=-w_n^p)
因为有((g^{frac{n}{2}})^2equiv 1),为了保持原根的定义,就会有(g^{frac{n}{2}}equiv -1(mod n))
带回去运算即可
综上所述,原根满足原来的单位根所具有的性质,因此我们可以考虑用原根来代替单位根
实际运用
一点注意事项
1、原根的话在模数不确定的情况下需要自己求,不过如果模数是(998244353)或者(1004535809)的话,它们的原根是3
2、注意在IDFT的时候,原来直接除以的地方要换做求逆元
代码
#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
using namespace std;
#define rep(i,a,b) for (i=a;i<=b;i++)
typedef long long ll;
#define maxd 998244353
const double pi=acos(-1.0);
#define int long long
ll n,m,a[5005000],b[5005000];
int lim=1,r[5005000];
int qpow(int x,int y)
{
int ans=1,sum=x;
while (y)
{
int tmp=y%2;y/=2;
if (tmp) ans=(1ll*ans*sum)%maxd;
sum=(1ll*sum*sum)%maxd;
}
return ans;
}
void ntt(int lim,ll *a,int typ)
{
int i;
for (i=0;i<lim;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
int mid;
for (mid=1;mid<lim;mid<<=1)
{
int gn=qpow(3,(maxd-1)/(mid<<1));
int sta,len=mid<<1,j;
for (sta=0;sta<lim;sta+=len)
{
int g=1;
for (j=0;j<mid;j++,g=(g*gn)%maxd)
{
int x1=a[j+sta],y1=(g*a[j+sta+mid])%maxd;
a[j+sta]=(x1+y1)%maxd;
a[j+sta+mid]=(x1-y1+maxd)%maxd;
}
}
}
if (typ==-1) reverse(&a[1],&a[lim]);
}
int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
signed main()
{
n=read();m=read();int i,cnt=0;
for (i=0;i<=n;i++) a[i]=read();
for (i=0;i<=m;i++) b[i]=read();
while (lim<=n+m) {lim<<=1;cnt++;}
for (i=0;i<=lim;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
ntt(lim,a,1);
ntt(lim,b,1);
for (i=0;i<=lim;i++) a[i]=(a[i]*b[i])%maxd;
ntt(lim,a,-1);
int tmp=qpow(lim,maxd-2);
for (i=0;i<=n+m;i++)
{
a[i]=(a[i]*tmp)%maxd;
printf("%lld ",a[i]);
}
return 0;
}