BM算法学习笔记

一种nb算法,可以求出数列的递推式。

具体过程是这样的。

我们先假设它有一个递推式,然后按位去算他的值。

for(int j=0;j<now.size();++j)(delta[i]+=1ll*now[j]*f[i-j-1]%mod)%=mod;

这是我们算出了f[i]应当是多少,但是f[i]有可能不是我们算出的值,所以我们记录一个delta,为我们算出的值减去f[i]的结果。

然后查看一下之前有没有出过锅。

如果出过,那么就补一个0,然后塞过去。。

if(!cnt){now.resize(i);cnt++;continue;}

cnt记录出锅次数,now记录当前递推式。

然后我们需要构造一个递推式把这一位的delta补上。

然后我们设inv为这一次的dalta除以上一次的delta。

然后我们的递推式就是在last和now之间补0,然后加一个inv,后面把所有的pre*(-inv)加进去,这样最后n这个位置会出现-delta就满足我们的要求了。

最后我们把构造递推式和当前递推式加起来。

再贪心选一个左端点最靠右的出锅递推式作为last。

正确性???

代码

#include<iostream>
#include<cstdio>
#include<vector>
#define N 100009
using namespace std;
typedef long long ll;
const ll mod=65521;
ll n,f[N],delta[N],fail[N],cnt,last;
vector<ll>cur,now,pre;
inline int rd(){
    int x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;
    }
    return ans;
}

int main(){
    n=rd();
    for(int i=1;i<=n;++i)f[i]=rd(); 
    for(int i=1;i<=n;++i){
       delta[i]=mod-f[i];
       for(int j=0;j<now.size();++j)(delta[i]+=1ll*now[j]*f[i-j-1]%mod)%=mod;
       if(!delta[i])continue;
       fail[cnt]=i;
       if(!cnt){now.resize(i);cnt++;continue;}
       ll inv=((mod-1ll*delta[i]*power(delta[fail[last]],mod-2)%mod)%mod+mod)%mod;
       cur.clear();cur.resize(i-fail[last]-1);cur.push_back(mod-inv);
       for(int j=0;j<pre.size();++j)cur.push_back(1ll*pre[j]*inv%mod);
       if(now.size()>cur.size())cur.resize(now.size());
       for(int j=0;j<now.size();++j)(cur[j]+=now[j])%=mod;
       if(i-now.size()>=fail[last]-pre.size())pre=now,last=cnt; //fail[last]!!!
       cnt++;now=cur; 
    } 
    for(int i=0;i<now.size();++i)cout<<now[i]<<",";cout<<now.size();
    return 0;
} 

应用

[NOI2007]生成树计数

正解貌似是插头dp+快速幂。

然后我们发现k非常小。。。。

那么就可以对于每一个k打一个表,然后再扔到BM里跑一下,发现转移式子最大只有45。

于是就直接上矩乘。

代码

打表

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 402
using namespace std;
typedef long long ll;
ll a[N][N],n;
const int mod=65521;
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;
    }
    return ans;
}
inline ll ni(ll x){return power(x,mod-2);}
inline ll matr(int n){
    for(int i=1;i<=n;++i){
        for(int j=i+1;j<=n;++j){
            ll x=1ll*a[j][i]*ni(a[i][i])%mod;
            for(int k=i;k<=n;++k)a[j][k]=(a[j][k]-x*a[i][k]%mod+mod)%mod;
        }
    }
    ll ans=1;
    for(int i=1;i<=n;++i)ans=ans*a[i][i]%mod;
    return ans;
}
int main(){
    freopen("out","w",stdout);
    int kk=2;
    for(int n=1;n<=45;++n){
        memset(a,0,sizeof(a));
        for(int i=1;i<=n;++i){
          for(int k=i-1;k>=1&&k>=i-kk;--k)a[i][i]++,a[i][k]--;
          for(int k=i+1;k<=n&&k<=i+kk;++k)a[i][i]++,a[i][k]--;
        }
        printf("%lld,",matr(n-1));
    }
    return 0;
}

