FFT/NTT初探

做了全家桶然后写了几道入门题。
FFT.ref
NTT.ref

Luogu4238 【模板】多项式求逆

Link
套牛顿迭代完事。有一个细节问题是:这次运算多项式有几项就只赋几项的值,其他位置(次数大于n次的项在(mod {x^n})意义下当然为0)一定要设成0(即清空数组),否则会计算错误。

Luogu5205 【模板】多项式开根

Link
开根也是牛顿迭代的问题。运算的时候时刻注意是在(mod {x^n})意义下进行的,这决定了多项式取的位数。
tips:两个长度为n(2的次幂)的多项式相乘,本来应该要在长度为2n下运算,如果在长度为n下运算,那么NTT会起到循环卷积的效果,即后面n项的贡献会平移到前n项来。

Luogu4725 【模板】多项式ln

Link
求导秒掉 求导的时候要注意链式法则。求导和积分最高位也是m-1。

Luogu4726 【模板】多项式exp

Link
牛顿迭代。

Luogu5245 【模板】多项式快速幂

Link
(B(x)=A^k(x) ightarrow B(x)=e^{k*ln(A(x))})

[SDOI2015]序列统计

Link
(f[i*2][c]= {sum_{a*b\%mequiv c}} {f[i][a]*f[i][b]})。快速幂一个log,每层转移是(m^2)的。转移次数不太好优化,考虑优化每层转移速度。有一个很妙的做法是两边求log,(用原根作为log的底数),此时式子就变成:(f[i*2][c]= {sum_{log(a)+log(b)\%mequiv log(c)}} {f[i][a]*f[i][b]}),一个平凡的卷积,NTT硬套完事儿。另外普通快速幂就行了,不需要用到ln+exp。

[ZJOI2014]力

Link
IDFT求完后除以n,此处n指的是化为2的次幂的那个n!另外是精度问题!!!假设要卷积的两个多项式F和G,F的系数均在[(10^{-5},10^{-6})]之间,而G的系数均在[(10^5,10^6)]之间。直接做两遍DFT+一遍IDFT,涉及的精度跨度是(10^{12}),没毛病。但是!!如果用三步并两步优化(一遍DFT+一遍IDFT)的话,涉及的精度跨度上限就是平方,也就是(10^{24})的,而double小数点后有效位数是15~16位,GG。像是这道题如果用了三步并两步,那么就可以获得0分的好成绩(烟)。

板子

FFT

注意idft的时候/n,实部和虚部都要除!

#include<bits/stdc++.h>
using namespace std;
#define REP(i,a,b) for(int i=(a),_ed=(b);i<=_ed;++i)
#define DREP(i,a,b) for(int i=(a),_ed=(b);i>=_ed;--i)
typedef long long ll;
inline int read(){
    register int x=0,f=1;register char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
    return f?x:-x;
}

const int N=(1<<21)+5;
const double Pi=acos(-1.0);
int n,m,trans[N];
struct cmplx{
    double x,y;
    inline cmplx(double _x=0,double _y=0):x(_x),y(_y){}
    inline cmplx operator+(const cmplx& t){return (cmplx){x+t.x,y+t.y};}
    inline cmplx operator-(const cmplx& t){return (cmplx){x-t.x,y-t.y};}
    inline cmplx operator*(const cmplx& t){return (cmplx){x*t.x-y*t.y,x*t.y+y*t.x};}
    inline void operator/=(const int &t){x/=t,y/=t;}
};
cmplx f[N],g[N],w[N];

void FFT(cmplx* f,int n){
    REP(i,0,n-1)if(i<trans[i])swap(f[i],f[trans[i]]);
    for(int len=2,d=1;len<=n;d=len,len<<=1)
	for(int p=0;p<n;p+=len)
	    for(int i=p;i<p+d;++i){
		cmplx t=f[i+d]*w[d+i-p];
		f[i+d]=f[i]-t;f[i]=f[i]+t;
	    }
}
void times(cmplx* f,cmplx* g,int m1,int m2){
    int n=1;for(;n<m1+m2-1;n<<=1);
    REP(i,0,n-1)trans[i]=(trans[i>>1]>>1)|(i&1?(n>>1):0);
    for(int len=2,d=1;len<=n;d=len,len<<=1){
	cmplx nw=cmplx(1,0),e=cmplx(cos(2*Pi/len),sin(2*Pi/len));
	for(int i=0;i<d;++i,nw=nw*e)w[d+i]=nw;
    }
    FFT(f,n),FFT(g,n);
    REP(i,0,n-1)f[i]=f[i]*g[i];
    FFT(f,n);
    reverse(f+1,f+n);
    REP(i,0,n-1)f[i]/=n;
}

int main(){
    //freopen("in.in","r",stdin);
    n=read()+1,m=read()+1;
    REP(i,0,n-1)f[i].x=read();
    REP(i,0,m-1)g[i].x=read();
    times(f,g,n,m);
    REP(i,0,n+m-1-1)printf("%d ",(int)(f[i].x+0.5));
    puts("");
    return 0;
}

NTT

NTT过程用ull存储的话,因为计算t的时候值域最大可达(log(n)*mod*mod),所以在范围较大时(例如(len=1<<17)),要取一次模防止溢出。

#include<bits/stdc++.h>
using namespace std;
#define REP(i,a,b) for(int i=(a),_ed=(b);i<=_ed;++i)
#define DREP(i,a,b) for(int i=(a),_ed=(b);i>=_ed;--i)
typedef long long ll;
typedef unsigned long long ull;
inline int read(){
    register int x=0,f=1;register char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
    return f?x:-x;
}

