学习BM算法

BM算法:

希望大家别见怪,当前博客只用于个人记录所用。

【例题】Poor God Water

题意:

有肉,鱼,巧克力三种食物,有几种禁忌,对于连续的三个食物,

1.这三个食物不能都相同;

2.若三种食物都有的情况,巧克力不能在中间;

3.如果两边是巧克力,中间不能是肉或鱼。

求方案数

要求任意连续三个小时不能出现aaa,bbb,ccc,abc,cba,bab,bcb (假设b为巧克力)

然后进行推导,其实可以用矩阵快速幂 或者 BM算法。

复制粘贴一下CJY学长的代码:

 1 #include <bits/stdc++.h>
 2 
 3 using namespace std;
 4 typedef long long ll;
 5 
 6 const int N=9;
 7 struct Matrix{
 8     ll matrix[N][N];
 9 };
10 
11 const int mod = 1e9 + 7;
12 
13 void init(Matrix &res)
14 {
15     memset(res.matrix,0,sizeof(res.matrix));
16     for(int i=0;i<N;i++)
17         res.matrix[i][i]=1;
18 }
19 Matrix multiplicative(Matrix a,Matrix b)
20 {
21     Matrix res;
22     memset(res.matrix,0,sizeof(res.matrix));
23     for(int i = 0 ; i < N ; i++){
24         for(int j = 0 ; j < N ; j++){
25             for(int k = 0 ; k < N ; k++){
26                 res.matrix[i][j] += a.matrix[i][k]*b.matrix[k][j];
27                 res.matrix[i][j] %= mod;
28             }
29         }
30     }
31     return res;
32 }
33 Matrix pow(Matrix mx,ll m)
34 {
35     Matrix res,base=mx;
36     init(res);
37     while(m)
38     {
39         if(m&1)
40             res=multiplicative(res,base);
41         base=multiplicative(base,base);
42         m>>=1;
43     }
44     return res;
45 }
46 int main()
47 {
48     int t;
49     scanf("%d",&t);
50     while(t--)
51     {
52         ll n,ant = 0;
53         scanf("%lld",&n);
54         if(n == 1)  printf("3
");
55         else if(n == 2) printf("9
");
56         else
57         {
58             Matrix res,ans = {
59                 0,0,0, 1,0,0, 1,0,0,
60                 1,0,0, 0,0,0, 1,0,0,
61                 1,0,0, 1,0,0, 1,0,0,
62 
63                 0,1,0, 0,1,0, 0,0,0,
64                 0,1,0, 0,0,0, 0,1,0,
65                 0,0,0, 0,1,0, 0,1,0,
66 
67                 0,0,1, 0,0,1, 0,0,1,
68                 0,0,1, 0,0,0, 0,0,1,
69                 0,0,1, 0,0,1, 0,0,0
70             };
71             res=pow(ans,n-2);
72 
73             for(int i = 0;i < N;i++)
74                 for(int j = 0;j < N;j++)
75                     ant = (ant + res.matrix[i][j]) % mod;
76             printf("%lld
",ant);
77         }
78     }
79     return 0;
80 }
矩阵快速幂
  1 #include<bits/stdc++.h>
  2 #define rep(i,a,n) for (int i=a;i<n;i++)
  3 #define per(i,a,n) for (int i=n-1;i>=a;i--)
  4 #define pb push_back
  5 #define mp make_pair
  6 #define all(x) (x).begin(),(x).end()
  7 #define fi first
  8 #define se second
  9 #define SZ(x) ((int)(x).size())
 10 using namespace std;
 11 typedef vector<int> VI;
 12 typedef long long ll;
 13 typedef pair<int,int> PII;
 14 const ll mod = 1e9+7;
 15 const int N = 30;
 16 ll powmod(ll a,ll b) {
 17     ll res=1;a%=mod; assert(b>=0);
 18     for(;b;b>>=1){
 19         if(b&1)res=res*a%mod;
 20         a=a*a%mod;
 21     }
 22     return res;
 23 }
 24 
 25 /*
 26     BM模板
 27     begin
 28 */
 29 
 30 // head
 31 
 32 int _,n;
 33 namespace linear_seq {
 34     const int N=10010;
 35     ll res[N],base[N],_c[N],_md[N];
 36 
 37     vector<int> Md;
 38     void mul(ll *a,ll *b,int k) {
 39         rep(i,0,k+k) _c[i]=0;
 40         rep(i,0,k) if (a[i]) rep(j,0,k) _c[i+j]=(_c[i+j]+a[i]*b[j])%mod;
 41         for (int i=k+k-1;i>=k;i--) if (_c[i])
 42             rep(j,0,SZ(Md)) _c[i-k+Md[j]]=(_c[i-k+Md[j]]-_c[i]*_md[Md[j]])%mod;
 43         rep(i,0,k) a[i]=_c[i];
 44     }
 45     int solve(ll n,VI a,VI b) { // a 系数 b 初值 b[n+1]=a[0]*b[n]+...
 46 //        printf("%d
",SZ(b));
 47         ll ans=0,pnt=0;
 48         int k=SZ(a);
 49         assert(SZ(a)==SZ(b));
 50         rep(i,0,k) _md[k-1-i]=-a[i];_md[k]=1;
 51         Md.clear();
 52         rep(i,0,k) if (_md[i]!=0) Md.push_back(i);
 53         rep(i,0,k) res[i]=base[i]=0;
 54         res[0]=1;
 55         while ((1ll<<pnt)<=n) pnt++;
 56         for (int p=pnt;p>=0;p--) {
 57             mul(res,res,k);
 58             if ((n>>p)&1) {
 59                 for (int i=k-1;i>=0;i--) res[i+1]=res[i];res[0]=0;
 60                 rep(j,0,SZ(Md)) res[Md[j]]=(res[Md[j]]-res[k]*_md[Md[j]])%mod;
 61             }
 62         }
 63         rep(i,0,k) ans=(ans+res[i]*b[i])%mod;
 64         if (ans<0) ans+=mod;
 65         return ans;
 66     }
 67     VI BM(VI s) {
 68         VI C(1,1),B(1,1);
 69         int L=0,m=1,b=1;
 70         rep(n,0,SZ(s)) {
 71             ll d=0;
 72             rep(i,0,L+1) d=(d+(ll)C[i]*s[n-i])%mod;
 73             if (d==0) ++m;
 74             else if (2*L<=n) {
 75                 VI T=C;
 76                 ll c=mod-d*powmod(b,mod-2)%mod;
 77                 while (SZ(C)<SZ(B)+m) C.pb(0);
 78                 rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
 79                 L=n+1-L; B=T; b=d; m=1;
 80             } else {
 81                 ll c=mod-d*powmod(b,mod-2)%mod;
 82                 while (SZ(C)<SZ(B)+m) C.pb(0);
 83                 rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
 84                 ++m;
 85             }
 86         }
 87         return C;
 88     }
 89     int gao(VI a,ll n) {
 90         VI c=BM(a);
 91         c.erase(c.begin());
 92         rep(i,0,SZ(c)) c[i]=(mod-c[i])%mod;
 93         return solve(n,c,VI(a.begin(),a.begin()+SZ(c)));
 94     }
 95 };
 96 
 97 /*
 98     end
 99 */
100 
101 
102 ll a[10][2],ans[40];
103 
104 void Init(){
105 
106     int op = 0 ;
107     for(int i=1;i<=9;i++){
108         a[i][op] = 1 ;
109     }
110     for(int i=3;i<=15;i++){
111         op ^= 1 ;
112         for(int j=1;j<=9;j++)   a[j][op] = 0;
113 
114         a[1][op] = (a[4][op^1] + a[7][op^1]) % mod;
115         a[2][op] = (a[1][op^1] + a[7][op^1]) % mod;
116         a[3][op] = (a[1][op^1] + a[4][op^1] + a[7][op^1]) % mod ;
117         a[4][op] = (a[2][op^1] + a[5][op^1]) % mod ;
118         a[5][op] = (a[2][op^1] + a[8][op^1]) % mod ;
119         a[6][op] = (a[5][op^1] + a[8][op^1]) % mod ;
120         a[7][op] = (a[3][op^1] + a[6][op^1] + a[9][op^1]) % mod ;
121         a[8][op] = (a[3][op^1] + a[9][op^1]) % mod ;
122         a[9][op] = (a[3][op^1] + a[6][op^1]) % mod ;
123 
124         for(int j=1;j<=9;j++){
125             ans[i] = (ans[i] + a[j][op]) % mod  ;
126         }
127         //printf("%lld
",ans[i]);
128     }
129 }
130 
131 vector <int> Vec = { 3,9,20,46,106,244,560,
132                      1286,2956,6794,15610,35866,
133                      82416,189384,435170 };
134 
135 int main()
136 {
137     int T;
138 
139     Init();
140     for( scanf("%d",&T) ; T ; T-- ){
141         ll n;
142         scanf("%lld",&n);
143         printf("%lld
",linear_seq::gao(Vec,n-1)%mod);
144     }
145     return 0;
146 }
BM算法

牛客多校训练2 B.Eddy Walker 2

2019牛客暑期多校训练营(第二场) - B - Eddy Walker 2 - BM算法

P4723 【模板】线性递推 题解

【学习笔记】Berlekamp-Massey算法

 

  1 #include<cstdio>
  2 #include<vector>
  3 #include<cstring>
  4 #include<algorithm>
  5 typedef long long ll ;
  6 using namespace std;
  7 
  8 #define rep(i,a,n) for (int i=a;i<n;i++)
  9 #define per(i,a,n) for (int i=n-1;i>=a;i--)
 10 #define pb push_back
 11 #define mp make_pair
 12 #define all(x) (x).begin(),(x).end()
 13 #define fi first
 14 #define se second
 15 #define SZ(x) ((int)(x).size())
 16 using namespace std;
 17 typedef vector<ll> VI;
 18 typedef pair<ll,ll> PII;
 19 const ll mod = 1e9+7;
 20 const int N = 5e4+10;
 21 ll powmod(ll a,ll b) {
 22     ll res=1;a%=mod;
 23     for(;b;b>>=1){
 24         if(b&1)res=res*a%mod;
 25         a=a*a%mod;
 26     }
 27     return res;
 28 }
 29 
 30 /*
 31     BM模板
 32     begin
 33 */
 34 
 35 // head
 36 
 37 int _,n;
 38 namespace linear_seq {
 39     const int N=10010;
 40     ll res[N],base[N],_c[N],_md[N];
 41 
 42     vector<int> Md;
 43     void mul(ll *a,ll *b,int k) {
 44         rep(i,0,k+k) _c[i]=0;
 45         rep(i,0,k) if (a[i]) rep(j,0,k) _c[i+j]=(_c[i+j]+a[i]*b[j])%mod;
 46         for (int i=k+k-1;i>=k;i--) if (_c[i])
 47                 rep(j,0,SZ(Md)) _c[i-k+Md[j]]=(_c[i-k+Md[j]]-_c[i]*_md[Md[j]])%mod;
 48         rep(i,0,k) a[i]=_c[i];
 49     }
 50     int solve(ll n,VI a,VI b) { // a 系数 b 初值 b[n+1]=a[0]*b[n]+...
 51 //        printf("%d
",SZ(b));
 52         ll ans=0,pnt=0;
 53         int k=SZ(a);
 54         rep(i,0,k) _md[k-1-i]=-a[i];_md[k]=1;
 55         Md.clear();
 56         rep(i,0,k) if (_md[i]!=0) Md.push_back(i);
 57         rep(i,0,k) res[i]=base[i]=0;
 58         res[0]=1;
 59         while ((1ll<<pnt)<=n) pnt++;
 60         for (int p=pnt;p>=0;p--) {
 61             mul(res,res,k);
 62             if ((n>>p)&1) {
 63                 for (int i=k-1;i>=0;i--) res[i+1]=res[i];res[0]=0;
 64                 rep(j,0,SZ(Md)) res[Md[j]]=(res[Md[j]]-res[k]*_md[Md[j]])%mod;
 65             }
 66         }
 67         rep(i,0,k) ans=(ans+res[i]*b[i])%mod;
 68         if (ans<0) ans+=mod;
 69         return ans;
 70     }
 71     VI BM(VI s) {
 72         VI C(1,1),B(1,1);
 73         int L=0,m=1,b=1;
 74         rep(n,0,SZ(s)) {
 75             ll d=0;
 76             rep(i,0,L+1) d=(d+(ll)C[i]*s[n-i])%mod;
 77             if (d==0) ++m;
 78             else if (2*L<=n) {
 79                 VI T=C;
 80                 ll c=mod-d*powmod(b,mod-2)%mod;
 81                 while (SZ(C)<SZ(B)+m) C.pb(0);
 82                 rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
 83                 L=n+1-L; B=T; b=d; m=1;
 84             } else {
 85                 ll c=mod-d*powmod(b,mod-2)%mod;
 86                 while (SZ(C)<SZ(B)+m) C.pb(0);
 87                 rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
 88                 ++m;
 89             }
 90         }
 91         return C;
 92     }
 93     int gao(VI a,ll n) {
 94         VI c=BM(a);
 95         c.erase(c.begin());
 96         rep(i,0,SZ(c)) c[i]=(mod-c[i])%mod;
 97         return solve(n,c,VI(a.begin(),a.begin()+SZ(c)));
 98     }
 99 };
100 
101 /*
102     end
103 */
104 ll dp[N] ;
105 int main()
106 {
107     int T;
108     for( scanf("%d",&T) ; T ; T-- ){
109         memset(dp,0,sizeof(dp));
110         ll n,k ;
111         VI v;
112         scanf("%lld%lld",&k,&n);
113         if( n==0 ){
114             printf("1
");
115         }else if( n==-1 ){
116             printf("%lld
",(2ll) * powmod(k+1,mod-2) % mod );
117         }else{
118             ll Inv_k = powmod( k ,mod-2) ;
119             dp[0] = 1 ;
120             v.push_back(1);
121             for(int i=1;i<=k;i++){
122                 for(int j=0;j<i;j++){
123                     dp[i] = (dp[i] + dp[j]) % mod;
124                 }
125                 dp[i] = dp[i] * Inv_k % mod ;
126                 v.push_back(dp[i]);
127             }
128             for(int i=k+1;i<=2*k;i++){
129                 for(int j=1;j<=k;j++){
130                     dp[i] = (dp[i] + dp[i-j]) % mod ;
131                 }
132                 dp[i] = dp[i] * Inv_k % mod ;
133                 v.push_back(dp[i]);
134             }
135             printf("%lld
",linear_seq::gao(v,n));
136         }
137     }
138     return 0;
139 }
BM
原文地址:https://www.cnblogs.com/Osea/p/11242723.html