【知识点】多项式相关算法

FFT&NTT:

前面有,不说了。

#include<bits/stdc++.h>
#define maxn 5000005
#define maxm 500005
#define inf 0x7fffffff
#define mod 998244353
#define g 3
#define ll long long
#define rint register ll
#define debug(x) cerr<<#x<<": "<<x<<endl
#define fgx cerr<<"--------------"<<endl
#define dgx cerr<<"=============="<<endl

using namespace std;
char s1[maxn],s2[maxn];
ll A[maxn],B[maxn],C[maxn],ind[maxn],res[maxn],N=1;

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll power(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod,b>>=1;
    }
    return ans;
}

inline void ntt(ll *a,ll op){
    for(rint i=0;i<N;i++)
        if(ind[i]<i)
            swap(a[i],a[ind[i]]);
    for(rint l=2;l<=N;l<<=1){
        ll p=power(g,(mod-1)/l);
        if(op==-1) p=power(p,mod-2);
        for(rint i=0;i<N;i+=l)
            for(ll j=i,w=1;j<i+l/2;j++,w=w*p%mod){
                ll x=a[j],y=w*a[j+l/2]%mod; 
                a[j]=(x+y)%mod;
                a[j+l/2]=(x-y+mod)%mod;
            }
    }
    if(op==-1){
        ll pw=power(N,mod-2);
        for(rint i=0;i<N;i++) a[i]=a[i]*pw%mod;
    }
}

int main(){
    scanf("%s%s",s1,s2);
    ll n=strlen(s1),m=strlen(s2);
    for(rint i=0;i<n;i++) A[n-i-1]=s1[i]-'0';
    for(rint i=0;i<m;i++) B[m-i-1]=s2[i]-'0';
    while(N<n+m) N<<=1;
    for(rint i=0;i<N;i++) ind[i]=(ind[i>>1]>>1)|((i&1)?(N>>1):0);
    ntt(A,1),ntt(B,1);
    for(rint i=0;i<N;i++) C[i]=A[i]*B[i]%mod;
    ntt(C,-1);
    for(rint i=0;i<(N<<1);i++){
        ll x=C[i],p=1;
        res[i]=x%10,x/=10;
        while(x) C[i+1]+=x%10*p,x/=10,p*=10;
    }
    rint p=N<<1; while(res[p]==0) p--;
    for(;p>=0;p--) printf("%d",res[p]);
    printf("
");
    return 0;
}
NTT

任意模数NTT:

设模数为m,假如不取模,可以得到每项系数的上界是$nm^{2}$。于是在模$geq nm^{2}$意义下做就行了。

可以直接拿三个NTT模数做个CRT合并。我一般取998244353,1004535809,469762049。

注意合并完前两个之后第三个直接合并会爆longlong,推一下系数之后在模m意义下算答案就行了。

这里复习一下CRT的公式:设n个方程组形如$x equiv a_i pmod{m_i }$,令$M_i =frac{prod limits_{j=1}^{n}{m_j }}{m_i }$。

那么有$x=sum limits_{i=1}^{n}{a_i M_i (frac{1}{M_i }pmod{m_i })}$。

#include<bits/stdc++.h>
#define maxn 1200005
#define maxm 500005
#define inf 0x7fffffff
#define m0 998244353
#define m1 1004535809
#define m2 469762049
#define g 3
#define ll long long
#define rint register ll
#define debug(x) cerr<<#x<<": "<<x<<endl
#define fgx cerr<<"--------------"<<endl
#define dgx cerr<<"=============="<<endl

using namespace std;
ll A[maxn],B[maxn],ind[maxn],D[maxn];
ll tA[maxn],tB[maxn],C[3][maxn],N=1;

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll mul(ll a,ll b,ll mod){
    ll ans=0;
    while(b){
        if(b&1) ans=(ans+a)%mod;
        a=(a+a)%mod,b>>=1;
    }
    return ans;
}
inline ll gcd(ll a,ll b){return (b==0)?a:gcd(b,a%b);}
inline ll power(ll a,ll b,ll mod){
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod,b>>=1; 
    }
    return ans;
}

