NTT算法小结

从理论上说,经过人们优化的FFT已经十分优秀,能够处理大部分的多项式乘法,但是有的时候仍然会出现下面的情况:

1)常数仍然比较大

2)在进行与整数有关的FFT时,发现得到的结果是一堆诡异的数,你需要不停的和精度搏斗

那么在这时,你就需要学会快速数论变换(NTT)

前置芝士

快速傅里叶变换

你可以上网百度,或者看我的博客

阶与原根

我们由欧拉定理可以知道,对于任意的正整数(a、m),如果满足(gcd(a,m)=1),就有(a^{varphi(m)}equiv 1(mod m))

但我们发现,还有一些数满足(a^pequiv 1(mod m))(p< varphi(m)),因此人们定义了阶

(m>1),且(gcd(a,m)=1),则满足(a^pequiv 1(mod m))的最小正整数(p)成为(a)对模(m)的阶,记作(delta_m(a))

于是就会有(delta_m(a)|p),充分性很显然,我们证一下必要性

我们设(p=delta_m(a)·q+r)(其中(0leq r < delta_m(a))

那么(a^p=a^{delta_m(a)·q}·a^requiv a^requiv 1(mod m))

(p)是最小正整数知(r=0),所以(delta_m(a)|p)

然后你就有了(delta_m(a)|varphi(p))

好吧这个性质并没有什么用

由上面的欧拉定理,我们不难理解数学家们为什么搞出这么一个蛋疼的定义——原根

如果(delta_m(a)=varphi(m)),那么称(a)是模(m)的一个原根,为了下面表述的方便我们将它记作(g)

我们发现原根有一个这样的性质:(g^0,g^1,g^2,cdots,g^{varphi(m)-1})构成了一个模(m)的完全剩余系

证明:考虑反证法,即假设c存在(i,j)((i>j))满足(g^iequiv g^j(mod m))

​ 两边同时除以(g^j),有(g^{i-j}equiv 1),而很明显(i-j<varphi(m)),这与原根的定义相矛盾

​ 于是性质得证

性质验证

在FFT中,我们使用单位根的原因就是单位根满足的一些性质可以加速计算,如果原根也满足的话,那么我们在计算时可以直接替换

性质1

(w_n^0,w_n^1,w_n^2,cdots,w_n^{n-1})两两不同

这在上面已经得到了证明

性质2

(w_{2n}^{2p}=w_n^p)

如果(w_{n}=g^p),那么就应该有(w_{2n}=g^{frac{p}{2}}),他们在乘上两倍的次幂之后值相等

性质3

(w_{n}^{frac{n}{2}+p}=-w_n^p)

因为有((g^{frac{n}{2}})^2equiv 1),为了保持原根的定义,就会有(g^{frac{n}{2}}equiv -1(mod n))

带回去运算即可

综上所述,原根满足原来的单位根所具有的性质,因此我们可以考虑用原根来代替单位根

实际运用

一点注意事项

1、原根的话在模数不确定的情况下需要自己求,不过如果模数是(998244353)或者(1004535809)的话,它们的原根是3

2、注意在IDFT的时候,原来直接除以的地方要换做求逆元

代码

#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
using namespace std;
#define rep(i,a,b) for (i=a;i<=b;i++)
typedef long long ll;
#define maxd 998244353
const double pi=acos(-1.0);
#define int long long 
ll n,m,a[5005000],b[5005000];
int lim=1,r[5005000];

int qpow(int x,int y)
{
    int ans=1,sum=x;
    while (y)
    {
        int tmp=y%2;y/=2;
        if (tmp) ans=(1ll*ans*sum)%maxd;
        sum=(1ll*sum*sum)%maxd;
    }
    return ans;
}

void ntt(int lim,ll *a,int typ)
{
    int i;
    for (i=0;i<lim;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    int mid;
    for (mid=1;mid<lim;mid<<=1)
    {
        int gn=qpow(3,(maxd-1)/(mid<<1));
        int sta,len=mid<<1,j;
        for (sta=0;sta<lim;sta+=len)
        {
            int g=1;
            for (j=0;j<mid;j++,g=(g*gn)%maxd)
            {
                int x1=a[j+sta],y1=(g*a[j+sta+mid])%maxd;
                a[j+sta]=(x1+y1)%maxd;
                a[j+sta+mid]=(x1-y1+maxd)%maxd;
            }
        }
    }
    if (typ==-1) reverse(&a[1],&a[lim]);
}

int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

signed main()
{
    n=read();m=read();int i,cnt=0;
    for (i=0;i<=n;i++) a[i]=read();
    for (i=0;i<=m;i++) b[i]=read();
    while (lim<=n+m) {lim<<=1;cnt++;}
    for (i=0;i<=lim;i++)
        r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
    ntt(lim,a,1);
    ntt(lim,b,1);
    for (i=0;i<=lim;i++) a[i]=(a[i]*b[i])%maxd;
    ntt(lim,a,-1);
    int tmp=qpow(lim,maxd-2);
    for (i=0;i<=n+m;i++) 
    {
        a[i]=(a[i]*tmp)%maxd;
        printf("%lld ",a[i]);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/encodetalker/p/10285657.html