多项式取模优化线性递推总结

多项式取模优化线性递推总结

声明:博主已退役,这是以前的总结,如有错误望指正,如有问题不妨看看别人的博客

线性递推

即对于数列({a})

已知前(k)

且对于任意(nge k)

[a_n=sum_{i=0}^{k-1}f_ia_{n-1-i} ]

其中({f})是一个已知的数列

现在要求({a})的第(n)

暴力是(O(n*k))

如果(n)太大就会超时

常用的优化方法是矩阵快速幂

复杂度(O(k^3log n))

但如果(k)比较大也会超时

甚至还不如暴力

(log n)已经很优秀了

但是(k^3)实在太慢

注意到根据上面的式子

({a})所有数都可以被({a_0,a_1,...,a_{k-1}})线性表示

考虑已知(a_n)的线性表示如何求出(a_{2n})的线性表示

这里应用一个性质

[a_{n}=sum_{i=0}^{k-1}b_ia_i\ ]

[a_{n+x}=sum_{i=0}^{k-1}b_ia_{i+x} ]

所以

[a_{2n}=sum_{i=0}^{k-1}b_ia_{n+i}\ =sum_{i=0}^{k-1}b_isum_{j=0}^{k-1}b_ja_{i+j}\ =sum_{i=0}^{2k-2}a_isum_{j=0}^{i}b_jb_{i-j}\ (这里令b_x=0(xge k)) ]

这样就用({a_0,a_1,...,a_{2k-2}})线性表示了(a_{2n})

只要知道({a_k,a_{k+1},...,a_{2k-2}})的线性表示然后带入即可

这一步倒着依次带入

复杂度优化为(O(k^2log n))

Shlw loves matrix I

#include<bits/stdc++.h>

using namespace std;

#define gc c=getchar()
#define r(x) read(x)
#define ll long long

template<typename T>
inline void read(T&x){
    x=0;T k=1;char gc;
    while(!isdigit(c)){if(c=='-')k=-1;gc;}
    while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
}

const int p=1000000007;
const int N=2000;

inline int add(int a,int b){
    a+=b;
    if(a>=p)a-=p;
    return a;
}

int n,k;

int Tmp[N<<1];

inline void mul(int* a,int *b,int* f){
    memset(Tmp,0,k<<3);
    for(int i=0;i<k;++i){
        for(int j=0;j<k;++j){
            Tmp[i+j]=add(Tmp[i+j],(ll)a[i]*b[j]%p);
        }
    }
    for(int i=(k<<1)-2;i>=k;--i){
        for(int j=0;j<k;++j){
            Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
        }
    }
    memcpy(a,Tmp,k<<2);
}

int base[N],ans[N];

inline int solve(int* a,int* f,int n){
    if(n<k)return a[n];
    base[1]=ans[0]=1;
    for(;n;n>>=1){
        if(n&1)mul(ans,base,f);
        mul(base,base,f);
    }
    int ret=0;
    for(int i=0;i<k;++i)ret=add(ret,(ll)a[i]*ans[i]%p);
    return ret;
}

int a[N],f[N];

