bzoj4128 Matrix

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4128

【题解】

矩阵版本的BSGS。

至于如何不需要求逆,详见:http://www.cnblogs.com/galaxies/p/bzoj2480.html

# include <map>
# include <math.h>
# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 5e5 + 10;

# define RG register
# define ST static

int n, mod;

struct matrix {
    int n, a[72][72];
    inline void init(int _n) {
        n = _n;
        memset(a, 0, sizeof a);
    }
    inline void set(int _n) {
        n = _n;
        for (int i=1; i<=n; ++i)
            for (int j=1; j<=n; ++j)
                scanf("%d", &a[i][j]);
    }
    friend matrix operator * (matrix a, matrix b) {
        matrix c; c.init(a.n);
        for (int i=1; i<=a.n; ++i)
            for (int j=1; j<=a.n; ++j)
                for (int k=1; k<=a.n; ++k) {
                    c.a[i][j] += 1ll * a.a[i][k] * b.a[k][j] % mod;
                    if(c.a[i][j] >= mod) c.a[i][j] -= mod;
                }
        return c;
    }
    friend matrix operator ^ (matrix a, int b) {
        matrix c; c.init(a.n);
        for (int i=1; i<=a.n; ++i) c.a[i][i] = 1;
        while(b) {
            if(b&1) c = c * a;
            a = a * a;
            b >>= 1;
        }
        return c;
    }
    friend bool operator == (matrix a, matrix b) {
        for (int i=1; i<=a.n; ++i)
            for (int j=1; j<=a.n; ++j) 
                if(a.a[i][j] != b.a[i][j]) return 0;
        return 1;
    }
    inline ull ghash() {
        ull ret = 0;
        for (int i=1; i<=n; ++i)
            for (int j=1; j<=n; ++j)
                ret = ret * 20001130ull + a[i][j];
        return ret;
    }
}A, B;

map<ull, int> mp;

inline int BSGS(int P) {
    mp.clear();
    int m = ceil(sqrt(1.0 * P));
    matrix t = B; ull tem;
    for (int i=0; i<m; ++i) {
        mp[t.ghash()] = i;
        t = t * A;
    }
    matrix g = A^m; t = g;
    for (int i=1; i<=m+1; ++i) {
        tem = t.ghash();
        if(mp.count(tem)) return i * m - mp[tem];
        t = t * g;
    }
    return -1;
}

int main() {
    cin >> n >> mod;
    A.set(n);
    B.set(n);
    cout << BSGS(mod) << endl;
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/galaxies/p/bzoj4128.html