矩阵乘法优化dp

题表

P3390 【模板】矩阵快速幂

板子

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int MAXN=110;
 4 const int mod=1e9+7;
 5 int n;
 6 struct Matrix{
 7     int jz[MAXN][MAXN];
 8     inline void init(){
 9         for(int i=1;i<=n;++i)
10             for(int j=1;j<=n;++j)
11                 jz[i][j]=0;
12         return ;
13     }
14     friend Matrix operator *(Matrix a,Matrix b){
15         Matrix tmp;tmp.init();
16         for(int k=1;k<=n;++k)
17             for(int i=1;i<=n;++i)
18                 for(int j=1;j<=n;++j)
19                     tmp.jz[i][j]=(tmp.jz[i][j]+1LL*a.jz[i][k]*b.jz[k][j]%mod)%mod;
20         return tmp;
21     }
22 };
23 Matrix ans,a;
24 long long K;
25 int main(){
26     //freopen("P3390_1.in","r",stdin);
27     scanf("%d%lld",&n,&K);
28     ans.init();a.init();
29     for(int i=1;i<=n;++i)ans.jz[i][i]=1;
30     for(int i=1;i<=n;++i)
31         for(int j=1;j<=n;++j)
32             scanf("%d",&a.jz[i][j]);
33     for(;K;K>>=1,a=a*a)
34         if(K&1)ans=ans*a;
35     for(int i=1;i<=n;++i,puts(""))
36         for(int j=1;j<=n;++j)
37             printf("%d ",ans.jz[i][j]);
38     return 0;
39 } 
View Code

P1962 斐波那契数列

最简单的优化dp问题,但是不要拘泥于这种优化dp的矩阵书写格式,理解透彻给dp乘系数的算法

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int mod=1e9+7;
 4 int F[3];
 5 struct Matrix{
 6     int jz[3][3];
 7     inline void init(){
 8         for(int i=1;i<=2;++i)
 9             for(int j=1;j<=2;++j)
10                 jz[i][j]=0;
11         return ;
12     }
13     friend Matrix operator *(const Matrix a,const Matrix b){
14         Matrix tmp;tmp.init();
15         for(int k=1;k<=2;++k)
16             for(int i=1;i<=2;++i)
17                 for(int j=1;j<=2;++j)
18                     tmp.jz[i][j]=(tmp.jz[i][j]+1LL*a.jz[i][k]*b.jz[k][j]%mod)%mod;
19         return tmp;
20     }
21 };
22 Matrix a,ans;
23 long long n;
24 int main(){
25     //freopen("P1939_3.in","r",stdin);
26     scanf("%lld",&n);
27     if(n<=2)printf("%d\n",1);
28     else{
29         F[1]=F[2]=1;
30         a.jz[1][1]=0;a.jz[1][2]=1;
31         a.jz[2][1]=1;a.jz[2][2]=1;
32         ans.init();
33         for(int i=1;i<=2;++i)ans.jz[i][i]=1;
34         n-=2;
35         for(;n;n>>=1,a=a*a)
36             if(n&1)ans=ans*a;
37         printf("%d\n",(1LL*F[1]*ans.jz[1][2]%mod+1LL*F[2]*ans.jz[2][2]%mod)%mod);
38     }
39     return 0;
40 } 
View Code

P1397 [NOI2013] 矩阵游戏

做题时拘泥于上题的构造矩阵格式,写了一个大小为3*3的矩阵

看题解又写了一个2*2的矩阵$$\left[\begin{matrix} a & 0 \\ 1 & 1 \end{matrix} \right]$$

和答案矩阵$$\left[\begin{matrix} ans & b \\ 0 & 0 \end{matrix} \right]$$ 

然后发现这样行和列答案矩阵不统一,应该把第一个矩阵的1和答案矩阵的b互换

即答案矩阵$ANS$为$$\left[\begin{matrix} ans & 1 \\ 0 & 0 \end{matrix} \right]$$ 

行转移A为$$\left[\begin{matrix} c & 0 \\ d & 1 \end{matrix} \right]$$ 

列转移B为$$\left[\begin{matrix} a & 0 \\ b & 1 \end{matrix} \right]$$ 

答案即为$ANS*B^{m-1}*{(B^{m-1}*A)}^{n-1}$

