HDU 4965 Fast Matrix Calculation

按题意的步骤来显然是不行的。算一次最坏需要o(1000*1000*6)复杂度,需要算log(n*n)次,显然超时。

需要转换一下公式,考虑到K只有6,所以可以考虑转换成BA来做。

#include<cstdio>
#include<cstring>
#include<cmath>
#include<vector>
#include<algorithm>
using namespace std;

int N, M;
const long long m = 6;

struct Matrix
{
    long long A[10][10];
    int R, C;
    Matrix operator*(Matrix b);
};

Matrix X, Y;
long long A[1000 + 10][1000 + 10];
long long B[1000 + 10][1000 + 10];
long long ANS1[1000 + 10][1000 + 10];
long long ANS2[1000 + 10][1000 + 10];

Matrix Matrix::operator*(Matrix b)
{
    Matrix c;
    int i, j, k;
    for (i = 1; i <= R; i++)
        for (j = 1; j <= b.C; j++){
            c.A[i][j] = 0;
            for (k = 1; k <= C; k++)
                c.A[i][j] = (c.A[i][j] + (A[i][k] * b.A[k][j]) % m) % m;
        }
    c.R = R; c.C = b.C;
    return c;
}

void init()
{
    memset(Y.A, 0, sizeof Y.A);
    for (int i = 1; i <= M; i++) Y.A[i][i] = 1; Y.R = M; Y.C = M;
    X.R = M; X.C = M;
    for (int i = 1; i <= M; i++)
    {
        for (int j = 1; j <= M; j++)
        {
            X.A[i][j] = 0;
            for (int k = 1; k <= N; k++)
                X.A[i][j] = (X.A[i][j] + (B[i][k] * A[k][j]) % m) % m;
        }
    }
}

void work()
{
    int tmp = N*N - 1;
    while (tmp)
    {
        if (tmp % 2 == 1) Y = Y*X;
        tmp = tmp >> 1;
        X = X*X;
    }

    for (int i = 1; i <= N; i++)
    {
        for (int j = 1; j <= M; j++)
        {
            ANS1[i][j] = 0;
            for (int k = 1; k <= M; k++)
                ANS1[i][j] = (ANS1[i][j] + (A[i][k] * Y.A[k][j]) % m) % m;
        }
    }

    for (int i = 1; i <= N; i++)
    {
        for (int j = 1; j <= N; j++)
        {
            ANS2[i][j] = 0;
            for (int k = 1; k <= M; k++)
                ANS2[i][j] = (ANS2[i][j] + (ANS1[i][k] * B[k][j]) % m) % m;
        }
    }

    long long sum = 0;
    for (int i = 1; i <= N; i++)
        for (int j = 1; j <= N; j++)
            sum = sum + ANS2[i][j];
    printf("%lld
", sum);
}

void read()
{
    for (int i = 1; i <= N; i++) for (int j = 1; j <= M; j++) scanf("%d", &A[i][j]);
    for (int i = 1; i <= M; i++) for (int j = 1; j <= N; j++) scanf("%d", &B[i][j]);
}

int main()
{
    while (~scanf("%d%d", &N, &M))
    {
        if (N == 0 && M == 0) break;
        read();
        init();
        work();
    }
    return 0;
}
原文地址:https://www.cnblogs.com/zufezzt/p/5242404.html