NTT

假设我们现在要多项式除法并且取模,(FFT)就会很难受了,因为它用的是复数,并且还有精度差。

这时我们需要一个能替代单位复根的东西:原根。

一、为什么原根能替代单位复根

考虑为什么单位复根能用来做(FFT),因为它有很多性质,而我们会发现原根也具有这些性质:

以下的(n)都为(2)的正整数次幂。

(g_n=g^{frac{p-1}{n}})(p)是质数且(nmid(p-1))(g)是模(p)意义下的原根。

1.(omega_n^n=omega_n^0=1)

(g_n^nequiv g^{frac{n(p-1)}{n}}equiv g^{p-1}equiv 1pmod{p})

2.(omega_n^{frac{n}{2}}=-1)

(g_n^{frac{n}{2}}equiv g^{frac{n(p-1)}{2n}}equiv g^{frac{p-1}{2}}pmod{p})

又因为((g^{frac{p-1}{2}})^2equiv g^{p-1}equiv1pmod{p})

方程(x^2equiv 1pmod{p})(p)为质数时只有(1,-1)两种取值,而(g^{frac{(p-1)}{2}} otequiv g^{p-1} equiv 1pmod{p}),因此(g^{frac{(p-1)}{2}}equiv -1pmod{p})

3.(omega_n^k=omega_{dn}^{dk})

$g_{dn}^{dk} = g^{frac{dk(p-1)}{dn}} = g^{frac{k(p-1)}{n}} = g_n^k $

于是单位复根有的性质原根都有

二、NTT:

于是我们开始魔改(FFT)

首先我们的的模数(p)要满足$ acdot 2^k +1$的形式,并且这个 (2) 的幂要大于 (n)。常见的有两种:
1.$1004535809 = 479 imes 2^{21} + 1 $,它的最小正原根是 (3)
2.(998244353 = 2^{23} imes 7 imes 17 + 1),最小正原根也是 (3)

实现时将(FFT)中的单位复根换成原根即可,最后逆变换时要乘逆元。

模板题

code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4e6+10;
const ll mod=998244353;
const ll g=3;
const ll invg=332748118;
int n,m,lim=1,len;
int pos[maxn];
ll a[maxn],b[maxn];
inline ll power(ll x,ll k)
{
	ll res=1;
	while(k)
	{
		if(k&1)res=res*x%mod;
		x=x*x%mod;k>>=1;
	}
	return res;
}
inline void ntt(ll* a,int op)
{
	for(int i=0;i<lim;i++)if(i<pos[i])swap(a[i],a[pos[i]]);
	for(int mid=1;mid<lim;mid<<=1)
	{
		ll wn=power((op==1)?g:invg,(mod-1)/(mid<<1));
		for(int i=0,l=(mid<<1);i<lim;i+=l)
		{
			ll w=1;
			for(int j=0;j<mid;j++,w=w*wn%mod)
			{
				ll x=a[i+j]%mod,y=w*a[i+mid+j]%mod;
				a[i+j]=(x+y)%mod,a[i+mid+j]=(x-y+mod)%mod;
			}
		}
	}
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++)scanf("%lld",&a[i]),a[i]=(a[i]+mod)%mod;
	for(int i=0;i<=m;i++)scanf("%lld",&b[i]),b[i]=(b[i]+mod)%mod;
	while(lim<=n+m)lim<<=1,len++;
	for(int i=0;i<lim;i++)pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
	ntt(a,1);ntt(b,1);
	//for(int i=0;i<lim;i++)cerr<<a[i]<<' '<<b[i]<<endl;
	for(int i=0;i<lim;i++)a[i]=(a[i]*b[i])%mod;
	ntt(a,-1);
	ll inv=power(lim,mod-2);
	for(int i=0;i<=n+m;i++)printf("%lld ",a[i]*inv%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/nofind/p/12118673.html