常系数齐次线性递推

常系数齐次线性递推

定义

对于一个递推式,如果 (a_n = displaystyle sum_{i=1}^{k}{a_{n-i}*f_i}) ,那么称这个 (a) 序列满足 (n) 阶常系数齐次线性递推关系

矩阵优化

如果我们已知一个满足 (k) 阶常系数齐次线性递推关系的序列 (a) ,关系式为 (a_n = displaystyle sum_{i=1}^{k}{a_{n-i} * f_i}) ,要求出 (a_n) 的值

可以设计出一个转移矩阵进行矩阵优化

如果初始阵为

[A= egin{pmatrix} a_{n-1}\ a_{n-2}\ vdots\ a_{n-k} end{pmatrix} ]

转移阵为

[M= egin{pmatrix} f_1 quad &f_2 quad &f_3 quad &dots quad &f_{k-1}\ 1 quad &0 quad &0 quad &dots quad &0\ 0 quad &1 quad &0 quad &dots quad &0\ vdots quad &vdots quad &vdots quad &ddots quad &vdots \ 0 quad &0 quad &0 quad &dots &1 end{pmatrix} ]

那么 (M imes A) 可以得到矩阵

[egin{pmatrix} a_{n}\ a_{n-1}\ vdots\ a_{n-k-1} end{pmatrix} ]

那么我们可以设计初始矩阵为

[A= egin{pmatrix} a_{k-1}\ a_{k-2}\ vdots\ a_{0} end{pmatrix} ]

此时我们可以用 (M^n imes A) 来得到我们需要的矩阵

特征多项式

  • 若有常数 (lambda) ,向量 (vec{v}) ,满足 (lambda vec{v} = A vec{v}) ,那么我们称 (lambda) 为矩阵 (A) 的特征值,称 (vec{v}) 为矩阵的特征向量

那么我们可以得到 ((lambda I - A) vec{v}= 0) ,其中 (0) 表示零矩阵

此时该式有解当且仅当 (det(lambda I - A) = 0)

这个行列式的展开形式为一个 (k) 次多项式,此时,我们称这个 (k) 次多项式为 (A) 的特征多项式,该多项式的值为 (0) 时的方程称为 (A) 的特征方程

记特征多项式为 (f(x) = det(lambda I - A)) ,那么可以表示为 (f(x) = displaystyle prod_{i}{lambda_i - x})

凯莱-哈密顿定理 (Cayley-Hamilton定理)

  • 对于 (A) 的特征多项式 (f(x)) ,有 (f(A) = 0)

证明

(f(A) =displaystyle prod_{i}{lambda_i I - A})

对于这个 (k) 次的特征多项式,其有 (k) 个解,也就是说矩阵 (A)(k) 个特征值以及 (k) 个线性无关的特征向量,而如果 (f(A)) 得到的矩阵乘上任意一个特征向量都可以得到零矩阵,那么就可以推出 (f(A)) 为零矩阵

首先,可以证明, ((lambda_i I - A)(lambda_j I - A) = (lambda_j I - A)(lambda_i I - A))

那么

[egin{aligned} f(A) imes vec{v_i} &= (displaystyle prod_{j}{lambda_j I - A}) imes vec{v_i} \ &= (displaystyle prod_{j eq i}{lambda_j I - A}) imes ((lambda_i I - A) imes vec{v_i}) end{aligned} ]

由特征值与特征向量的定义式可知: ((lambda_i I - A) vec{v_i} = 0)

所以 (forall f(A) imes vec{v_i} =0)

得证

常系数齐次线性递推优化

设矩阵 (M) 的特征多项式为 (f(x))

对于我们要求的 (M^n) ,可以写出

[M^n = f(M) imes g(M) + R(M) ]

(f(M)=0) ,那么就有 (M^n = R(M))

所以,我们只需要做 (M^n ~\% ~f(M)) 就可以了

考虑 (f(M)) 怎么求

