bzoj3157国王奇遇记(秦九韶算法+矩乘)&&bzoj233AC达成

bz第233题,用一种233333333的做法过掉了(为啥我YY出一个算法来就是全网最慢的啊...)

题意:求sigma{(i^m)*(m^i),1<=i<=n},n<=10^9,m<=200

别人的做法: O(m^2logn),O(m^2),甚至O(m)的神做法

学渣的做法:矩乘+秦九韶算法,O(m^3logn),刚好可以过最弱版本的国王奇遇记的数据

(极限数据单点其实是1.2s+,不想继续卡常了…bzoj卡总时限使人懒惰…如果把矩乘的封装拆掉可能会快点吧,然而人弱懒得拆了...)

首先考虑这么一道题:求sigma{i*(m^i),1<=i<=n},n<=10^9

举一个m=2,n=4的例子:

sigma{i*(m^i),1<=i<=n}   =       1*(2^1)+2*(2^2)+3*(2^3)+4*(2^4)

                                               =       2*(1+2*(2^1)+3*(2^2)+4*(2^3))

                                               =       2*(1+2*(2+3*2+4*2^2))

                                               =       2*(1+2*(2+2*(3+4*2)))

从括号最里面向外计算,那么只需要从0的基础上,依次加4乘2;加3乘2;加2乘2;加1乘2。

这样的运算是很有规律的,我们可以构造一个矩阵用矩阵快速幂来进行计算.

如果是sigma{i*i*(m^i),1<=i<=n}?我们需要在矩阵中保存一个i^2,这时候利用(i-1)^2=i^2-2*i+1,在矩阵中同时保存i和i^2即可

对于国王奇遇记这道题,我们需要从0的基础上,依次加n^m再乘m;加(n-1)^m再乘m,加(n-2)^m再乘m…..

看似矩阵中从n^m到(n-1)^m的转换较难完成

利用二项式定理,在矩阵中存储n^1,n^2,n^3,…n^m,就可以完成转移.

#include<cstdio>
#include<cstring>
//#include<ctime>
#include<algorithm>
using namespace std;
const int mod=1000000007,maxn=202;
int sz;
struct matrix{
  int a[maxn][maxn];
  matrix(){
    memset(a,0,sizeof(a));
  }
  matrix(int x){
    memset(a,0,sizeof(a));
    for(int i=0;i<maxn;++i)a[i][i]=1;
  }
  matrix operator *(const matrix &B)const{
    matrix C;
    for(int i=0;i<=sz;++i){
      for(int j=0;j<=sz;++j){
    if(a[i][j]==0)continue;
    for(int k=0;k<=sz;++k){
      if(B.a[j][k]==0)continue;
      C.a[i][k]=(C.a[i][k]+a[i][j]*1LL*B.a[j][k])%mod;
    }
      }
    }
    return C;
  }
}A,ANS(1),B;
int n,m;
void build_matrix(){
  A.a[0][0]=1;
  for(int i=1;i<sz;++i){
    A.a[0][i]=1;
    for(int j=1;j<=i;++j){
      A.a[j][i]=(A.a[j-1][i-1]+A.a[j][i-1])%mod;
    }
  }
  for(int i=1;i<sz;++i){
    for(int j=0;j<=i;++j){
      if((j^i)&1)A.a[j][i]=(mod-A.a[j][i])%mod;
    }
  }
  A.a[sz][sz]=m;A.a[sz-1][sz]=1;
  // for(int i=0;i<=sz;++i){
  //   for(int j=0;j<=sz;++j)printf("%d ",A.a[i][j]);
  //   printf("
");
  // }
}
void quickpow(int x){
  // double t1=clock();
  for(;x;x>>=1,A=A*A){//printf("!");
    if(x&1)ANS=ANS*A;
  }
  // double t2=clock();
}
int pwr[maxn];
int main(){
  scanf("%d%d",&n,&m);
  pwr[0]=1;
  for(int i=1;i<=m;++i){
    pwr[i]=pwr[i-1]*1LL*n%mod;
  }
  sz=m+1;
  build_matrix();
  quickpow(n);
  int ans=0;
  for(int i=0;i<=m;++i){
    ans=(ans+ANS.a[i][sz]*1LL*pwr[i])%mod;
  }
  printf("%lld
",ans*1LL*m%mod);
  return 0;
}
原文地址:https://www.cnblogs.com/liu-runda/p/6208830.html