UOJ 34 fft板子

http://uoj.ac/problem/34

  

fft真是一个丧心病狂的东西

递归版

#include<cstdio>
#include<cmath>
#define FOR(i,s,t) for(register int i=s;i<=t;++i)
typedef double db;
const db pi=acos(-1);
const int N=500011;
struct complex{
	db r,i;
	typedef complex cp;
	inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};}
	inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};}
	inline cp operator*(cp A){return (cp){r*A.r-i*A.i,r*A.i+i*A.r};}
}a[N],b[N];
typedef complex cp;
inline void fft(cp *x,int n,int type){
	if(n==1)return;
	int hf=n>>1;
	cp l[hf+10],r[hf+10];
	for(register int i=0;i<n;i+=2)
		l[i>>1]=x[i],r[i>>1]=x[i+1];
	fft(l,hf,type);fft(r,hf,type);
	cp wn=(cp){cos(2*pi/n),sin(type*2*pi/n)},w=(cp){1,0},t;
	for(register int i=0;i<hf;++i,w=w*wn)
		t=w*r[i],x[i]=l[i]+t,x[i+hf]=l[i]-t;
}
int n,m,x;
int main(){
	scanf("%d%d",&n,&m);
	FOR(i,0,n)scanf("%d",&x),a[i].r=x;
	FOR(i,0,m)scanf("%d",&x),b[i].r=x;
	m+=n;for(n=1;n<=m;n<<=1);
	fft(a,n,1);fft(b,n,1);
	FOR(i,0,n)a[i]=a[i]*b[i];
	fft(a,n,-1);
	FOR(i,0,m)
		printf("%d ",(int)(a[i].r/n+0.5));
	return 0;
}

  

迭代版

#include<cstdio>
#include<cmath>
#include<algorithm>
#define gc getchar()
#define FOR(i,s,t) for(register int i=s;i<=t;++i)
using std::swap;
typedef double db;
const db pi=acos(-1);
struct complex{
	db r,i;
	typedef complex cp;
	inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};}
	inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};}
	inline cp operator*(cp A)const{return (cp){r*A.r-i*A.i,r*A.i+A.r*i};}
}a[1<<18],b[1<<18],wn[18];
typedef complex cp;
int p[1<<18];
int n,m,lg2;
inline void fft(cp *a){
	FOR(i,0,n-1)if(i<p[i])swap(a[i],a[p[i]]);
	for(register int i=1,t=0;i<n;i<<=1,++t){
		int m=i<<1;
		cp w=wn[t];
		for(register int j=0;j<n;j+=m){
			cp v=(cp){1,0};
			int e=i+j;
			for(register int k=j;k<e;++k,v=v*w){
				cp y=v*a[k+i];a[k+i]=a[k]-y;
				a[k]=a[k]+y;
			}
		}
	}
}
inline int read(){
	char c;while(c=gc,c==' '||c=='
');int data=c-48;
	while(c=gc,c>='0'&&c<='9')data=(data<<1)+(data<<3)+c-48;return data;
}
int wr[51];
inline void write(int x){
	if(!x){
		putchar(48);
		return;
	}
	while(x)wr[++wr[0]]=x%10,x/=10;
	while(wr[0])putchar(48+wr[wr[0]--]);
}
int main(){
	n=read();m=read();
	FOR(i,0,n)a[i].r=1.00*read();
	FOR(i,0,m)b[i].r=1.00*read();
	m+=n;for(n=1;n<=m;n<<=1)++lg2;
	FOR(i,0,n-1)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1));
	for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(pi/i)};
	fft(a);fft(b);
	FOR(i,0,n-1)a[i]=a[i]*b[i];
	for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(-pi/i)};
	fft(a);
	FOR(i,0,m)write((int)(a[i].r/n+0.5)),putchar(' ');
	return 0;
}

  

ntt

