Luogu 5170 【模板】类欧几里得算法

原理不难但是写起来非常复杂的东西。

我觉得讲得非常好懂的博客。   传送门

我们设

$$f(a, b, c, n) = sum_{i = 0}^{n}left lfloor frac{ai + b}{c} ight floor$$

$$g(a, b, c, n) = sum_{i = 0}^{n}ileft lfloor frac{ai + b}{c} ight floor$$

$$h(a, b, c, n) = sum_{i = 0}^{n}left lfloor frac{ai + b}{c} ight floor^2$$

先考虑一个结论

$$left lfloor frac{Ax}{y} ight floor = left lfloor frac{A(x mod y)}{y} ight floor + Aleft lfloor frac{x}{y} ight floor$$

那么有

$$left lfloor frac{ai + b}{c} ight floor = left lfloor frac{(a mod c)i + (b mod c)}{c} ight floor + ileft lfloor frac{a}{c} ight floor + left lfloor frac{b}{c} ight floor$$

这个东西可以把$a geq c$ 或者$b geq c$的情况转化成$a,b < c$的情况。

F

先看看这个比较好做的$f$。

注意到当$a geq c$或者$b geq c$的时候,我们可以用以上的结论把$a$或者$b$降下来,有

$$f(a, b, c, n) = left lfloor frac{b}{c} ight floor(n + 1) + left lfloor frac{a}{c} ight floor frac{n(n + 1)}{2} + f(a mod c, b mod c, c, n)$$

那么当$a, b < c$的时候,有

$$f(a, b, c, n) = sum_{i = 0}^{n}sum_{j = 1}^{m}[left lfloor frac{ai + b}{c} ight floor geq j]$$

为了方便把$left lfloor frac{an + b}{c} ight floor$记为$m$。

$$f(a, b, c, n) = sum_{i = 0}^{n}sum_{j = 0}^{m - 1}[left lfloor frac{ai + b}{c} ight floor geq j + 1]$$

$$= sum_{i = 0}^{n}sum_{j = 0}^{m - 1}[ai geq cj + c - b]$$

$$= sum_{i = 0}^{n}sum_{j = 0}^{m - 1}[ai > cj + c - b - 1]$$

$$= sum_{i = 0}^{n}sum_{j = 0}^{m - 1}[i > left lfloor frac{cj + c - b - 1}{a} ight floor]$$

注意到这时候$i$这一项可以直接算了。

$$f(a, b, c, n) = sum_{j = 0}^{m - 1}(n -  left lfloor frac{cj + c - b - 1}{a} ight floor)$$

$$= nm - f(c, c - b - 1, a, m - 1)$$

比较方便。

G

按照套路先做$a geq c$或者$b geq c$的情况。

$$g(a, b, c, n) = left lfloor frac{b}{c} ight floorfrac{n(n + 1)}{2} + left lfloor frac{a}{c} ight floor frac{n(n + 1)(2n + 1)}{6}  + g(a mod c, b mod c, c, n)$$

然后接着按照上述套路弄$a, b < c$的情况。

$$g(a, b, c, n) = sum_{j = 0}^{m - 1}frac{(n - left lfloor frac{cj + c - b - 1}{a} ight floor)(n + 1 + left lfloor frac{cj + c - b - 1}{a} ight floor)}{2}$$

$$= frac{1}{2}(mn(n + 1) - f(c, c - b - 1, a, m - 1) - h(c, c - b - 1, a, m - 1))$$

还要解决$h$。

H

当$a geq c$或者$b geq c$的时候

$$h(a, b, c, n) = (n + 1)left lfloor frac{b}{c} ight floor^2 + frac{n(n + 1)(2n + 1)}{6}left lfloor frac{a}{c} ight floor^2 + n(n + 1)left lfloor frac{b}{c} ight floorleft lfloor frac{a}{c} ight floor + h(a mod c, b mod c, c, n) + 2left lfloor frac{a}{c} ight floor g(a mod c, b mod c, c, n) + 2left lfloor frac{b}{c} ight floor f(a mod c, b mod c, c, n)$$

好麻烦啊……

当$a, b < c$的时候需要一些操作

$$n^2 = 2 imes frac{n(n + 1)}{2} - n = 2sum_{i = 1}^{n}i - n$$

可以得到

$$h(a, b, c, n) = sum_{i = 0}^{n}(2sum_{j = 1}^{left lfloor frac{ai + b}{c} ight floor}j - left lfloor frac{ai + b}{c} ight floor)$$

