Hihocoder-1286 子矩阵求和

解题思路:

看到这个题目的时候是很懵逼的= =矩阵是无限的

但是其实没那么刚,只需要巧妙的转换下就可以得到结果。

对于矩阵:

1 1 1 1 1 1 1 1
1 2 2 2 2 2 2 2
1 2 3 3 3 3 3 3
1 2 3 4 4 4 4 4
1 2 3 4 5 5 5 5
1 2 3 4 5 6 6 6
1 2 3 4 5 6 7 7
1 2 3 4 5 6 7 8
对于这样的矩阵,我们能看到两个规律:

用s[i][j]表示左上角左边为(i, j)的子矩阵和

1.对于要求的N*M的子矩阵,当行数超过M后,s[i][j] = s[i+1][j](i >= M),当列数超过N的时候,s[i][j] = s[i][j+1](j >= N)

2.s[i][j] + x * N * M = s[i+x][j+x]

所以这题的做法就出来了,只需要枚举第一行和第一列,对于每个位置,求(s[i][j] + x*N*M) mod k = 0的最小x,这个时候很自然的想到扩展欧几里得算法。所以求一发就解出来了。


代码:

#include <set>
#include <map>
#include <cmath>
#include <queue>
#include <stack>
#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;

typedef long long LL;
const LL inf = 1e18 + 5;

LL x, y, ans_x, ans_y;

LL extend_gcd(LL a, LL b, LL& x, LL &y) {
    if ( b == 0 ) { x = 1; y = 0; return a; }
    LL ans = extend_gcd( b, a % b, y, x );
    y -= a / b * x;
    return ans;
}
LL solve( LL n, LL m, LL cur, LL sum ) {
    LL i = n + cur - 1;
    if ( i > m ) sum += (m + 1) * m / 2;
    else{
        sum += (i + 1) * i / 2;
        sum += i * (m - i);
    }
    if ( cur - 1 > m ) sum -= (m + 1) * m / 2;
    else {
        sum -= cur * (cur - 1) / 2;
        sum -= (cur - 1) * (m - cur + 1);
    }
    return sum;
}
bool judge( LL a, LL b, LL c, LL d, LL base1, LL base2 ) {
    if ( (c % d) == 0 ) {
        a /= d; b /= d; c /= d;
        LL tmp = ((x % b) * (c % b)) % b;
        while(tmp < 0) tmp += b >= 0 ? b : -b;

        LL tt = tmp * 2 + base1 + base2;
        if ( tt < ans_x + ans_y || (tt == ans_x + ans_y && tmp + base1 < ans_x) || (tt == ans_x + ans_y && tmp + base1 == ans_x && tmp + base2 < ans_y) ) {
            ans_x = base1 + tmp;
            ans_y = base2 + tmp;
        }
        return true;
    }
    return false;
}
int main() {
    LL q, n, m, k;
    cin >> q;
    while(q--) {
        LL init = 0;
        cin >> n >> m >> k;
        for ( LL i = 1; i <= n; ++i ) {
            if ( i > m ) init += (m + 1) * m / 2;
            else {
                init += (i + 1) * i / 2;
                init += i * (m - i);
            }
        }

        ans_x = ans_y = inf;
        bool flag = false;
        LL A = -n * m, B = k, C = init;

        LL D = extend_gcd(A, B, x, y);
        flag |= judge(A, B, C, D, 1, 1);

        LL sum = init;
        for ( LL i = 2; i <= n; ++i ) {
            sum = solve(m, n, i, sum);
            flag |= judge(A, B, sum, D, 1, i);
        }

        sum = init;
        for ( LL i = 2; i <= m; ++i ) {
            sum = solve(n, m, i, sum);
            flag |= judge(A, B, sum, D, i, 1);
        }

        if ( !flag ) cout << "-1" << endl;
        else cout << ans_x << " " << ans_y << endl;
    }
    return 0;
}


原文地址:https://www.cnblogs.com/wiklvrain/p/8179327.html