【模板】常系数线性递推

问题描述

求一个满足 $K$ 阶齐次线性递推数列 $a_i$ 的第 $n$ 项,即:$a_n = sum_{i=1}^k f_i imes a_{n-i}$.

分析

  首先写成矩阵快速幂

$$left( egin{bmatrix} f_1 &f_2 &f_3 &f_4 & cdots &f_{k-2} &f_{k-1} \ 1 &0 &0 &0 & cdots &0 &0 \ 0 &1 &0 &0 & cdots &0 &0\ cdots & cdots& cdots & cdots & cdots & cdots & cdots\ 0 &0 &0 &0 & cdots &1 &0 end{bmatrix} ight) ^n imes egin{bmatrix} a_{k-1} \ a_{k-2} \ cdots \ a_{1} \ a_{0}end{bmatrix} =egin{bmatrix} a_{n+k-1} \ a_{n+k-2} \ cdots \ a_{n+1} \ a_{n}end{bmatrix}$$

所以我们只需要算出 $M^N imes A$,然后取最后一个数即可。

使用矩阵快速幂,复杂度 $O(k^3 log_2n)$.

  Carlay-Hamilton定理

设有 $k$ 个特征值的矩阵 $A$的特征多项式为 $f(lambda ) =prod_{i=1}^k(lambda_i - x)$,则有 $f(A) = 0$,$0$  为零矩阵。

  用这个定理来优化递推

由前面的矩阵快速幂,我们只要求出 $M^n$就可以了。

我们考虑 $M$ 的特征多项式 $f(x)$,这是一个 $k$ 次多项式。我们对 $M^n$ 做带余除法 $M^n = f(M) imes g(M) + R(M)$。

由于 $f(M) = 0$,所以 $M^n equiv  R(M) (mod f(M))$,$R(M)$ 是一个次数不超过 $k-1$ 的多项式。

也就是说,我们只要求出 $M^n \% f(M)$就可以了

但是要怎么求呢?我们考虑快速幂的过程(就是倍增)

假设我们现在已知 $g(M)=M^{2^i} \% f(M)$,现在要求  $h(M)= M^{2^{i+1}} \% f(M)$。

一个直接的想法是令 $H(M)=g(M) imes g(M)$。但是这样做 $H(x)$ 的次数是 $2k-2$次的。

那么我们考虑原本的递推关系,$a_n=sumlimits_{i=1}^{k}a_{n-i}*f_i$

不难得到 $M^n=sumlimits_{i=1} ^{k} M^{n-i} imes f_{i}$

所以我们可以用这个式子将多余的系数都向前压一位。

这样我们可以得到一个 $O(k^2  log_2 n)$ 的做法。

那么有没有优化的余地呢?我们从倍增的过程入手,可以发现 $H(M) = g(M) imes g(M)$ 的过程可以用FFT/NTT加速至 $O(k log_2k)$。

 现在只要解决压系数就可以了,把 $H(M)$ 模 $f(M)$ 即可。

我们的推导一直用到这个特征多项式 $f(x)$,如何求得呢?

根据定义, $f(x) = det(xI - M)$,得到

$$f(x) = |x I - M| = egin{bmatrix} x- a_1 & -a_2 & -a_3 & cdots & -a_{k - 2} & -a_{k - 1} & -a_k \ -1 & x & 0 & cdots & 0 & 0 & 0 \ 0 & -1 & x & cdots & 0 & 0 & 0 \ 0 & 0 & -1 & cdots & 0 & 0 & 0 \ vdots & vdots & vdots & ddots & vdots & vdots & vdots \ 0 & 0 & 0 & cdots & -1 & x & 0 \ 0 & 0& 0 & cdots & 0 & -1 & xend{bmatrix}$$

对第一行进行展开,得到

$$f(x) = (x - a_1)M_{11} + (-a_2)M_{12} + cdots + (-a_k)M_{1n} = x ^ k - a_1 x ^ {k - 1} - a_2x ^ {k - 2} - cdots - a_k$$

代码1:

$O(k log_2k log_2n)$的做法

思路其实就是去做一个类似快速幂的操作,然后把乘法改成多项式下的,取模也改成多项式下的

// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
using namespace std;