inline void ntt(ll opt,ll *a,ll op){
    for(ll i=0;i<N;i++) 
        if(ind[i]<i)
            swap(a[ind[i]],a[i]);
    ll mod=(opt==0)?m0:(opt==1?m1:m2);
    for(ll l=2,p;l<=N;l<<=1){
        p=power(g,(mod-1)/l,mod);
        if(op==-1) p=power(p,mod-2,mod);
        for(ll i=0;i<N;i+=l)
            for(ll j=i,w=1;j<i+(l>>1);j++,w=w*p%mod){
                ll x=a[j],y=w*a[j+(l>>1)]%mod;
                a[j]=(x+y)%mod,a[j+(l>>1)]=(x-y+mod)%mod; 
            }
    }
    if(op==-1){
        ll inv=power(N,mod-2,mod);
        for(ll i=0;i<N;i++) a[i]=a[i]*inv%mod;
    }
}

int main(){
    ll n=read(),m=read(),P=read();
    for(ll i=0;i<=n;i++) A[i]=read();
    for(ll i=0;i<=m;i++) B[i]=read();
    while(N<=n+m+1) N<<=1;
    for(ll i=0;i<N;i++) ind[i]=(ind[i>>1]>>1)|((i&1)?(N>>1):0);
    for(ll i=0;i<3;i++){
        memcpy(tA,A,sizeof(A));
        memcpy(tB,B,sizeof(B));
        ntt(i,tA,1),ntt(i,tB,1);
        ll mod=(i==0)?m0:(i==1?m1:m2);
        for(ll j=0;j<N;j++) C[i][j]=tA[j]*tB[j]%mod;
        ntt(i,C[i],-1);
    }
    for(ll i=0;i<N;i++){
        ll a0=C[0][i],a1=C[1][i],a2=C[2][i],M=(ll)m0*(ll)m1/gcd(m0,m1);
        //cout<<a0<<" "<<a1<<" "<<a2<<" "<<M<<endl;
        ll res=(mul(M/m0,power((M/m0)%m0,m0-2,m0)*a0%M,M)+mul(M/m1,power((M/m1)%m1,m1-2,m1)*a1%M,M))%M;
        ll k=(a2%m2-res%m2+m2)*power(M%m2,m2-2,m2)%m2;
        M%=P,res%=P,D[i]=(k%P*M%P+res%P)%P;
    }
    for(ll i=0;i<=n+m;i++) printf("%lld ",D[i]);
    printf("
");
    return 0;
}
任意模数NTT

分治NTT:

问题大概是给一个多项式$G(x)$,求$F(x)$满足$F(i)=sum limits_{j=1}^{i}{F(j)G(i-j)}$。

直接自己卷自己没得操作,考虑CDQ分治,每次计算左边对右边的贡献。

对于$igeq mid$的i,左边给它的贡献就是$sum limits_{j=1}^{mid}{F(j)G(i-j)}$。

那么每次直接F卷G,然后把贡献对位加过去。

在做多次ntt时需要注意把高位清0。

#include<bits/stdc++.h>
#define maxn 1000005
#define maxm 500005
#define inf 0x7fffffff
#define mod 998244353
#define g 3
#define ll long long
#define rint register ll
#define debug(x) cerr<<#x<<": "<<x<<endl
#define fgx cerr<<"--------------"<<endl
#define dgx cerr<<"=============="<<endl

using namespace std;
ll N=1,mxn=1,G[maxn],F[maxn];
ll A[maxn],B[maxn],ind[maxn];

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll power(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod,b>>=1;
    }
    return ans;
}

inline void ntt(ll *a,ll op){
    for(ll i=0;i<N;i++) 
        if(ind[i]<i)
            swap(a[ind[i]],a[i]);
    for(ll l=2;l<=N;l<<=1){
        ll p=power(g,(mod-1)/l);
        if(op==-1) p=power(p,mod-2);
        for(ll i=0;i<N;i+=l)
            for(ll j=i,w=1;j<i+(l>>1);j++,w=w*p%mod){
                ll x=a[j],y=w*a[j+(l>>1)]%mod;
                a[j]=(x+y)%mod,a[j+(l>>1)]=(x-y+mod)%mod;
            }
    }
    if(op==-1){
        ll inv=power(N,mod-2);
        for(ll i=0;i<N;i++) a[i]=a[i]*inv%mod;
    }
}

