NTT模板

敲了一份NTT模板,在很多时候答案需要取余的时候NTT有较好的的效果.

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<ctime>
#include<string>
#include<iomanip>
#include<algorithm>
#include<map>
using namespace std;
#define LL long long
#define FILE "dealing"
#define up(i,j,n) for(LL i=j;i<=n;++i)
#define db double
#define ull unsigned long long
#define eps 1e-10
#define pii pair<LL,LL>
LL read(){
	LL x=0,f=1,ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
	return f*x;
}
const LL maxn=802200,maxm=20000,mod=(LL)(1e9+7+0.1),inf=(LL)(1e15);
template<class T>bool cmax(T& a,T b){return a<b?a=b,true:false;}
template<class T>bool cmin(T& a,T b){return a>b?a=b,true:false;}
LL n,m;
namespace NTT{
	const LL maxn=1000400;
	LL r,P,H=0,L=1,R[maxn],w[maxn];
	LL a[maxn],b[maxn];
	LL fast(LL a,LL b){
		LL ans=1;
		for(;b;b>>=1,a=a*a%P)
			if(b&1)ans=ans*a%P;
		return ans;
	}
	LL prime[maxn],tail=0,B[maxn],limit=(LL)(1e6+1);
	void getprime(){
		up(i,2,limit){
			if(!B[i])prime[++tail]=i;
			for(LL j=1;j<=tail&&prime[j]*i<=limit;j++){
				B[i*prime[j]]=1;
				if(i%prime[j]==0)break;
			}
		}
	}
	LL q[maxn],head=0;
	void getr(LL mod){
		if(mod==998244353){r=3;P=mod;return;}
		getprime();
		P=mod;LL N=mod-1,D=P-1;
		for(LL i=1;prime[i]*prime[i]<=N;i++){
			if(D==1)break;
			if(D%prime[i]==0)head++,q[head]=prime[i];
			while(D%prime[i]==0)
				D/=prime[i];
		}
		if(D!=1)q[++head]=D;
		bool flag=0;
		up(i,2,N){
			flag=0;
			up(j,1,head)if(fast(i,(mod-1)/q[j])==1){flag=1;break;}
			if(!flag){
				r=i;
				break;
			}
		}
	}
	void NTT(LL* a,bool flag){
		up(i,0,n)if(i<R[i])swap(a[i],a[R[i]]);
		for(LL len=2;len<=L;len<<=1){
			LL g=fast(r,(P-1)/len),l=len>>1;
			if(flag)g=fast(g,P-2);
			up(i,1,l)w[i]=w[i-1]*g%P;
			for(LL st=0;st<L;st+=len){
				for(LL k=0;k<l;k++){
					LL x=a[st+k],y=w[k]*a[st+k+l]%P;
					a[st+k]=(x+y)%P;a[st+k+l]=(x-y+P)%P;
				}
			}
		}
		if(flag){
			LL inv=fast(L,P-2);
			up(i,0,L-1)a[i]=a[i]*inv%P;
		}
	}
	LL solve(LL* c,LL* d,LL n,LL m,LL* ch){
		up(i,0,n-1)a[i]=c[i];
		up(i,0,m-1)b[i]=d[i];
		for(H=0,L=1;L<n+m-1;H++)L<<=1;
		w[0]=1;
		up(i,n,L)a[i]=0;
		up(i,m,L)b[i]=0;
		up(i,0,L)R[i]=(R[i>>1]>>1)|((i&1)<<H-1);
		NTT(a,0);
		NTT(b,0);
		up(i,0,L)a[i]=a[i]*b[i]%P;
		NTT(a,1);
		up(i,0,n+m-2)ch[i+1]=a[i];
	}
};
LL a[maxn],b[maxn],ch[maxn];
int main(){
	freopen(FILE".in","r",stdin);
	freopen(FILE".out","w",stdout);
	n=read();m=read();
	n++,m++;
	up(i,0,n-1)a[i]=read();
	up(i,0,m-1)b[i]=read();
	NTT::getr(998244353LL);
	NTT::solve(a,b,n,m,ch);
	up(i,1,n+m-1)printf("%lld ",ch[i]);
	return 0;
}

  

原文地址:https://www.cnblogs.com/chadinblog/p/6527143.html