多项式

这是优美的多项式家族

快速傅里叶变换(FFT)

问题:多项式乘法

原理先不写了,思想就是把系数表达转化为点值表达,点值运算之后再变回系数表达,复杂度(O(nlogn))

点值选取的是负数域中的n次单位根

有时间会补上这块内容的

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
const int N = 4e6;
const double Pi = acos(-1.0);
using namespace std;
struct node
{
    double x,y;
}a[N + 5],b[N + 5],w[N + 5];
int n,m,maxn,rev[N + 5],lg;
node operator +(node a,node b)
{
    return (node){a.x + b.x,a.y + b.y};
}
node operator -(node a,node b)
{
    return (node){a.x - b.x,a.y - b.y};
}
node operator *(node a,node b)
{
    return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void fft(node *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                node x = a[k + j],t = (node){w[i + k].x,w[i + k].y * typ} * a[k + j + i];
                a[k + j] = x + t;
                a[k + j + i] = x - t;
            }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%lf",&a[i].x);
    for (int i = 0;i <= m;i++)
        scanf("%lf",&b[i].x);
    maxn = 1;
    while (maxn <= m + n)
        maxn <<= 1,lg++;
    for (int i = 0;i <= maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < i;j++)
            w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)}; 
    fft(a,1);
    fft(b,1);
    for (int i = 0;i < maxn;i++)
        a[i] = a[i] * b[i];
    fft(a,-1);
    for (int i = 0;i <= n + m;i++)
        printf("%d ",(int)(a[i].x / maxn + 0.1));
    return 0;
}

快速数论变换(NTT)

就是把问题转化为了在模意义下,于是我们可以选择和单位根有类似性质的原根,时间复杂度仍是(O(nlogn))

#include <iostream>
#include <cstdio>
#include <algorithm>
const int N = 5e6;
const int P = 998244353;
using namespace std;
int n,m,rev[N + 5],maxn,lg,a[N + 5],b[N + 5],g[N + 5][3];
int mypow(int a,int x)
{
    int s = 1;
    while (x)
    {
        if (x & 1)
            s = 1ll * s * a % P;
        a = 1ll * a * a % P;
        x >>= 1;
    }
    return s;
}
void ntt(int *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                int x = a[k + j],t = 1ll * g[k + i][typ] * a[k + i + j] % P;
                a[k + j] = (x + t) % P;
                a[k + i + j] = ((x - t) % P + P) % P;
            }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%d",&a[i]);
    for (int i = 0;i <= m;i++)
        scanf("%d",&b[i]);
    maxn = 1;
    while (maxn <= n + m)
        maxn <<= 1,lg++;
    for (int i = 0;i <= maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    for (int i = 1;i < maxn;i <<= 1)
    {
        int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
        g[i][1] = 1;
        g[i][0] = 1;
        for (int j = 1;j < i;j++)
            g[i + j][1] = 1ll * g[i + j - 1][1] * G1 % P,g[i + j][0] = 1ll * g[i + j - 1][0] * G2 % P;
    }
    ntt(a,1);
    ntt(b,1);
    for (int i = 0;i < maxn;i++)
        a[i] = 1ll * a[i] * b[i] % P;
    ntt(a,0);
    int inv = mypow(maxn,P - 2);
    for (int i = 0;i <= n + m;i++)
        printf("%d ",1ll * a[i] * inv % P);
    return 0;
}

多项式求逆

问题:给定一个多项式(F(x)),求一个多项式(G(x)),满足(F(x)G(x)equiv 1(mod x^n))