inline void solve(ll l,ll r){
    if(l==r) return;
    ll mid=l+r>>1;
    solve(l,mid);
    for(ll i=l;i<=mid;i++) A[i-l]=F[i];
    for(ll i=mid+1;i<=r;i++) A[i-l]=0;
    for(ll i=l;i<=r;i++) B[i-l]=G[i-l];
    N=1; while(N<r-l+1) N<<=1;
    for(ll i=r+1;i<N;i++) A[i]=B[i]=0;
    for(ll i=0;i<N;i++) ind[i]=(ind[i>>1]>>1)|((i&1)?(N>>1):0);
    ntt(A,1),ntt(B,1);
    for(ll i=0;i<N;i++) A[i]=A[i]*B[i]%mod;
    ntt(A,-1);
    for(ll i=mid+1;i<=r;i++) F[i]=(F[i]+A[i-l])%mod;
    solve(mid+1,r);
}

int main(){
    ll n=read();
    for(ll i=1;i<n;i++) G[i]=read();
    while(mxn<n) mxn<<=1;
    F[0]=1,solve(0,mxn-1);
    for(ll i=0;i<n;i++) printf("%d ",F[i]);
    printf("
");
    return 0;
}
分治NTT

牛顿迭代法:

思路:

问题大概是给定多项式函数$G(F(x))$(例如多项式开根,多项式平方之类的)。

请你构造一个$F(x)$使得$G(F(x))equiv 0 pmod {x^{n}}$(即保留$x^{0},x^{1},cdots x^{n-1}$,去掉更高次项)。

考虑类似于求逆的方法,我们从小到大倍增做,设$G(F_0 (x))equiv 0 pmod {x^{frac{n}{2}}}$。

那么将G在$F_0 (x)$处泰勒展开(变量为多项式也可以做),得到

$G(F(x))=G(F_0 (x))+G'(F_0 (x))(F(x)-F_0(x))+frac{G''(F_0 (x))(F(x)-F_0 (x))^{2}}{2}+cdots$

注意到由于最后$G(F(x))$的每一项系数都得是0,而$G(F_0(x))$已有的${frac{n}{2}}$项系数肯定也得是0。

那么由答案的唯一性,$F(x)$和$F_0(x)$在前${frac{n}{2}}$项是相同的,也就是$(F(x)-F_0 (x))^{2}equiv 0 pmod{x^{n}}$。

于是实际上是$G(F(x))equiv G(F_0 (x))+G'(F_0 (x))(F(x)-F_0 (x)) pmod{x^{n}}$。

由于$G(F(x))equiv 0 pmod{x^{n}}$,可以得到$F(x)equiv F_0 (x)-frac{G(F_0 (x))}{G'(F_0 (x))}pmod{x^{n}}$。

复杂度$O(nlog{n})$,跟求逆一样。

应用:

1.多项式求逆:

给定$A(x)$,求$F(x)A(x)equiv 1pmod{x^{n}}$。

令$G(F(x))=F(x)A(x)-1$,根据上面牛顿迭代法的公式推一下,有

$F(x)equiv F_0 (x)-frac{G(F_0 (x))}{G'(F_0 (x))}pmod{x^{n}}$

$equiv F_0 (x)-frac{F_0(x)A(x)-1}{A(x)}pmod{x^{n}}$

$equiv F_0 (x)-(F_0(x)A(x)-1)F_0(x)pmod{x^{n}}$

$equiv 2F_0(x)-F_0 ^{2}(x)A(x)pmod{x^{n}}$

于是直接递归即可。

#include<bits/stdc++.h>
#define maxn 5000005
#define maxm 500005
#define inf 0x7fffffff
#define mod 998244353
#define g 3
#define ll long long
#define rint register ll
#define debug(x) cerr<<#x<<": "<<x<<endl
#define fgx cerr<<"--------------"<<endl
#define dgx cerr<<"=============="<<endl

using namespace std;
ll A[maxn],B[maxn],ind[maxn],F[maxn],G[maxn],N;

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll power(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod,b>>=1;
    }
    return ans;
}

inline void ntt(ll *a,ll op){
    for(ll i=0;i<N;i++)
        if(ind[i]<i)
            swap(a[ind[i]],a[i]);
    for(ll l=2;l<=N;l<<=1){
        ll p=power(g,(mod-1)/l);
        if(op==-1) p=power(p,mod-2);
        for(ll i=0;i<N;i+=l)
            for(ll j=i,w=1;j<i+(l>>1);j++,w=w*p%mod){
                ll x=a[j],y=w*a[j+(l>>1)]%mod;
                a[j]=(x+y)%mod,a[j+(l>>1)]=(x-y+mod)%mod; 
            }
    }
    if(op==-1){
        ll inv=power(N,mod-2);
        for(ll i=0;i<N;i++) a[i]=a[i]*inv%mod;
    }
}

inline void solve(ll n){
    if(n==1){B[0]=power(F[0],mod-2);return;}
    solve((n+1)>>1),N=1; while(N<=(n<<1)) N<<=1;
    for(ll i=0;i<N;i++) ind[i]=(ind[i>>1]>>1)|((i&1)?(N>>1):0);
    for(ll i=0;i<n;i++) A[i]=F[i];
    for(ll i=n;i<N;i++) A[i]=0;
    ntt(A,1),ntt(B,1);
    //for(int i=0;i<N;i++) cout<<A[i]<<" "<<B[i]<<endl;
    for(ll i=0;i<N;i++) B[i]=B[i]*(2-A[i]*B[i]%mod+mod)%mod;
    ntt(B,-1); 
    for(ll i=n;i<N;i++) B[i]=0;
}

int main(){
    ll n=read();
    for(ll i=0;i<n;i++) F[i]=read();
    solve(n);
    for(ll i=0;i<n;i++) printf("%d ",B[i]);
    printf("
");
    return 0;
}
多项式求逆

2.多项式开根:

给定$A(x)$,求满足$F^{2}(x)equiv A(x) pmod{x^{n}}$的$F(x)$。

令$G(F(x))=F^{2}(x)-A(x)$,套用上面的做法即可。

3.多项式除法:

给定$A(x),B(x)$,求$C(x),D(x)$满足$A(x)=B(x)C(x)+D(x)$。

如果没有D那就是直接求逆了,所以我们考虑怎么把D消掉。

令$A^{R}(x)=x^{n}A(frac{1}{x})$(其中n为A的最高次项),显然$A^{R}(x)$表示将$A(x)$的所有系数翻转形成的多项式。

显然原式翻转之后也成立,于是有$A^R(x)=B^R(x)C^R(x)+x^{n-m+1}D^R(x)$。

那么放到模$x^{n-m+1}$的意义下,就是求$C^R(x)equiv A^R(x)(B^R(x))^{-1}pmod {x^{n-m+1}}$。

直接求逆做出C然后减一下即可。

4.多项式ln:

给定$F(x)$,求$G(x)=ln F(x)$。

直接做不会做,考虑两边求导,有$G'(x)=frac{F'(x)}{F(x)}$。

求导的逆运算是积分,于是$G(x)=int frac{F'(x)}{F(x)}$。

做个多项式除法然后用高中数学方法积分即可。

5.多项式exp:

给定$A(x)$,求$F(x)=e^{A(x)}$。

直接做不会做,考虑两边取对数,有$ln {F(x)} -A(x)=0$。

令$G(F(x))=ln {F(x)} -A(x)$,套用牛顿迭代法即可。

拉格朗日插值:

求一个穿过n+1个点的n次多项式。

构造可得$f(x)=sum limits_{i=1}^{n+1}{y_{i}prod limits_{j eq i}{frac{x-x_{j}}{x_{i}-x_{j}}}}$。

#include<bits/stdc++.h>
#define maxn 100005
#define maxm 500005
#define inf 0x7fffffff
#define mod 998244353
#define ll long long
#define rint register ll
#define debug(x) cerr<<#x<<": "<<x<<endl
#define fgx cerr<<"--------------"<<endl
#define dgx cerr<<"=============="<<endl

using namespace std;
struct node{ll x,y;}P[maxn];

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll power(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod,b>>=1;
    }
    return ans;
}
inline ll inv(ll x){return power(x%mod,mod-2);}

int main(){
    ll n=read(),k=read();
    for(ll i=1;i<=n;i++) P[i].x=read(),P[i].y=read();
    ll res=0;
    for(ll i=1;i<=n;i++){
        ll mul=1;
        for(ll j=1;j<=n;j++){
            if(i==j) continue;
            mul=mul*(k-P[j].x+mod)%mod*inv(P[i].x-P[j].x+mod)%mod;
        }
        res=(res+mul*P[i].y%mod)%mod;
    }
    printf("%lld
",res);
    return 0;
}
拉格朗日插值
原文地址:https://www.cnblogs.com/YSFAC/p/13041515.html