#include<cstdio>
#include<algorithm>
using namespace std;
const int mod=479<<21|1,maxn=1e6;
int a[maxn],b[maxn],p[maxn],s[maxn],gn[maxn];
int n,m,lg2,g,ny;
inline int fp(int a,int b){
	int ret=1;
	while(b){
		if(b&1)ret=1ll*a*ret%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}
	return ret;
}
inline int get_g(int p){
	register int x=p-1;
	for(register int i=2;i*i<=x;++i)
		if(x%i==0){
			while(x%i==0)x/=i;
			s[++s[0]]=i;
		}
	if(x>1)s[++s[0]]=x;
	for(register int i=2;;++i){
		for(register int j=1;j<=s[0];++j)
			if(fp(i,(p-1)/s[j])==1)goto die;
		return i;
		die:;
	}
}
inline void ntt(int *a){
	for(register int i=0;i<m;++i)
		if(i<p[i])swap(a[i],a[p[i]]);
	for(register int i=1,t=0,len,w,v;i<m;i<<=1,++t){
		len=i<<1;
		for(register int j=0;j<m;j+=len){
			w=1;
			for(register int k=j;k<i+j;++k,w=1ll*w*gn[t]%mod){
				v=1ll*w*a[i+k]%mod;
				a[i+k]=(a[k]-v+mod)%mod;
				a[k]=(a[k]+v)%mod;
			}
		}
	}
}
int main(){
	g=get_g(mod);
	scanf("%d%d",&n,&m);
	for(register int i=0;i<=n;++i)scanf("%d",a+i);
	for(register int i=0;i<=m;++i)scanf("%d",b+i);
	n+=m;for(m=1;m<=n;m<<=1)++lg2;
	for(register int i=0;i<m;++i)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1));
	for(register int i=1,t=0;i<m;i<<=1,++t)gn[t]=fp(g,(mod-1)/(i<<1));
	ntt(a);ntt(b);
	for(register int i=0;i<m;++i)a[i]=1ll*a[i]*b[i]%mod;
	ntt(a);
	reverse(a+1,a+m);
	ny=fp(m,mod-2);
	for(register int i=0;i<m;++i)a[i]=1ll*a[i]*ny%mod;
	for(register int i=0;i<=n;++i)printf("%d ",a[i]);
	return 0;
}

  

多项式求逆元

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int mod=998244353,maxn=2e5+5;
int a[maxn],b[maxn],tmp[maxn],s[maxn],gn[maxn];
int n;
inline int fp(int a,int b){
	int ret=1;
	while(b){
		if(b&1)ret=1ll*a*ret%mod;
		a=1ll*a*a%mod;b>>=1;
	}
	return ret;
}
inline void ntt(int *a,int p,int f){
	for(register int i=0;i<p;++i)
		if(i<s[i])
			swap(a[i],a[s[i]]);
	for(register int i=1,t=0,g,w,v;i<p;i<<=1,++t){
		g=gn[t];
		for(register int j=0;j<p;j+=(i<<1)){
			w=1;
			for(register int k=j;k<i+j;++k,w=1ll*w*g%mod){
				v=1ll*w*a[i+k]%mod;
				a[i+k]=(a[k]-v+mod)%mod;
				a[k]=(a[k]+v)%mod;
			}
		}
	}
	if(f==1)return;
	reverse(a+1,a+p);
	int ny=fp(p,mod-2);
	for(register int i=0;i<p;++i)
		a[i]=1ll*a[i]*ny%mod;
}
inline void solve(int *b,int deg){
	if(deg==1){
		b[0]=fp(a[0],mod-2);
		return;
	}
	solve(b,(deg+1)>>1);
	int p=1,lg2=0;while(p<(deg<<1))p<<=1,++lg2;
	for(register int i=0;i<p;++i)tmp[i]=i<deg?a[i]:0;
	for(register int i=((deg+1)>>1);i<p;++i)b[i]=0;
	for(register int i=0;i<p;++i)s[i]=(s[i>>1]>>1)^((i&1)<<(lg2-1));
	ntt(tmp,p,1),ntt(b,p,1);
	for(register int i=0;i<p;++i)b[i]=(2ll*b[i]%mod-1ll*tmp[i]*b[i]%mod*b[i]%mod+mod)%mod;
	ntt(b,p,-1);
}
int main(){
	for(register int t=0,i=1;t<=20;i<<=1,++t)
		gn[t]=fp(3,(mod-1)/(i<<1));
	scanf("%d",&n);
	for(register int i=0;i<=n;++i)scanf("%d",a+i);
	solve(b,n+1);
	for(register int i=0;i<=n;++i)printf("%d ",b[i]);
	return 0;
}

  

原文地址:https://www.cnblogs.com/Stump/p/8001123.html