假设我们已经求出了一个(F(x))(mod x^n)下的逆(G'(x)),我们要求在(mod x^{2n})下的逆(G(x))

那么考虑

[egin{aligned} F(x)G'(x)&equiv 1(mod x^n) \ F(x)G(x)&equiv 1(mod x^n) \ F(x)(G(x)-G'(x))&equiv 0(mod x^n) \ F^2(x)(G(x)-G'(x))^2&equiv 0(mod x^{2n}) \ F^2(x)G^2(x)-2F^2(x)G(x)G'(x)+F^2(x)G'^2(x)&equiv 0(mod x^{2n}) \ 1-2F(x)G'(x)+F^2(x)G'^2(x)&equiv 0(mod x^{2n}) \ G(x)-2G'(x)+F(x)G'^2(x)&equiv 0(mod x^{2n}) \ G(x)&equiv 2G'(x)-F(x)G'^2(x)(mod x^{2n}) end{aligned} ]

于是就可以愉快地递归求解了,时间复杂度(T(n)=T(n/2)+O(nlogn)=O(nlogn))

Code

int INVa[N + 5];
void INV(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = mypow(a[0],p - 2);
        return;
    }
    INV(a,ans,n + 1 >> 1);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        INVa[i] = a[i];
    clear(INVa,maxn,n);
    ntt(INVa,1);
    ntt(ans,1);
    for (int i = 0;i < maxn;i++)
        ans[i] = (2ll * ans[i] % p - 1ll * INVa[i] * ans[i] % p * ans[i] % p) % p;
    ntt(ans,0);
    clear(ans,maxn,n);
}

多项式对数函数(多项式 ln)

问题:给出 (n-1) 次多项式 (A(x)),求一个 (mod{:x^n}) 下的多项式 (B(x)),满足 (B(x) equiv ln A(x)).

对两边同时求导(B'(x)equiv frac{A'(x)}{A(x)})

