洛谷 4723 【模板】线性递推——常系数线性齐次递推

题目:https://www.luogu.org/problemnew/show/P4723

题解:https://www.luogu.org/problemnew/solution/P4723

特征多项式:( f(x) = x^k - sumlimits_{i=1}^{k}f_i x^{n-i} )

这个多项式是转移矩阵 M 的化零多项式,所以 ( M^n ) 可以对该多项式取模,从而化成 ( M^n = sumlimits_{i=0}^{k-1} g_i M^i )

如果知道 ( g_i ) ,考虑给上式两边乘上初始的系数矩阵,得到 ( A(x)*M^n = sumlimits_{i=0}^{k-1}g_i A(x) M^i )

左边关注的只有该行向量的第一个位置的值。把所有行向量 “(A(x)*M^i)” 都改成其第一个位置的值(据说也是成立的?),发现左边是 a[n] 、右边是 a[i] 。

所以 ( a[n] = sumlimits_{i=0}^{k-1} g_i a[i] )

考虑求出 ( g_i ) 。

其实就是多项式 ( x^n ) 对多项式 ( x^k - sumlimits_{i=1}^{k}f_ix^{n-i} ) 取模得到的多项式。

用多项式 ( x ) 快速幂乘出 ( x^n ) 。一边乘一边对该多项式取模即可。

关于多项式取模的注意事项:

  ( A(x) = G(x)B(x)+R(x) ),其中 (A(x))是 n 次,(G(x))是 m 次的模数,(B(x))是 n-m 次的商,(R(x)) 是 m-1 次的余数。

  ( A^R(x) = G^R(x)B^R(x) + x^{n-m+1}R^R(x) )

  ( A^R(x) = G^R(x)B^R(x) ( mod x^{n-m+1} ) )

  ( B^R(x) = frac{ A^R(x) }{ G^R(x) } ( mod x^{n-m+1} ) )

  注意这一步!!!

  1. (A^R(x)) 是先翻转之后再对 (x^{n-m+1}) 取模。当然不对之取模也可。

  2.可以预处理 (G^R(x)) 的逆元,求逆就是在 mod (x^{n-m+1}) 意义下的;但注意在求逆之前,原始的数组不要对 ( x^{n-m+1} ) 取模!

  3.得出的 (B^R(x)) 需要取一下模。注意是先取了模,在翻转回去得到 (B(x))

  4.各种时刻要把临时数组使劲清空!!!!!!一直清空到 len 的程度。如果 len 变化了,还要一直清空到新的 len 的范围为止!!!

  得到 (B(x)) 之后 ( R(x) = A(x) - G(x)*B(x) )

  注意这里没有取模。但是得出的 ( R(x) ) 应该是 m-1 次多项式。需要算好之后手动把 m 次项及以后的系数清零。

  注意如果没有把 ( A(x) ) 化成点值,就不要写成 for( i=0;i<len;i++ ) ta[ i ] = a[ i ] - ta[ i ]*g[ i ] ; !!!

  可以预处理 ( G(x) ) 的逆元。反正 len 只有两种。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=(1<<17)+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,n2,m,f[N],h[N],g[N],ig[N],tp[N];
