NTT

NTT

NTT是一种跑得比FFT快的东西(?)。

元素的幂

考虑有限群G,(a in G)。元素的幂就是a的几次方。

使得(a^d=e)的最小正整数d称为a的阶,记作(d=ord(a))

显然,a的幂生成的集合S是G的子群。因此,(a^{|G|}=e)

原根

有个结论:(Z_n^*)存在原根(Leftrightarrow)(n=2,4,p^alpha,2p^alpha),p为奇素数。

tip:469762049,998244353和1004535809都有原根3。

设对于n,我们找到了原根g。设(g_n=g^{frac{p-1}{n}})。那么:

  • (g_n=g^{frac{p-1}{n}})

  • (g^n_n=g^{p-1}=1)

  • (g_{dn}^{dk}=(g^frac{p-1}{dn})^{dk}=(g^{frac{p-1}{n}})^k=g_n^k)(消去引理)

  • ((g_n^k)^2=(g_n^{k+n/2})^2=(g^frac{p-1}{n/2})^k=g_{n/2}^k)(折半引理)

  • 求和引理:(sum_{i=0}^{n-1}(g_n^k)^i =left{ egin{align} n,nmid k \ 0,n mid k end{align} ight.), 可以用类似fft的方法证。

然后我们就可以用原根愉快的做fft了。

inline void plus(int &x, int y){ x=1ll*x*y%mod; }
inline void plus(LL &x, LL y){ x*=y; x%=mod; }
inline void pro(int &x){ if (x<0) x+=mod; }

int fpow(LL a, LL x){
    LL ans=1;
    for (LL base=a; x; x>>=1, plus(base, base))
        if (x&1) plus(ans, base);
    return ans;
}
int inv(int x){ return fpow(x, mod-2); }

void fft(int *a, int l, int flag){
    for (int i=0; i<l; ++i) 
        if (i<rev[i]) swap(a[i], a[rev[i]]);
    LL gn, g, x, y;
    for (int mid=1; mid<l; mid<<=1){  //区间半径 
        gn=fpow(G, (mod-1)/(mid<<1));
        if (flag==-1) gn=inv(gn);
        for (int j=0; j<l; j+=(mid<<1)){ g=1;
            for (int k=j; k<j+mid; ++k, plus(g, gn)){
                x=a[k]; y=g*a[k+mid]%mod;
                a[k]=(x+y)%mod; a[k+mid]=(x-y+mod)%mod; 
            }
        }
    }
}

void ntt(int *a, int *b, int &la, int &lb){  //a=a*b 
    int l=1, bits=0; while (l<=la+lb) l<<=1, ++bits;
    int linv=inv(l);
    for (int i=1; i<l; ++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bits-1));
    fft(a, l, 1); fft(b, l, 1);
    for (int i=0; i<l; ++i) a[i]=1ll*a[i]*b[i]%mod;
    fft(a, l, -1); la=la+lb;
    for (int i=0; i<=la; ++i) a[i]=1ll*a[i]*linv%mod;
}

归纳一下写fft的思路。首先,应该明确我们要写的主要部分,是把系数表示转换成点值表示。先推式子,对于一个系数表示,把它拆成偶数和奇数的两个多项式。然后,把当前多项式写成两个子多项式的式子。利用单位根的性质,使得分叉个数为4。注意数组大小要开4n。

为了方便,把ntt给封装起来。

任意模数ntt

#include <cstdio> 
#include <algorithm>
using namespace std;

typedef long long LL;
const int maxn=4e5+5, maxc=1e5+5;  //maxn必须是4倍数组 
const LL p1=998244353, p2=1004535809, p3=469762049, G=3;
int n, m, k, cntc, rev[maxn], mod, P;

inline void plus(int &x, int y){ x=1ll*x*y%mod; }
inline void plus(LL &x, LL y){ x*=y; x%=mod; }
inline void pro(int &x){ if (x<0) x+=mod; }

LL fmul(LL a, LL b, LL mod){
    LL ans=0;
    for (; b; b>>=1, a+=a, a%=mod)
        if (b&1) ans+=a, ans%=mod;
    return ans;
}
int fpow(LL a, LL x, LL mod){
	LL ans=1;
	for (LL base=a; x; x>>=1, (base*=base)%=mod)
		if (x&1) (ans*=base)%=mod;
	return ans;
}
int inv(int x){ return fpow(x, mod-2, mod); }

void fft(int *a, int l, int flag){
	for (int i=0; i<l; ++i) 
		if (i<rev[i]) swap(a[i], a[rev[i]]);
	LL gn, g, x, y;
	for (int mid=1; mid<l; mid<<=1){  //区间半径 
		gn=fpow(G, (mod-1)/(mid<<1), mod);
		if (flag==-1) gn=inv(gn);
		for (int j=0; j<l; j+=(mid<<1)){ g=1;
			for (int k=j; k<j+mid; ++k, plus(g, gn)){
				x=a[k]; y=g*a[k+mid]%mod;
				a[k]=(x+y)%mod; a[k+mid]=(x-y+mod)%mod; 
			}
		}
	}
}

void ntt(int *a, int *b, int &la, int &lb){  //a=a*b 
	int l=1, bits=0; while (l<=la+lb) l<<=1, ++bits;
	int linv=inv(l);
	for (int i=1; i<l; ++i)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bits-1));
	fft(a, l, 1); fft(b, l, 1);
	for (int i=0; i<l; ++i) a[i]=1ll*a[i]*b[i]%mod;
	fft(a, l, -1);
	for (int i=0; i<=la+lb; ++i) a[i]=1ll*a[i]*linv%mod;
}
int A[maxn], B[maxn], C[3][maxn], D[3][maxn];

int crt(LL c1, LL c2, LL c3){
    static LL invp1=fpow(p1, p2-2, p2), invp2=fpow(p2, p1-2, p1);
    static LL p0=p1*p2, invp0=fpow(p0%p3, p3-2, p3);
    LL c0=fmul(p2*invp2, c1, p0)+fmul(p1*invp1, c2, p0); c0%=p0;
    LL k=(c3+p3-c0%p3)*invp0%p3;
    return (c0%mod+k*(p0%mod))%mod;  //k*p0会爆!
}

int main(){
	scanf("%d%d%d", &n, &m, &P);
	for (int i=0; i<=n; ++i) 
		scanf("%d", &A[i]), C[0][i]=C[1][i]=C[2][i]=A[i];
	for (int i=0; i<=m; ++i) 
		scanf("%d", &B[i]), D[0][i]=D[1][i]=D[2][i]=B[i];
	mod=p1; ntt(C[0], D[0], n, m);
	mod=p2; ntt(C[1], D[1], n, m);
	mod=p3; ntt(C[2], D[2], n, m); mod=P; 
	for (int i=0; i<=n+m; ++i)
		printf("%d ", crt(C[0][i], C[1][i], C[2][i]));
	return 0;
}

两天后的PS:注意对于998244352,1004535808,469762048的分解。(998244352=2^{23}*x)(1004535808=2^{21}*x)(469762048=2^{26}*x)。这是因为三个质数本来就是(p=c*2^l+1)的形式。

原文地址:https://www.cnblogs.com/MyNameIsPc/p/9592422.html