const int N=(1<<18)+5,mod=998244353;
int n,k,f[N];char s[N];
inline int power(int b,int n){int ans=1;for(;n;n>>=1,b=1ll*b*b%mod)if(n&1)ans=1ll*ans*b%mod;return ans;}
inline void inc(int& x,int y){x=x+y<mod?x+y:x+y-mod;}
inline void dec(int& x,int y){x=x-y>=0?x-y:x-y+mod;}

int trans[N],w[N];
void NTT(int* a,int n){
    static ull f[N];
    REP(i,0,n-1)f[i]=a[i];
    REP(i,0,n-1)if(i<trans[i])swap(f[i],f[trans[i]]);
    for(int len=2,d=1;len<=n;d=len,len<<=1){
	for(int p=0;p<n;p+=len)
	    for(int i=p;i<p+d;++i){
		int t=f[i+d]*w[d+i-p]%mod;
		f[i+d]=f[i]+mod-t;f[i]+=t;
	    }
        if(len==1<<17)REP(i,0,n-1)f[i]%=mod;
    }
    REP(i,0,n-1)a[i]=f[i]%mod;
}
void times(int* f,int* a,int m1,int m2,int lim){
    static int g[N];
    int n=1;for(;n<(m1+m2-1);n<<=1);
    REP(i,0,n-1)trans[i]=(trans[i>>1]>>1)|(i&1?(n>>1):0);
    for(int len=2,d=1;len<=n;d=len,len<<=1){
	int e=power(3,(mod-1)/len);
	REP(i,w[d]=1,d-1)w[d+i]=1ll*w[d+i-1]*e%mod;
    }
    REP(i,m1,n-1)f[i]=0;
    REP(i,0,m2-1)g[i]=a[i];REP(i,m2,n-1)g[i]=0;
    NTT(f,n);NTT(g,n);
    REP(i,0,n-1)f[i]=1ll*f[i]*g[i]%mod;
    NTT(f,n);int inv=power(n,mod-2);
    reverse(f+1,f+n);
    REP(i,0,lim-1)f[i]=1ll*f[i]*inv%mod;
    REP(i,lim,n-1)f[i]=0;
}

void inv(int* f,int m){
    static int g[N],p[N];
    int n=1;for(;n<m;n<<=1);
    g[0]=power(f[0],mod-2);
    for(int len=2;len<=n;len<<=1){
	REP(i,0,(len>>1)-1)p[i]=g[i],g[i]=1ll*2*g[i]%mod;
	times(p,p,len>>1,len>>1,len);times(p,f,len,len,len);
	REP(i,0,len-1)inc(g[i],mod-p[i]);
    }
    REP(i,0,m-1)f[i]=g[i];
    REP(i,0,n-1)g[i]=p[i]=0;
}

void drv(int* f,int m){
    REP(i,0,m-2)f[i]=1ll*(i+1)*f[i+1]%mod;
    f[m-1]=0;
}
void itg(int* f,int m){
    DREP(i,m-1,1)f[i]=1ll*power(i,mod-2)*f[i-1]%mod;
    f[0]=0;
}
void ln(int* f,int m){
    static int g[N];
    REP(i,0,m-1)g[i]=f[i];
    inv(f,m);drv(g,m);
    times(f,g,m,m,m);
    itg(f,m);
    REP(i,0,m-1)g[i]=0;
}

void exp(int* f,int m){
    static int g[N],p[N];
    int n=1;for(;n<m;n<<=1);
    g[0]=1;
    for(int len=2;len<=n;len<<=1){
	REP(i,0,(len>>1)-1)p[i]=g[i];
	ln(p,len);REP(i,0,len-1)p[i]=mod-p[i],inc(p[i],f[i]);
	inc(p[0],1);
	times(g,p,len>>1,len,len);
    }
    REP(i,0,m-1)f[i]=g[i];
    REP(i,0,n-1)g[i]=p[i]=0;
}

void spower(int* f,int m,int k){
    ln(f,n);
    REP(i,0,n-1)f[i]=1ll*k*f[i]%mod;
    exp(f,n);
}

void div(int* a,int* b,int m1,int m2,int* r){
    static int f[N],g[N];
    if(m1<m2){
	REP(i,0,m1-1)r[i]=a[i];REP(i,m1,m2-1)r[i]=0;
	return;
    }
    int lim=m1-m2+1;
    REP(i,0,m1-1)f[i]=a[i];REP(i,0,m2-1)g[i]=b[i];
    reverse(f,f+m1),reverse(g,g+m2);
    REP(i,lim,m1-1)f[i]=0;REP(i,lim,m2-1)g[i]=0;
    inv(g,lim);times(f,g,lim,lim,lim);
    reverse(f,f+lim);
    times(f,b,lim,m2,m1);
    REP(i,0,m1-1)r[i]=a[i],dec(r[i],f[i]);
    REP(i,0,m1-1)f[i]=0;REP(i,0,m2-1)g[i]=0;
}

int main(){
    //freopen("in.in","r",stdin);
    n=read();scanf("%s",s);REP(i,0,strlen(s)-1)k=(1ll*k*10+(s[i]^'0'))%mod;
    REP(i,0,n-1)f[i]=read();
    spower(f,n,k);
    REP(i,0,n-1)printf("%d%c",f[i],i==n-1?'
':' ');
    return 0;
}
原文地址:https://www.cnblogs.com/fruitea/p/12018570.html