FFT迭代加深 & NTT 多项式求逆

NTT板子
又重温了一遍,大佬说背锅就好
具体看代码

想要看懂NTT板子,先看懂FFT迭代加深模板;

FFT迭代加深版本

#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
const int N=1e7+7;
struct complex{
	double x,y;
	complex(double xx=0,double yy=0) {x=xx,y=yy;}
}a[N],b[N];
const double pi=acos(-1.0);
complex operator +(const complex a,complex b) {return complex(a.x+b.x,a.y+b.y);}
complex operator -(const complex a,complex b) {return complex(a.x-b.x,a.y-b.y);}
complex operator *(const complex a,complex b) {return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int limit=1,n,m,l;
int r[N];

void FFT(complex *a,int f){
	for(int i=0;i<limit;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int mid=1;mid<limit;mid<<=1){//枚举要合并的区间的长度
		complex Wn=complex(cos(pi/mid),f*sin(pi/mid));//单位根
		for(int R=mid<<1,j=0;j<limit;j+=R){
			complex w(1,0);
			for(int k=0;k<mid;k++,w=w*Wn){
				complex x=a[j+k],y=w*a[j+mid+k];
				a[j+k]=x+y;
				a[j+mid+k]=x-y;
			}
		}
	}
}

int main(){
	scanf("%d%d",&n,&m); 
	for(int i=0;i<=n;i++) scanf("%lf",&a[i].x);
	for(int i=0;i<=m;i++) scanf("%lf",&b[i].x);
	while(limit<=n+m) limit<<=1,l++;
	for(int i=0;i<limit;i++){
		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	}
	
	NTT(a,1);
	NTT(b,1);
	for(int i=0;i<=limit;i++) a[i]=a[i]*b[i];
	NTT(a,-1);
	for(int i=0;i<=n+m;i++){
		cout<<(int)(a[i].x/(limit)+0.5)<<" ";
	}
}

多项式求逆

#include<iostream>
#include<cstdio>
using namespace std;
#define int long long
const int N=1e6+7;
const int p=998244353;//原根为3
int n;
int a[N],b[N],c[N],r[N];
int ksm(int a,int b){
	int res=1;
	for(;b;b>>=1){
		if(b&1) res=res*a%p;
		a=a*a%p;
	}
	return res;
}

void NTT(int *a,int len,int opt){
	for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int h=1;h<len;h<<=1){
		int Wn=ksm(3,(p-1)/(h<<1));
		if(opt==-1) Wn=ksm(Wn,(p-2));//NTT求原根
		for(int j=0;j<len;j+=(h<<1)){
			int w=1;
			for(int k=0;k<h;k++){
				int x=a[j+k];
				int y=w*a[j+h+k] % p;
				a[j+k]=(x+y)%p;
				a[j+h+k]=(x-y+p)%p;
				w=w*Wn%p;
			}
		}
	}
	if(opt==-1){
		int inv=ksm(len,p-2);
		for(int i=0;i<len;i++){
			a[i]=a[i]*inv%p;
		}
	}
}

void INV(int n,int *a,int *b){
	if(n==1){
		b[0]=ksm(a[0],p-2);
		return;
	}
	INV((n+1)>>1,a,b);//向上取整
	int limit=1,l=0;
	while(limit<(n<<1)) limit<<=1,l++;
	for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	for(int i=0;i<n;i++) c[i]=a[i];//a数组不能改变,所以赋值
	for(int i=n;i<limit;i++) c[i]=0;//其余对答案没用;
	NTT(c,limit,1),NTT(b,limit,1);
	for(int i=0;i<limit;i++){
		b[i]=(1LL*2*b[i]%p-1LL*b[i]*b[i]%p*c[i]%p+p)%p;
	}
	NTT(b,limit,-1);
	for(int i=n;i<limit;i++) b[i]=0;
}

signed main(){
	scanf("%lld",&n);
	for(int i=0;i<n;i++) scanf("%lld",&a[i]);
	INV(n,a,b);
	for(int i=0;i<n;i++) cout<<(b[i]%p+p)%p<<" ";
}
原文地址:https://www.cnblogs.com/Aswert/p/14264278.html