矩阵快速幂

不会矩阵乘法和快速幂的同学请按Ctrl+W,以便您有更好的学习体验,不要忘记找度娘问一下哦(⊙o⊙)哦


给定n*n的矩阵A,求A^k

n<=100, k<=10^12, |矩阵元素|<=1000

显然,朴素的矩阵乘法求幂是不可能过的,这辈子是不可能了

那么就要引入一种新的算法------矩阵快速幂

快速幂大家都知道,是将指数化为了二进制进行运算

那么知道了这些,矩阵快速幂也就变得简单了起来

首先声明一个性质,

如下的一个单位矩阵,任何矩阵和他相乘得到的都是这个矩阵本身

$$
egin{bmatrix}
1&0&0\
0&1&0\
0&0&1
end{bmatrix}
$$

那我们就能按照快速幂的方法来实现矩阵快速幂了

说白了矩阵快速幂只是把矩阵快速幂的乘法操作变成了矩阵乘法的操作

首先手动造一个单位矩阵,只需要在输入的时候把对角线全赋值为1就行,向下面一样

for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
                s.mat[i][j] = read();
        }
        Ans.mat[i][i] = 1;
}

  

接下来需要分解K,这个也很简单,每次都进行(K&1)的操作,如果等于1,那么就把base矩阵变为平方。

一直到K等于0结束,代码如下

void quick_pow() {
    while(k != 0) {
        if(k & 1) {
            Ans = mat_mul(Ans, base);
        }
        base = mat_mul(base, base);
        k >>= 1;
    }
}

  

当然你需要写一个矩阵乘法的函数像下面这样

matrix mat_mul(matrix a, matrix b) {
    matrix res;
    memset(res.mat, 0, sizeof(res.mat));
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            for(int p=1; p<=n; p++) {
                res.mat[i][j] += (a.mat[i][p]%Mod)*(b.mat[p][j]%Mod);
                res.mat[i][j] %= Mod;
            }
        }
    }
    return res;
}

  

所有的代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 107
#define Mod int(1e9+7)

using namespace std;

long long n, k;

struct matrix {
    long long mat[MAXN][MAXN];
};

matrix s, base, Ans;

inline long long read() {
    long long x = 0, f = 1;
    char c = getchar();
    while(c < '0'||c > '9') {
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c <= '9'&&c >= '0') {
        x = x*10+c-'0';
        c = getchar();
    }
    return x * f;
}

matrix mat_mul(matrix a, matrix b) {
    matrix res;
    memset(res.mat, 0, sizeof(res.mat));
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            for(int p=1; p<=n; p++) {
                res.mat[i][j] += (a.mat[i][p]%Mod)*(b.mat[p][j]%Mod);
                res.mat[i][j] %= Mod;
            }
        }
    }
    return res;
}

void quick_pow() {
    while(k != 0) {
        if(k & 1) {
            Ans = mat_mul(Ans, base);
        }
        base = mat_mul(base, base);
        k >>= 1;
    }
}

int main() {
    n = read(), k = read();
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            s.mat[i][j] = read();
        }
        Ans.mat[i][i] = 1;
    }
    base = s;
    quick_pow();
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            printf("%lld ", Ans.mat[i][j]%Mod);
        }
        printf("
");
    }
}

  

这就完成了,客官不要忘了给我点赞哦

原文地址:https://www.cnblogs.com/bljfy/p/9191984.html