矩阵乘法

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int mod=65521;
ll top,n;
int s1[2]={0,1};
int s2[4]={0,3,65520,0};
int s3[8]={0,5,65518,3,65516,1,0,0};
int s4[18]={0,7,65520,65496,31,65469,65437,300,65437,65469,31,65496,65520,7,65520,0,0,0};
int s5[46]={0,8,5,65489,40,364,63172,62845,2793,7304,50170,14272,13974,
32712,27590,63226,30516,31431,62449,44809,2992,62529,20712,3072,34090,35005,2295,
37931,32809,51547,51249,15351,58217,62728,2676,2349,65157,65481,32,65516,65513,1,0,0,0,0};
int a1[2]={0,1};
int a2[4]={0,1,1,3};
int a3[8]={0,1,1,3,16,75,336,1488};
int a4[18]={0,1,1,3,16,125,864,5635,35840,29517,48795,64376,52310,4486,28336,8758,64387,31184};
int a5[46]={0,1,1,3,16,125,1296,12005,38927,26915,65410,9167,63054,58705,18773,9079,38064,46824,
48121,50048,47533,30210,24390,51276,45393,357,44927,15398,15923,31582,56586,25233,41258,21255,
21563,16387,39423,26418,10008,6962,42377,50881,54893,50452,23715,53140};
inline ll power(ll x,ll y){
    ll ans=1;
    while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
    return ans;
}
struct matrix{
    ll a[48][48];
    matrix(){memset(a,0,sizeof(a));}
    matrix operator *(const matrix &b)const{
        matrix c;
        for(int i=1;i<=top;++i)
          for(int j=1;j<=top;++j){
              for(int k=1;k<=top;++k)
              (c.a[i][j]+=a[i][k]*b.a[k][j]%mod)%=mod; 
          }
        return c;
    }
}ans,Z;
inline void work1(){
    puts("1");
}
inline void work2(){
    if(n<=3){printf("%d
",a2[n]);return;}
    for(int i=1;i<=3;++i){
        ans.a[1][i]=a2[i];
        Z.a[i][3]=s2[3-i+1];
        if(i!=1)Z.a[i][i-1]=1;
    }
    n-=3;top=3;
    while(n){
        if(n&1)ans=ans*Z;
        Z=Z*Z;
        n>>=1;
    }
    printf("%lld",ans.a[1][3]); 
}
inline void work3(){
    if(n<=7){printf("%d
",a3[n]);return;}
    for(int i=1;i<=7;++i){
        ans.a[1][i]=a3[i];
        Z.a[i][7]=s3[7-i+1];
        if(i!=1)Z.a[i][i-1]=1;
    }
    n-=7;top=7;
    while(n){
        if(n&1)ans=ans*Z;
        Z=Z*Z;
        n>>=1;
    }
    printf("%lld",ans.a[1][7]); 
}
inline void work4(){
    if(n<=17){printf("%d
",a4[n]);return;}
    for(int i=1;i<=17;++i){
        ans.a[1][i]=a4[i];
        Z.a[i][17]=s4[17-i+1];
        if(i!=1)Z.a[i][i-1]=1;
    }
    n-=17;top=17;
    while(n){
        if(n&1)ans=ans*Z;
        Z=Z*Z;
        n>>=1;
    }
    printf("%lld",ans.a[1][17]); 
}
inline void work5(){
    if(n<=45){printf("%d
",a5[n]);return;}
    for(int i=1;i<=45;++i){
        ans.a[1][i]=a5[i];
        Z.a[i][45]=s5[45-i+1];
        if(i!=1)Z.a[i][i-1]=1;
    }
    n-=45;top=45;
    while(n){
        if(n&1)ans=ans*Z;
        Z=Z*Z;
        n>>=1;
    }
    printf("%lld",ans.a[1][45]); 
}
int main(){
    int k;
    cin>>k>>n;
    if(k==1)work1();
    else if(k==2)work2();
    else if(k==3)work3();
    else if(k==4)work4();
    else if(k==5)work5();
    return 0;
}
原文地址:https://www.cnblogs.com/ZH-comld/p/10306313.html