按照定义 (f(x) = det(x I - M)) ,所以这里有

[f(x)= egin{vmatrix} x-a_1 quad &-a_2 quad &-a_3 quad &dots quad &-a_{k-1}quad &-a_{k}\ -1 quad &x quad &0 quad &dots quad &0 quad &0\ 0 quad &-1 quad &x quad &dots quad &0 quad &0\ vdots quad &vdots quad &vdots quad &ddots quad &vdots quad &vdots\ 0 quad &0 quad &0 quad &dots &-1 quad &x end{vmatrix} ]

将其进行展开,有

[egin{aligned} f(x) &= displaystyle (x-a_1)M_{11} - a_2 M_{12} dotsb - a_k M_{1k}\ &= x^k - a_1 x^{k-1} - a_2 x^{k-2} - dotsb a_k end{aligned} ]

处理 (M^n ~\%~ f(M)) 我们可以在做快速幂的时候进行实现,所以这里的实现只需要在快速幂的时候做多项式取模即可,复杂度为 (O(k^2 log n))

而这里我们做快速幂的时候还会涉及多项式乘法,那么可以进行 NTTFFT 优化,做到 (O(k log k log n))

那么我们这里已经快速处理出了 (M^n) ,之后直接和初始的矩阵 (A) 相乘即可求得答案

例题

首先,恰好 (K) 个的概率不容易处理,可以考虑将其处理为至少有 (K) 个的概率减去至少有 (K-1) 个的概率

(f_i) 表示在底部的一个宽为 (i) 的矩形,并且第 (i) 个位置恰好为不合法的位置

那么最终答案就是 (frac{f_{n+1}}{1-q})

这里有 (f_n = displaystyle sum_{i=1}^{n}{f_{n-i+1} * g_i}) ,这里 (g_i) 表示出现长度宽为 (i) 的矩形的概率

(dp_{i,j}) 表示一个宽为 (i) ,高位 (j) 的矩形

那么这里 (g_i = displaystyle sum_{j=1}^{infty}{dp_{i,j}})

而这个 (dp_{i,j}) 实际上也是可以递推的,有递推式为

[dp_{i,j} = [i*(j-1) leq K] (1-q) q^{j-1} displaystyle sum_{k=1}^{i}{(displaystyle sum_{q > j} dp_{k-1,q})(displaystyle sum_{q geq j}{dp_{i-k,q}})} ]

表示 (dp_{i,j}) 可以由宽 (k-1) 中那些高度大于 (j) 的矩形的情况在和宽 (i-k) ,高大于等于 (j) 的那些矩形拼起来,再乘上当前宽度为 (i) 的这个地方的高度只有 (j) 的部分的概率

这样,这里 (i imes (j-1) leq K) ,所以对 (i,j) 的枚举的复杂度为 (O(K log K)) ,再加上枚举 (k,q) 的枚举,复杂度为 (O(K^2 log^2 K))

而这个式子中对 (q) 枚举的部分是可以后缀和优化的(并且在 (f) 的求解中应用),那么此时求 (dp) 数组的复杂度可以被优化到 (O(k^2 log k))

(f) 数组的求解显然满足常系数其次线性递推的形式,可以直接套用优化,那么总复杂度为 (O(k^2 log k)) (完全没有必要用 FFTNTT 优化,直接暴力做多项式取模即可)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<cstring>
#define ll long long
#define ld long double

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch))
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(isdigit(ch))
	{
		x=(x<<1)+(x<<3)+ch-'0';
		ch=getchar();
	}
	return x*f;
}

const ll inf=1e18;
const ll maxn=2e3+10;
const ll mod=998244353;
ll N,K,X,Y,p,q;
ll pw[maxn];
ll dp[maxn][maxn],sum[maxn][maxn];
ll I[maxn],A[maxn],M[maxn],ret[maxn],f[maxn];
ll tmp1[maxn],tmp2[maxn];