注意矩阵乘不满足交换律,满足结合律,所以顺序不能所以调换

 1 //灵活运用矩阵优化
 2 //十进制快速幂 
 3 #include<bits/stdc++.h>
 4 using namespace std;
 5 const int MAXN=1e6+17;
 6 const int mod=1e9+7;
 7 char n[MAXN],m[MAXN];
 8 int a,b,c,d;
 9 struct Matrix{
10     int jz[2][2];
11     void init(){
12         jz[0][0]=jz[0][1]=jz[1][0]=jz[1][1]=0;
13         return ;
14     }
15     void pt(){
16         cout<<"PRT"<<endl;
17         cout<<jz[0][0]<<" "<<jz[0][1]<<endl;
18         cout<<jz[1][0]<<" "<<jz[1][1]<<endl;
19         return ;
20     }
21     friend Matrix operator *(const Matrix x,const Matrix y){
22         Matrix tmp;tmp.init();
23         tmp.jz[0][0]=(1LL*x.jz[0][0]*y.jz[0][0]%mod+1LL*x.jz[0][1]*y.jz[1][0]%mod)%mod;
24         tmp.jz[0][1]=(1LL*x.jz[0][0]*y.jz[0][1]%mod+1LL*x.jz[0][1]*y.jz[1][1]%mod)%mod;
25         tmp.jz[1][0]=(1LL*x.jz[1][0]*y.jz[0][0]%mod+1LL*x.jz[1][1]*y.jz[1][0]%mod)%mod;
26         tmp.jz[1][1]=(1LL*x.jz[1][0]*y.jz[0][1]%mod+1LL*x.jz[1][1]*y.jz[1][1]%mod)%mod;
27         return tmp;
28     }
29 }ans,base,tmp;
30 int len,jie;
31 int main(){
32     scanf("%s%s%d%d%d%d",n+1,m+1,&a,&b,&c,&d);
33     len=strlen(n+1);jie=1;
34     for(int i=len;i>=1;--i){
35         int idx=n[i]-'0';
36         if(jie<=idx){n[i]=idx-jie+'0';break; }
37         n[i]=idx+10-jie+'0',jie=1;
38     }
39     len=strlen(m+1);jie=1;
40     for(int i=len;i>=1;--i){
41         int idx=m[i]-'0';
42         if(jie<=idx){m[i]=idx-jie+'0';break; }
43         m[i]=idx+10-jie+'0',jie=1;
44     }
45     ans.init();base.init();
46     ans.jz[0][0]=ans.jz[1][1]=1;
47     base.jz[0][0]=a;base.jz[0][1]=0;
48     base.jz[1][0]=b;base.jz[1][1]=1;
49     len=strlen(m+1);
50     for(int i=len;i>=1;--i){
51         int idx=m[i]-'0';
52         while(idx--){
53             ans=ans*base;
54             //ans.pt();
55         }
56         base=base*base*base*base*base*base*base*base*base*base;
57     }
58     base.jz[0][0]=c;base.jz[0][1]=0;
59     base.jz[1][0]=d;base.jz[1][1]=1;
60     tmp=ans;base=base*ans;
61     ans.init();ans.jz[0][0]=ans.jz[0][1]=1;
62     ans=ans*tmp;
63     len=strlen(n+1);
64     for(int i=len;i>=1;--i){
65         int idx=n[i]-'0';
66         while(idx--){
67             ans=ans*base;
68             //ans.pt();
69         }
70         base=base*base*base*base*base*base*base*base*base*base;
71     }
72     printf("%d\n",ans.jz[0][0]);
73     return 0;
74 } 
View Code

P2461 [SDOI2008] 递归数列

维护k+1*k+1的矩阵,多出来的一个维护答案,因为答案是合,矩阵k位的每次更新乘对应系数,矩阵中小于k的位置直接把平移

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define LL long long
 4 int k,p;
 5 int b[110],c[110];
 6 LL n,m;
 7 struct Matrix{
 8     int jz[21][21];
 9     inline void init(){
10         for(int i=1;i<=k+1;++i)
11             for(int j=1;j<=k+1;++j)
12                 jz[i][j]=0;
13         return ;
14     }
15     inline void pt(){
16         cout<<"EEEE "<<endl;
17         for(int i=1;i<=k+1;++i,puts(""))
18             for(int j=1;j<=k+1;++j)
19                 printf("%d ",jz[i][j]);
20         return ;
21     }
22     friend Matrix operator *(const Matrix a,const Matrix b){
23         Matrix tmp;tmp.init();
24         for(int o=1;o<=k+1;++o)
25             for(int i=1;i<=k+1;++i)
26                 for(int j=1;j<=k+1;++j)
27                     tmp.jz[i][j]=(tmp.jz[i][j]+1LL*a.jz[i][o]*b.jz[o][j]%p)%p;
28         return tmp;
29     }
30 }F,a;
31 int Get(LL x){
32     if(x<=(LL)k){
33         int ans=0;
34         for(int i=1;i<=x;++i)ans=(ans+b[i])%p;
35         return ans;
36     }
37     F.init();a.init();
38     for(int i=1;i<k;++i)F.jz[1][i]=b[i],F.jz[1][k+1]=(F.jz[1][k+1]+b[i])%p;
39     F.jz[1][k]=b[k];
40     //F.pt();
41     a.jz[k+1][k+1]=1;a.jz[k][k+1]=1;
42     for(int i=1;i<k;++i)a.jz[i+1][i]=1;
43     for(int j=1;j<=k;++j)a.jz[j][k]=c[k-j+1];
44     //a.pt();
45     x-=k-1;
46     for(;x;x>>=1,a=a*a)
47         if(x&1)F=F*a;
48     //for(int i=1;i<=x;++i)F=F*a,F.pt();
49     return F.jz[1][k+1];
50 }
51 int main(){
52     scanf("%d",&k);
53     for(int i=1;i<=k;++i)scanf("%d",&b[i]);
54     for(int i=1;i<=k;++i)scanf("%d",&c[i]);
55     scanf("%lld%lld%d",&m,&n,&p);
56     //cout<<m-1<<" "<<Get(m-1)<<endl;
57     //cout<<n<<" "<<Get(n)<<endl;
58     printf("%d\n",(Get(n)-Get(m-1)+p)%p);
59     return 0;
60 } 
61 /*
62 
63 2
64 1 1
65 1 1
66 2 10 1000003
67 
68 */
View Code

 

原文地址:https://www.cnblogs.com/2018hzoicyf/p/15530949.html