积分回去(B(x)equiv int frac{A'(x)}{A(x)}dx)

然后就是求导公式和积分公式

[x^{a'}=ax^{a-1} ]

[int x^adx=frac{1}{a+1}x^{a+1} ]

Code

int Lna[N + 5],Lnb[N + 5];
void DOV(int *a,int *f,int n)
{
    for (int i = 1;i < n;i++)
        f[i - 1] = 1ll * i * a[i] % p;
    f[n - 1] = 0;
}
void DOVINV(int *a,int *f,int n)
{
    f[0] = 0;
    for (int i = 1;i < n;i++)
        f[i] = 1ll * mypow(i,p - 2) * a[i - 1] % p;
}
void Ln(int *a,int *ans,int n)
{
    DOV(a,Lna,n);
    pre(n * 2);
    clear(Lnb,maxn);
    INV(a,Lnb,n);
    pre(n * 2);
    clear(Lna,maxn,n);
    ntt(Lna,1);
    ntt(Lnb,1);
    for (int i = 0;i < maxn;i++)
        Lna[i] = 1ll * Lna[i] * Lnb[i] % p;
    ntt(Lna,0);
    DOVINV(Lna,ans,n);
    clear(ans,maxn,n);
}

多项式指数函数(多项式 exp)

问题:给出 (n-1) 次多项式 (A(x)),保证(A_0=0),求一个 (mod{:x^n}) 下的多项式 (B(x)),满足 (B(x) equiv ext e^{A(x)})

考虑用牛顿迭代解决这个问题

[B(x)equiv e^{A(x)} ]

[lnB(x)-A(x)equiv 0 ]

(F(B(x))=lnB(x)-A(x))

(A(x))看作常数项,所以(F'(B(x))=frac{1}{B(x)})

代入牛顿迭代的式子有

[B(x)equiv B_0(x)-frac{F(B(x))}{F'(B(x))} ]

[B(x)equiv B_0(x)(1-lnB_0(x)+A(x)) ]

倍增求解即可

Code

int expa[N + 5],expb[N + 5];
void exp(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = 1;
        return;
    }
    exp(a,ans,n + 1 >> 1);
    Ln(ans,expa,n);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        expb[i] = a[i];
    clear(expb,maxn,n);
    ntt(ans,1);
    ntt(expa,1);
    ntt(expb,1);
    for (int i = 0;i < maxn;i++)
        ans[i] = 1ll * ans[i] * ((1 - expa[i] + expb[i]) % p) % p;
    ntt(ans,0);
    clear(ans,maxn,n);
}

多项式快速幂

问题:给定一个 (n-1) 次多项式 (A(x)),求一个在 (mod x^n) 意义下的多项式 (B(x)),使得 (B(x) equiv A^k(x) (mod x^n))

我们对两边先ln再exp可以得到

[B(x)equiv exp(k imes ln(A(x))) ]

于是(k)也可以取模了

然后注意到数据不一定保证(A_0=1),那么我们可以找到第一个非(0)的项(a),把(A(x))的每一项都除以(a),变成(frac{A(x)}{a}),并将后面的移到前面,这样就可以保证(A_0=1),最后再乘(a^k)并且处理(0)即可

Code

int pa[N + 5];
void mypow(int *a,int *ans,int n,int k)
{
    Ln(a,pa,n);
    for (int i = 0;i < n;i++)
        pa[i] = 1ll * pa[i] * k % p;
    exp(pa,ans,n);
}

多项式开根

问题:给定一个(n-1)次多项式(A(x)),求一个在(mod x^n)意义下的多项式(B(x)),使得(B^2(x) equiv A(x) (mod x^n))。若有多解,请取零次项系数较小的作为答案。

(H^2(x)equiv F(x)(mod x^n))

那么考虑

[egin{aligned} G(x)&equiv H(x)(mod x^n) \ G(x)-H(x)&equiv 0(mod x^n) \ (G(x)-H(x))^2&equiv 0(mod x^2n) \ G^2(x)-2H(x)G(x)+H^2(x)&equiv 0(mod x^{2n}) \ G(x)&equiv frac{F(x)+H^2(x)}{2H(x)}(mod x^{2n}) end{aligned} ]

倍增即可,只有一项的时候需要用二次剩余求根号

不过其实也可以先ln再exp回去

Code

int sqra[N + 5],sqrtmp[N + 5];
void sqr(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = sq;
        return;
    }
    sqr(a,ans,n + 1 >> 1);
    pre(n * 2);
    clear(sqra,maxn);
    clear(sqrtmp,maxn);
    INV(ans,sqra,n);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        sqrtmp[i] = a[i];
    ntt(sqra,1);
    ntt(sqrtmp,1);
    ntt(ans,1);
    int t = mypow(2,p - 2);
    for (int i = 0;i < maxn;i++)
        ans[i] = 1ll * ((sqrtmp[i] + 1ll * ans[i] * ans[i] % p) % p) * t % p * sqra[i] % p;
    ntt(ans,0);
    int inv = mypow(maxn,p - 2);
    for (int i = 0;i < n;i++)
        ans[i] = 1ll * ans[i] * inv % p;
    clear(ans,maxn,n);
}

多项式除法

问题:给定一个(n)次多项式(F(x))和一个(m)次多项式(G(x)),求出多项式(Q(x),R(x))满足:

  • (Q(x))次数为(n-m)(R(x))次数小于(m)
  • (F(x)=Q(x)G(x)+R(x))

首先设一个(n)项多项式(A(x)),假设一个(r)操作使得(A_r(x)=x^nA(frac{1}{x}))

那么可以看出(A_r[i]=A[n-i])

然后考虑下面的式子

[egin{aligned} F(x)&=Q(x)G(x)+R(x) \ F(frac{1}{x})&=Q(frac{1}{x})G(frac{1}{x})+R(frac{1}{x}) \ x^nF(frac{1}{x})&=x^{n-m}Q(frac{1}{x})x^mG(frac{1}{x})+x^{n-m+1}cdot x^{m-1}R(frac{1}{x}) \ F_r(x)&=Q_r(x)G_r(x)+x^{n-m+1}R_r(x) \ F_r(x)&equiv Q_r(x)G_r(x)+x^{n-m+1}R_r(x)(mod x^{n-m+1}) \ F_r(x)&equiv Q_r(x)G_r(x)(mod x^{n-m+1}) \ Q_r(x)&equiv F_r(x)G^{-1}_r(x)(mod x^{n-m+1}) end{aligned} ]

于是我们对(G_r(x))求逆,然后求得(Q_r(x)),再带回得到(Q(x))

最后根据(R(x)=F(x)-Q(x)G(x))求得(R(x))

时间复杂度(O(nlogn))

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int P = 998244353;
const int N = 1e6;
using namespace std;
int mypow(int a,int x)
{
    int s = 1;
    while (x)
    {
        if (x & 1)
            s = 1ll * s * a % P;
        a = 1ll * a * a % P;
        x >>= 1;
    }
    return s;
}
int n,m,F[N + 5],G[N + 5],Q[N + 5],GR[N + 5],w[N + 5][3],maxn,lg,rev[N + 5],Gi[N + 5],c[N + 5],FR[N + 5];
void R(int *a,int *b,int n)
{
    for (int i = 0;i <= n;i++)
        b[i] = a[n - i];
}
void ntt(int *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                int x = a[j + k],t = 1ll * w[i + k][typ] * a[i + j + k] % P;
                a[j + k] = (x + t) % P;
                a[j + k + i] = ((x - t) % P + P) % P;
            }
}
void ntt_pre(int n)
{
    maxn = 1;
    lg = 0;
    while (maxn <= n)
        maxn <<= 1,lg++;
    for (int i = 0;i < maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));    
}
void INV(int n,int *a,int *b)
{
    if (n == 1)
    {
        b[0] = mypow(a[0],P - 2);
        return;
    }
    INV((n + 1) >> 1,a,b);
    ntt_pre(n << 1);
    for (int i = 0;i < n;i++)
        c[i] = a[i];
    for (int i = n;i < maxn;i++)
        c[i] = 0;
    ntt(c,1);
    ntt(b,1);
    for (int i = 0;i < maxn;i++)
        b[i] = ((2ll * b[i] % P - 1ll * c[i] * b[i] % P * b[i] % P) % P + P) % P;
    ntt(b,2);
    int inv = mypow(maxn,P - 2);
    for (int i = 0;i < n;i++)
        b[i] = 1ll * b[i] * inv % P;
    for (int i = n;i < maxn;i++)
        b[i] = 0;
}
void NR(int *a,int *b,int n)
{
    for (int i = 0;i <= n;i++)
        b[n - i] = a[i];
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%d",&F[i]);
    for (int i = 0;i <= m;i++)
        scanf("%d",&G[i]);
    maxn = 1;
    while (maxn <= (n + m) * 2)
        maxn <<= 1;
    for (int i = 1;i < maxn;i <<= 1)
    {
        int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
        w[i][1] = w[i][2] = 1;
        for (int j = 1;j < i;j++)
            w[i + j][1] = 1ll * w[i + j - 1][1] * G1 % P,w[i + j][2] = 1ll * w[i + j - 1][2] * G2 % P;
    }
    R(G,GR,m);
    INV(n - m + 2,GR,Gi);
    R(F,FR,n);   
    ntt_pre(n * 2 - m + 2);
    ntt(FR,1);    
    ntt(Gi,1);
    for (int i = 0;i < maxn;i++)
        Gi[i] = 1ll * Gi[i] * FR[i] % P;
    ntt(Gi,2);
    int inv = mypow(maxn,P - 2); 
    for (int i = 0;i < maxn;i++)
        Gi[i] = 1ll * Gi[i] * inv % P;
    NR(Gi,Q,n - m);
    for (int i = 0;i <= n - m;i++)
        printf("%d ",Q[i]);
    cout<<endl;
    for (int i = n - m + 1;i < maxn;i++)
        Q[i] = 0;
    ntt_pre(n + m);
    ntt(Q,1);
    ntt(G,1);
    ntt(F,1);
    for (int i = 0;i < maxn;i++)    
        F[i] = ((F[i] - 1ll * Q[i] * G[i] % P) % P + P) % P;
    ntt(F,2);
    inv = mypow(maxn,P - 2);
    for (int i = 0;i < m;i++)
        printf("%d ",1ll * F[i] * inv % P);
    return 0;
}
原文地址:https://www.cnblogs.com/sdlang/p/13068320.html