int main(){
    r(n);r(k);
    for(int i=0;i<k;++i)r(f[i]),f[i]=add(f[i],p);
    for(int i=0;i<k;++i)r(a[i]),a[i]=add(a[i],p);
    printf("%d
",solve(a,f,n));
}

多项式取模在哪里?

考虑上面代码的这一部分

for(int i=(k<<1)-2;i>=k;--i){
    for(int j=0;j<k;++j){
        Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
    }
}

考虑消去第(n)位的时候

相当于把多项式({-f_{k-1},-f_{k-2},...,-f_0,1})平移了(n-k)

并从原数列中减去它的(Tmp_n)

所以这段代码实际上是在对多项式({-f_{k-1},-f_{k-2},...,-f_0,1}​)取模

于是复杂度可以优化至(O(klog k log n))

【模板】线性递推

#include<bits/stdc++.h>

using namespace std;

#define gc c=getchar()
#define r(x) read(x)
#define ll long long

template<typename T>
inline void read(T&x){
    x=0;T k=1;char gc;
    while(!isdigit(c)){if(c=='-')k=-1;gc;}
    while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
}

const int N=500000;
const int p=998244353;
const int g=3;

inline int qpow(int a,int b){
    int ans=1;
    for(;b;b>>=1){
        if(b&1)ans=1ll*ans*a%p;
        a=1ll*a*a%p;
    }
    return ans;
}

namespace polynomial{
    int r[N];
    int NOW_LEN;
    inline void ntt(int *A,int len,int opt=1){
        if(len!=NOW_LEN)for(int i=0;i<len;++i)r[i]=(r[i>>1]>>1)|((i&1)*(len>>1));
        NOW_LEN=len;
        for(int i=0;i<len;++i)if(i<r[i])swap(A[i],A[r[i]]);
        for(int i=2;i<=len;i<<=1){
            int wn=qpow(g,(p-1)/i),n=i>>1;
            if(!opt)wn=qpow(wn,p-2);
            for(int j=0;j<len;j+=i){
                int w=1;
                for(int k=0;k<n;++k,w=1ll*w*wn%p){
                    int u=A[j+k],v=1ll*A[j+k+n]*w%p;
                    A[j+k]=(u+v)%p;
                    A[j+k+n]=(u-v+p)%p;
                }
            }
        }
        if(!opt){
            int inv=qpow(len,p-2);
            for(int i=0;i<len;++i)A[i]=1ll*A[i]*inv%p;
        }
    }
    
    int Tmp_mul1[N],Tmp_mul2[N];
    inline void mul(int *A,int *B,int *C,int lenA,int lenB){
        int len=1,lenC=lenA+lenB-1;
        while(len<lenC)len<<=1;
        memcpy(Tmp_mul1,A,lenA<<2);
        memcpy(Tmp_mul2,B,lenB<<2);
        memset(Tmp_mul1+lenA,0,(len-lenA)<<2);
        memset(Tmp_mul2+lenB,0,(len-lenB)<<2);
        ntt(Tmp_mul1,len);ntt(Tmp_mul2,len);
        for(int i=0;i<len;++i)C[i]=1ll*Tmp_mul1[i]*Tmp_mul2[i]%p;
        ntt(C,len,0);
        memset(C+lenC,0,(len-lenC)<<2);
    }
    
    int Tmp_inv[N];
    inline void inverse(int *A,int *Inv,int len){
        memset(Inv,0,len<<2);
        Inv[0]=qpow(A[0],p-2);
        for(int i=2;i<=len;i<<=1){
            memcpy(Tmp_inv,A,i<<2);
            memset(Tmp_inv+i,0,i<<2);
            ntt(Inv,i<<1);ntt(Tmp_inv,i<<1);
            for(int k=0;k<i<<1;++k)Inv[k]=Inv[k]*(2-1ll*Inv[k]*Tmp_inv[k]%p+p)%p;
            ntt(Inv,i<<1,0);
            memset(Inv+i,0,i<<2);
        }
    }
    
    int A0[N],B0[N];
    inline void mod(int A[],int B[],int R[],int lenA,int lenB){
        int len=1,t=lenA-lenB+1;
        while(len<=t)len<<=1;
        reverse_copy(B,B+lenB,A0);
        inverse(A0,B0,len);
        reverse_copy(A,A+lenA,A0);
        mul(A0,B0,A0,t,t);
        reverse(A0,A0+t);
        for(len=1;len<(lenA<<1);len<<=1);
        copy(B,B+lenB,B0);
        mul(A0,B0,R,t,lenB);
        for(int i=0;i<lenB-1;++i)R[i]=(A[i]-R[i]+p)%p;
    }
}

int n,k;

int Tmp[N<<1];

inline void mul(int a[],int b[],int f[]){
    polynomial::mul(a,b,Tmp,k,k);
    polynomial::mod(Tmp,f,a,2*k,k+1);
}

int base[N],ans[N];

inline int solve(int a[],int f[],int n){
    if(n<k)return a[n];
    
    reverse(f,f+k);
    for(int i=0;i<k;++i)f[i]=p-f[i];
    f[k]=1;
    
    base[1]=ans[0]=1;
    for(;n;n>>=1){
        if(n&1)mul(ans,base,f);
        mul(base,base,f);
    }
    int ret=0;
    for(int i=0;i<k;++i)ret=(ret+(ll)a[i]*ans[i]%p)%p;
    return ret;
}

int a[N],f[N];

int main(){
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
    r(n);r(k);
    for(int i=0;i<k;++i)r(f[i]),f[i]=(f[i]+p)%p;
    for(int i=0;i<k;++i)r(a[i]),a[i]=(a[i]+p)%p;
    printf("%d
",solve(a,f,n));
}
原文地址:https://www.cnblogs.com/yicongli/p/11143002.html