$$= 2sum_{i = 0}^{n}sum_{j = 1}^{left lfloor frac{ai + b}{c} ight floor}j - f(a, b, c, n)$$

$$= 2sum_{j = 0}^{m - 1}(j + 1)sum_{i = 0}^{n}[left lfloor frac{ai + b}{c} ight floor geq j + 1] - f(a, b, c, n)$$

$$= 2sum_{j = 0}^{m - 1}(j + 1)(n - left lfloor frac{cj + c - b - 1}{a} ight floor) - f(a, b, c, n)$$

$$= nm(m + 1) - 2g(c, c - b - 1, a, m - 1) - 2f(c, c - b - 1, a, m - 1) - f(a, b, c, n)$$

大概就是这样了。

在计算的时候如果只考虑$(a, c)$这两项的话相当于每一次把$(a, c)$变成了$(c, a \% c)$,所以时间复杂度和欧几里得算法相同,递归层数是$log$层。

时间复杂度$O(Tlogn)$。

因为递归的式子很一致,在计算的时候应当三个东西一起算比较快。

Code:

#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;

const ll P = 998244353LL;

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for (; ch > '9'|| ch < '0'; ch = getchar())
        if (ch == '-') op = -1;
    for (; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

template <typename T>
inline void inc(T &x, T y) {
    x += y;
    if (x >= P) x -= P;
}

template <typename T>
inline void sub(T &x, T y) {
    x -= y;
    if (x < 0) x += P;
}

inline ll fpow(ll x, ll y) {
    ll res = 1;
    for (; y > 0; y >>= 1) {
        if (y & 1) res = res * x % P;
        x = x * x % P;
    }
    return res;
}

namespace Likegcd {
    struct Node {
        ll f, g, h;    
    };
    
    #define f(now) now.f
    #define g(now) now.g
    #define h(now) now.h
    
    const ll inv2 = fpow(2, P - 2);
    const ll inv6 = fpow(6, P - 2);
    
    inline Node solve(ll a, ll b, ll c, ll n) {
        Node res;
        if (!a) {
            f(res) = (b / c) * (n + 1) % P;
            g(res) = (b / c) * (n + 1) % P * n % P * inv2 % P;
            h(res) = (b / c) * (b / c) % P * (n + 1) % P;
            return res;
        }
        
        f(res) = g(res) = h(res) = 0;
        if (a >= c || b >= c) {
            Node tmp = solve(a % c, b % c, c, n);
            inc(f(res), (a / c) * n % P * (n + 1) % P * inv2 % P);
            inc(f(res), (b / c) * (n + 1) % P);
            inc(f(res), f(tmp));
            
            inc(g(res), (a / c) * n % P * (n + 1) % P * ((2 * n + 1) % P) % P * inv6 % P);
            inc(g(res), (b / c) * n % P * (n + 1) % P * inv2 % P);
            inc(g(res), g(tmp));
            
            inc(h(res), (a / c) * (a / c) % P * n % P * (n + 1) % P * ((2 * n + 1) % P) % P * inv6 % P);
            inc(h(res), (b / c) * (b / c) % P * (n + 1) % P);
            inc(h(res), (a / c) * (b / c) % P * n % P * (n + 1) % P);
            inc(h(res), h(tmp));
            inc(h(res), 2LL * (a / c) % P * g(tmp) % P);
            inc(h(res), 2LL * (b / c) % P * f(tmp) % P);
            
            return res;
        }
        
        if (a < c && b < c) {
            ll m = (a * n + b) / c;
            Node tmp = solve(c, c - b - 1, a, m - 1);
            
            f(res) = n * m % P;
            sub(f(res), f(tmp));
            
            g(res) = n * (n + 1) % P * m % P;
            sub(g(res), f(tmp));
            sub(g(res), h(tmp));
            g(res) = g(res) * inv2 % P;
            
            h(res) = n * m % P * (m + 1) % P;
            sub(h(res), 2LL * g(tmp) % P);
            sub(h(res), 2LL * f(tmp) % P);
            sub(h(res), f(res));
            
            return res;
        }
        
        return res;
    }
        
}

int main() {
    int testCase;
    read(testCase);
    for (ll a, b, c, n; testCase--; ) {
        read(n), read(a), read(b), read(c);
        Likegcd :: Node ans = Likegcd :: solve(a, b, c, n);
        printf("%lld %lld %lld
", ans.f, ans.h, ans.g);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/CzxingcHen/p/10365060.html