int len,r[N],wn[N],wn2[N],inv[N];
int ta[N],ret[N];
void Rev(int *a,int n)
{ int k=(n+1)>>1; for(int i=0;i<k;i++)swap(a[i],a[n-i]);}
void ntt_pre()
{
  int lm=(1<<17);
  for(int R=2;R<=lm;R<<=1)
    {
      wn[R]=pw(3,(mod-1)/R);
      wn2[R]=pw(3,(mod-1)-(mod-1)/R);
      inv[R]=pw(R,mod-2);
    }
}
void ntt_r()
{
  for(int i=0,j=len>>1;i<len;i++)
    r[i]=(r[i>>1]>>1)+((i&1)?j:0);
}
void ntt(int *a,bool fx)
{
  for(int i=0;i<len;i++)
    if(i<r[i])swap(a[i],a[r[i]]);
  for(int R=2;R<=len;R<<=1)
    {
      int Wn=(fx?wn2[R]:wn[R]);
      for(int i=0,m=R>>1;i<len;i+=R)
    for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%mod)
      {
        int x=a[i+j],y=(ll)w*a[i+m+j]%mod;
        a[i+j]=upt(x+y); a[i+m+j]=upt(x-y);
      }
    }
  if(!fx)return; int iv=inv[len];
  for(int i=0;i<len;i++)a[i]=(ll)a[i]*iv%mod;
}
void get_inv(int *a,int lm)
{
  memset(ret,0,sizeof ret);
  ret[0]=pw(a[0],mod-2);
  for(int t=2,yt=1,i,j;yt<lm;yt=t,t=len)
    {
      len=t<<1; ntt_r();
      for(i=0;i<t;i++)ta[i]=a[i];for(;i<len;i++)ta[i]=0;
      ntt(ret,0); ntt(ta,0);
      for(i=0;i<len;i++)
    ret[i]=(ll)ret[i]*upt(2-(ll)ret[i]*ta[i]%mod)%mod;
      ntt(ret,1);
      for(i=t;i<len;i++)ret[i]=0;
    }
  for(int i=lm;i<len;i++)ret[i]=0;
  memcpy(a,ret,sizeof ret);
}
void Mul(int *a,int *b)//(m-1)
{
  for(len=1;len<=n2;len<<=1); ntt_r();
  memcpy(ta,b,sizeof b);
  ntt(a,0); ntt(ta,0);
  for(int i=0;i<len;i++)a[i]=(ll)a[i]*ta[i]%mod;
  ntt(a,1);
}
void get_mod(int *a)//n2 % m
{
  int d=n2-m+1;
  for(len=1;len<d<<1;len<<=1); ntt_r();
  memcpy(ta,a,sizeof 4*(n2+1));
  Rev(ta,n2);
  for(int i=d;i<=len;i++)ta[i]=0;
  //i<=len not n2//rev before mod
  //not mod is ok,but clear (n2,len)!!!
  ntt(ta,0);
  for(int i=0;i<len;i++)ta[i]=(ll)ta[i]*ig[i]%mod;
  ntt(ta,1); for(int i=d;i<len;i++)ta[i]=0;
  Rev(ta,d-1);//mod before rev

  for(len=1;len<=n2;len<<=1); ntt_r();
  for(int i=d;i<len;i++)ta[i]=0;//////new len
  ntt(ta,0);
  for(int i=0;i<len;i++)
    ta[i]=(ll)ta[i]*g[i]%mod;
    //ta[i]=upt((a[i]-(ll)ta[i]*g[i])%mod);
  ntt(ta,1);
  for(int i=0;i<m;i++)ta[i]=upt(a[i]-ta[i]);
  for(int i=m;i<len;i++)ta[i]=0;//
  memcpy(a,ta,sizeof ta);
}
int main()
{
  n=rdn();m=rdn(); n2=2*(m-1); ntt_pre();
  for(int i=1;i<=m;i++)f[i]=rdn();
  for(int i=0;i<m;i++)h[i]=upt(rdn());
  for(int i=1;i<=m;i++)g[m-i]=ig[m-i]=upt(-f[i]);
  g[m]=ig[m]=1; Rev(ig,m);
  int d=n2-m+1;
  get_inv(ig,d);
  for(len=1;len<d<<1;len<<=1); ntt_r(); ntt(ig,0);//d<<1
  for(len=1;len<=n2;len<<=1); ntt_r(); ntt(g,0);
  memset(f,0,sizeof f); f[0]=tp[1]=1;
  while(n)
    {
      if(n&1){ Mul(f,tp);get_mod(f);}
      for(len=1;len<=n2;len<<=1); ntt_r();
      ntt(tp,0);
      for(int i=0;i<len;i++)tp[i]=(ll)tp[i]*tp[i]%mod;
      ntt(tp,1); get_mod(tp); n>>=1;
    }
  int ans=0;
  for(int i=0;i<m;i++)
    ans=(ans+(ll)f[i]*h[i])%mod;
  printf("%d
",ans);
  return 0;
}
View Code
原文地址:https://www.cnblogs.com/Narh/p/10944223.html