Matrix Recurrence

给定矩阵$A,B$,且有

$$
f(0) = A ,f(i) =B * prod_{i=w(i)}^{i-1}f(i)
$$

求f(n)

其中,当w(i)单增时,可以做到$O(n*m^3)$,注意要优化取模运算。

对于加入的f(i),我们压入栈中,维护栈的 元素积。

同时维护栈之前的一段元素的后缀积,当w(i)超过非栈元素的右边界时,将栈内元素暴力化为后缀积。

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 
 5 #define LL long long
 6 #define N 1000010
 7 
 8 using namespace std;
 9 
10 int P;
11 
12 int m,n;
13 
14 struct MA
15 {
16         LL a[5][5];
17         void scan()
18         {
19                 for(int i=0,j;i<m;i++)
20                         for(j=0;j<m;j++) scanf("%lld",&a[i][j]);
21         }
22         void init()
23         {
24                 memset(a,0,sizeof(a));
25                 for(int i=0;i<m;i++) a[i][i]=1;
26         }
27         void print()
28         {
29                 for(int i=0;i<m;i++)
30                 {
31                         for(int j=0;j<m;j++) printf("%lld ",a[i][j]);
32                         printf("
");
33                 }
34         }
35 }A0,B;
36 
37 MA sta[N];
38 MA pre[N];
39 MA sumv,A;
40 int c[N],tot,r;
41 
42 MA mul(MA x,MA y)
43 {
44         MA ans;
45         for(int i=0,j,k;i<m;i++)
46                 for(j=0;j<m;j++)
47                 {
48                         ans.a[i][j]=0;
49                         for(k=0;k<m;k++)
50                                 ans.a[i][j] += x.a[i][k]*y.a[k][j];
51                 }
52         for(int i=0,j;i<m;i++)
53                         for(j=0;j<m;j++) ans.a[i][j]%=P;
54         return ans;
55 }
56 
57 void build()
58 {
59         int tmp=r;
60         for(int i=1;i<=tot;i++) pre[++r]=sta[i];
61         tot=0;
62         for(int i=r-1;i>=tmp+1;i--) pre[i]=mul(pre[i], pre[i+1]);
63         sumv.init();
64 }
65 
66 int main()
67 {
68         while(~scanf("%d%d%d",&n,&m,&P))
69         {
70                 A0.scan();
71                 B.scan();
72                 for(int i=1;i<=n;i++) scanf("%d",&c[i]);
73                 for(int i=0;i<=n;i++) pre[i].init();
74                 r=0;
75                 tot=0;
76                 pre[0]=A0;
77                 sumv.init();
78                 for(int i=1;i<=n;i++)
79                 {
80                         if(c[i]>r) build();
81                         A=mul(pre[c[i]],sumv);
82                         A=mul(A,B);
83                         sta[++tot]=A;
84                         sumv=mul(sumv,A);
85                 }
86                 A.print();
87         }
88         return 0;
89 }
View Code

当w(i)不单增时,我们可以维护$8$个长度为$6,6^2,6^3...6^8$的队列,每一次将新加入的元素先压入长度为$6$的队列,并$O(m^3*6)$维护后缀积,当队列满了之后,将其作为一个元素加入$6^2$的队列,同时维护至多$6$个元素的后缀积,当$6^2$满了之后$O(m^3*6^2)$ 暴力将其变为一个元素(算出$6^2$个元素的后缀积),并作为整体压入下一序列。

每个元素最多被更新了8次,所以 $O(8*n*m^3)$

原文地址:https://www.cnblogs.com/lawyer/p/6443625.html