typedef long long ll;
const ll mod=998244353;
const int N=65536+10;
int n;int k;int rv[20][N];ll rt[20][20];int Len;ll tr1[N];ll tr2[N];long long st[N];long long xs[N];
ll sg[N];ll a[N];ll res[N];ll irg[N];ll q[N];ll rf[N];int DL=-1;ll ans=0;ll ret[N];
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
inline void ntt(ll* a,int o,int len,int d)//ntt
{
    for(int i=0;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);
    for(int k=1,j=1;k<len;k<<=1,j++)
        for(int s=0;s<len;s+=(k<<1))
            for(int i=s,w=1;i<s+k;i++,w=w*rt[o][j]%mod)
            {ll a0=a[i];ll a1=a[i+k]*w%mod;a[i]=(a0+a1)%mod,a[i+k]=(a0+mod-a1)%mod;}
    if(o==1){ll inv=po(len,mod-2);for(int i=0;i<len;i++)(a[i]*=inv)%=mod;}
}
inline void poly_inv(ll* a,ll* b,int len)//求逆
{
    b[0]=po(a[0],mod-2);
    for(int k=1,j=0;k<=len;k<<=1,j++)
    {
        for(int i=0;i<k;i++)tr1[i]=a[i];for(int i=0;i<k;i++)tr2[i]=b[i];
        ntt(tr1,0,k<<1,j);ntt(tr2,0,k<<1,j);
        for(int i=0;i<(k<<1);i++)b[i]=tr2[i]*(2+mod-tr1[i]*tr2[i]%mod)%mod;
        ntt(b,1,k<<1,j);for(int i=k;i<(k<<1);i++)b[i]=0;
    }
}
inline void poly_mod(ll* a)//取模
{
    int mi=(k<<1);while(a[--mi]==0);if(mi<k)return;
    for(int i=0;i<(Len<<1);i++)rf[i]=0;for(int i=0;i<=mi;i++)rf[i]=a[i];
    reverse(rf,rf+mi+1);for(int i=mi-k+1;i<=mi;i++)rf[i]=0;ntt(rf,0,Len<<1,DL+1);
    for(int i=0;i<(Len<<1);i++)q[i]=(rf[i]*irg[i])%mod;ntt(q,1,(Len<<1),DL+1);
    for(int i=mi-k+1;i<=(Len<<1);i++)q[i]=0;reverse(q,q+mi-k+1);ntt(q,0,(Len<<1),DL+1);
    for(int i=0;i<(Len<<1);i++)(q[i]*=sg[i])%=mod;ntt(q,1,(Len<<1),DL+1);
    for(int i=0;i<k;i++)(a[i]+=mod-q[i])%=mod;for(int i=k;i<=mi;i++)a[i]=0;
}
int main()
{
    for(int i=0;i<=15;i++)
        for(int j=0;j<(1<<(i+1));j++)rv[i][j]=(rv[i][j>>1]>>1)|((j&1)<<i);
    for(int t=2,j=1;j<=18;t<<=1,j++)rt[0][j]=po(3,(mod-1)/t);
    for(int t=2,j=1;j<=18;t<<=1,j++)rt[1][j]=po(332748118,(mod-1)/t);
    scanf("%d%d",&n,&k);
    for(Len=1;Len<=k;Len<<=1,DL++); //预处理
    for(int i=1;i<=k;i++){scanf("%lld",&xs[i]);xs[i]=xs[i]<0?xs[i]+mod:xs[i];}
    for(int i=0;i<k;i++){scanf("%lld",&st[i]);st[i]=st[i]<0?st[i]+mod:st[i];}
    for(int i=1;i<=k;i++)sg[k-i]=mod-xs[i];sg[k]=1;for(int i=0;i<=k;i++)ret[i]=sg[i];
    for(int i=0;i<=k;i++)rf[i]=sg[i];reverse(rf,rf+k+1);poly_inv(rf,irg,Len);
    for(int i=0;i<=k;i++)rf[i]=0;ntt(sg,0,Len<<1,DL+1);ntt(irg,0,Len<<1,DL+1);a[1]=1;res[0]=1;
    while(n)//快速幂
    {
        if(n&1)
        {
            ntt(res,0,Len<<1,DL+1);ntt(a,0,Len<<1,DL+1);
            for(int i=0;i<(Len<<1);i++)(res[i]*=a[i])%=mod;
            ntt(res,1,Len<<1,DL+1);ntt(a,1,Len<<1,DL+1);poly_mod(res);
        }ntt(a,0,Len<<1,DL+1);for(int i=0;i<(Len<<1);i++)(a[i]*=a[i])%=mod;
        ntt(a,1,Len<<1,DL+1);poly_mod(a);n>>=1;
    }
    for(int i=0;i<k;i++)(ans+=res[i]*st[i])%=mod;
    printf("%lld",ans);
    return 0;
}

 代码2:

$O(k^2 log_2n)$的做法

 BZOJ4161

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cassert>
typedef long long ll;
typedef unsigned long long ull;
using namespace std;

const int P=1000000007;
const int MAXN=4010;    //2*k+10

int n,k,ans;
int f[MAXN],h[MAXN];

struct Matrix{ //其实是多项式
    int a[MAXN];
    Matrix (){memset(a,0,sizeof a);}
    int& operator [] (const int &i) {return a[i];}
    int operator [] (const int &i) const {return a[i];}
    inline Matrix operator * (const Matrix &rhs) const
    {
        Matrix ret;
        for(int i=0;i<k;i++)
            for(int j=0;j<k;j++)
                (ret[i+j]+=1ll*a[i]*rhs[j]%P)%=P;
        for(int i=2*k-2;i>=k;ret[i--]=0)
            for(int j=1;j<=k;j++) //这里就是多项式取模优化的地方
                (ret[i-j]+=1ll*ret[i]*f[j]%P)%=P; //可以认为是暴力向前压系数
        return ret;
    }
}res;

Matrix ksm(Matrix a,int b)
{
    Matrix ret;
    ret[0]=1;
    for(;b;a=a*a,b>>=1) if(b&1) ret=ret*a;
    return ret;
}

int main()
{
    scanf("%d%d",&n,&k);
    for(int i=1;i<=k;i++) scanf("%d",&f[i]),f[i]=f[i]>0?f[i]:f[i]+P;
    for(int i=0;i<k;i++) scanf("%d",&h[i]),h[i]=h[i]>0?h[i]:h[i]+P;
    if(n<k) printf("%d
",h[n]);
    res[1]=1;ans=0;
    res=ksm(res,n);
    for(int i=0;i<k;i++)  ans=(ans+1ll*res[i]*h[i]%P)%P;
    printf("%d
",ans);
}

参考链接:

1. https://www.luogu.org/problemnew/solution/P4723

2. https://www.luogu.org/blog/Zhang-RQ/chang-ji-shuo-ji-ci-xian-xing-di-tui-chu-tan

原文地址:https://www.cnblogs.com/lfri/p/11236711.html