多项式板子

FFT

[w_n=cosfrac{2pi}{n}+sinfrac{2pi}{n} i ]

[F(omega_n^k)=A(omega_{n/2}^k)+omega_{n}^k imes B(omega_{n/2}^k) ]

[F(omega_n^{k+n/2})=A(omega_{n/2}^k)-omega_{n}^k imes B(omega_{n/2}^k) ]

struct Complex{
    db x,y;
    Complex(db x_=0,db y_=0):x(x_),y(y_){}
    Complex operator +(Complex a){return Complex(x+a.x,y+a.y);}
    Complex operator -(Complex a){return Complex(x-a.x,y-a.y);}
    Complex operator *(Complex a){return Complex(x*a.x-y*a.y,y*a.x+x*a.y);}
};

inline void FFT(){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(a[i],a[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        int len=p>>1;
        Complex ret(cos(2.0*PI/p),sin(2.0*PI/p)*opt);
        for(rint k=0;k<n;k+=p){
            Complex now=Complex(1.0,0.0);
            for(rint l=k;l<k+len;l++){
                Complex t=now*a[l+len];
                a[l+len]=a[l]-t;
                a[l]=a[l]+t;
                now=now*ret;
            }
        }
    }
}

int main(){
    for(int i=0;i<n;i++)
        rk[i]=(rk[i>>1]>>1)|((i&1)? n>>1:0);
    opt=1;FFT();
    for(rint i=0;i<n;++i) a[i]=a[i]*a[i];
    opt=-1;FFT();
    for(rint i=0;i<=m;++i) printf("%.0lf ",fabs(a[i].y)/n/2.0);
}

NTT

(p) 的原根 (g) 替换 (omega),因为原根有类似的性质。

[varphi(p)=2^p imes r ]

其中 (2^p) 决定了最大长度。

void NTT(bool op,int n,ll *F){
    for(int i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(int p=2;p<=n;p<<=1){
        int len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(int k=0;k<n;k+=p){
            ll now=1;
            for(int l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

Mul

分别 NTT 后,将点值乘起来,再 NTT 回去。注意最后要除 (n)

(n)(m) 都是最高次幂大小。注意清空 (x)(y)

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(int i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(int i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

Inv

(A(x)) 的逆元 (B(x))(a_0) 非零。

先求出 (A(x)) 的常数项的逆元,设为初始的 (B(x))。现在已知

[A(x) equiv B(x) pmod{x^n} ]

可以得到

[A(x)B(x) equiv 1 pmod{x^n} ]

[ig(A(x)B(x)-1ig)^2 equiv 0 pmod{x^{2n}} ]

[A(x)ig(2B(x)-A(x)B(x)^2ig) equiv 0 pmod{x^{2n}} ]

新的 (B(x)) 就是 (2B(x)-A(x)B(x)^2) 。递归即可,复杂度 (O(n log n))

注意 (n) 是项数,也即多项式长度,而我们上述式子所倍增的是多项式最高次幂,也就是平方能得到的是最高次幂的倍增,不是长度的倍增。而且会发现最高次幂的倍增会慢于长度的倍增,这就会导致求出来的最后几项实际上是虚拟的。一个解决的办法是 (n) 取大于等于 (2m) 的值,这样虽然会慢,但求出来一定是对的。

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

ln

给定 (A(x)),且 (a_0=1)。求 (B(x)=ln A(x))

求导,有

[B'(x)=frac{A'(x)}{A(x)} ]

求逆即可,得到 (B'(x)),再积分回去。

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n,x); Inv(n,a,x);
    for(int i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(int i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

exp

(B(x)=e^{A(x)})

[g(B(x))=ln B(x)-A(x)equiv 0 pmod {x^n} ]

也就是要求 (g) 的一个多项式根。假如现在已经知道了 (B) 的前 (n) 项,即

[B(x)equiv B_0(x) pmod {x^n} ]

(x=B_0(x)) 处泰勒展开,有

[egin{align} 0 &=g(B_0(x)) \ &=g(B_0(x))+g'(B_0(x))ig(B(x)-B_0(x)ig)+frac{g''(B_0(x))}{2}ig(B(x)-B_0(x)ig)^2+dots \ &=g(B_0(x))+g'(B_0(x))ig(B(x)-B_0(x)ig) pmod {x^{2n}} end{align} ]

化简得

[B(x)equiv B_0(x)-frac{gig(B_0(x)ig)}{g'ig(B_0(x)ig)} ]

代入 (g)

[B(x)equiv B_0(x)Big(1-ln B_0(x)+A(x)Big) pmod {x^{2n}} ]

倍增的时候 (m) 也要翻倍,和 (Inv) 同理。

void exp(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    exp((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); 
    Clear(n,x),Clear(n,y),Cop(m,y,a),ln(m,b,x); // x=ln(b)
    NTT(1,n,x),NTT(1,n,y),NTT(1,n,b);
    for(rint i=0;i<n;i++) b[i]=(1-x[i]+y[i]+Mod)%Mod*b[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<n;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

Sqrt

(B(x)^2equiv A(x)),保证 (a_0=1).

[egin{align} B(x)& equiv B_0(x) pmod{x^n}\ ig(B(x)- B_0(x) ig )^2&equiv 0 pmod{x^{2n}}\ B^2(x)+B_0(x)^2&equiv 2B(x)B_0(x) pmod{x^{2n}}\ A(x)+B_0(x)^2&equiv 2B(x)B_0(x) pmod{x^{2n}}\ B(x)&equiv frac{A(x)+B_0(x)^2}{2B_0(x)} pmod{x^{2n}} end{align} ]

void Sqrt(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    Sqrt((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1);
    Clear(n,x),Clear(n,y),Inv(m,b,x);
    Cop(m,y,a),Mul(y,m-1,m-1,x,y);
    ll inv=qpow(2ll);
    for(rint i=0;i<m;i++) b[i]=(b[i]+y[i])%Mod*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

完整板子

#include<stdio.h>
#define rint register int

typedef long long ll;

inline int read(){
    int x=0,flag=1; char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')flag=0;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
    return flag? x:-x; 
}

const int N=(1<<22)+7;
const int Mod=998244353;
const int G=3;

ll qpow(ll x,int y=Mod-2){
    ll ret=1;
    while(y){
        if(y&1) ret=ret*x%Mod;
        x=x*x%Mod,y>>=1; 
    }
    return ret;
}

const int Gi=qpow(G);

int rk[N];
inline void swap(ll &x,ll &y){x^=y,y^=x,x^=y;}
void NTT(bool op,int n,ll *F){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        rint len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(rint k=0;k<n;k+=p){
            ll now=1;
            for(rint l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(rint i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(rint i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n,x); Inv(n,a,x);
    for(int i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(int i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

void exp(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    exp((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); 
    Clear(n,x),Clear(n,y),Cop(m,y,a),ln(m,b,x); // x=ln(b)
    NTT(1,n,x),NTT(1,n,y),NTT(1,n,b);
    for(rint i=0;i<n;i++) b[i]=(1-x[i]+y[i]+Mod)%Mod*b[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<n;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void Sqrt(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    Sqrt((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1);
    Clear(n,x),Clear(n,y),Inv(m,b,x);
    Cop(m,y,a),Mul(y,m-1,m-1,x,y);
    ll inv=qpow(2ll);
    for(rint i=0;i<m;i++) b[i]=(b[i]+y[i])%Mod*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

int n,m;
ll a[N],b[N];

int main(){
    n=read();
    for(rint i=0;i<n;i++) a[i]=read();
    Func();
    for(rint i=0;i<n;i++) printf("%lld ",b[i]);
}
原文地址:https://www.cnblogs.com/wwlwQWQ/p/14930844.html