inline ll ksm(ll a,ll b,ll p)
{
	ll ret=1;
	while(b)
	{
		if(b&1) ret=ret*a%p;
		a=a*a%p;
		b>>=1;
	}
	return ret;
}

inline ll sol(ll x)
{
	ll ans=0;
	memset(M,0,sizeof(M));
	memset(A,0,sizeof(A));
	memset(f,0,sizeof(f));
	memset(I,0,sizeof(I));
	memset(dp,0,sizeof(dp));
	memset(sum,0,sizeof(sum));
	memset(ret,0,sizeof(ret));
	for(int i=0;i<=x+2;i++) sum[0][i]=dp[0][i]=1;
	for(int j=x;j>=1;j--)
	{
		for(int i=1;i*j<=x;i++)
		{
			for(int k=1;k<=i;k++)
			{
				(dp[i][j]+=sum[k-1][j+1]*sum[i-k][j]%mod*p%mod*pw[j]%mod)%=mod;
			}
			sum[i][j]=(sum[i][j+1]+dp[i][j])%mod;
		}
	}
//	for(int j=1;j<=x;j++)
//	{
//		for(int i=1;i*j<=x;i++)
//		{
//			printf("%d %d %lld %lld
",i,j,dp[i][j],sum[i][j]);
//		}
//	}
	x++;
	for(int i=1;i<=x;i++) I[i]=sum[i-1][1]*p%mod;
	A[0]=1;
	for(int i=1;i<=x;i++)
	{
		for(int j=0;j<i;j++)
		{
			(A[i]+=A[j]*I[i-j]%mod)%=mod;
		}
	}
	for(int i=1;i<=x;i++) f[x-i]=mod-I[i];
	f[x]=1;
//	for(int i=0;i<=x;i++) printf("%lld ",I[i]);
//	putchar(10);
//	for(int i=0;i<=x;i++) printf("%lld ",A[i]);
//	putchar(10);
//	for(int i=0;i<=x;i++) printf("%lld ",f[i]);
//	putchar(10);
	ret[0]=1;
	M[1]=1;
	ll b=N+1;
	while(b)
	{
		if(b&1)
		{
			memcpy(tmp1,ret,sizeof(ret));
			memset(ret,0,sizeof(ret));
			for(int i=0;i<=x;i++)
			{
				for(int j=0;j<=x;j++)
				{
					(ret[i+j]+=M[i]*tmp1[j])%=mod;
				}
			}
			for(int i=2*x;i>=x;i--)
			{
				for(int j=0;j<=x;j++)
				{
					(ret[i+j-x]+=mod-ret[i]*f[j]%mod)%=mod;
				}
			}
		}
		memcpy(tmp1,M,sizeof(M));
		memcpy(tmp2,M,sizeof(M));
		memset(M,0,sizeof(M));
		for(int i=0;i<=x;i++)
		{
			for(int j=0;j<=x;j++)
			{
				(M[i+j]+=tmp1[i]*tmp2[j]%mod)%=mod;
			}
		}
		for(int i=2*x;i>=x;i--)
		{
			for(int j=0;j<=x;j++)
			{
				(M[i+j-x]+=mod-M[i]*f[j])%=mod;
			}
		}
		b>>=1;
	}
	for(int i=0;i<=x;i++) (ans+=ret[i]*A[i])%=mod;
//	printf("%lld
",ans);
	return ans*ksm(p,mod-2,mod)%mod;
}

int main(void)
{
//	freopen("1.in","r",stdin);
//	freopen("1.ans","w",stdout);
	N=read(),K=read(),X=read(),Y=read();
	q=X*ksm(Y,mod-2,mod)%mod;
	p=(1-q+mod)%mod;
	pw[0]=1;
	for(int i=1;i<=K;i++) pw[i]=pw[i-1]*q%mod;
	printf("%lld
",(sol(K)-sol(K-1)+mod)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/jd1412/p/15220789.html