AT5200 [AGC038C] LCMs

题目描述

给定一个长度为(N)的数列(A_1, A_2, A_3, ldots, A_N)​。
请你求出(sum_{i=1}^{N}sum_{j=i+1}^{N}mathrm{lcm}(A_i,A_j))的值模(998244353)的结果。
(1leq N leq 2 imes 10^5,1 leq A_i leq 10^6)

题解

(sum_{i = 1} ^ {N}sum_{j = i + 1} ^ {N}mathrm{lcm}(A_i, A_j))

(= frac{sum_{i = 1} ^ {N} sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j) - sum_{i = 1} ^ {N}A_i}{2})

所以我们只要维护(sum_{i = 1} ^ {N}sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j))就很容易得到答案。

(sum_{i = 1} ^ {N}sum_{j = 1} ^ {N}mathrm{lcm}(A_i, A_j))

(=sum_{d = 1} ^ {Max} frac{1}{d} sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [mathrm{gcd}(A_i, A_j) == d])

(F(d) = sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [d | mathrm{gcd}(A_i, A_j)], f(d) = sum_{i = 1} ^ {N}A_i sum_{j = 1} ^ {N}A_j [mathrm{gcd}(A_i, A_j) == d])

(F(d) = sum f(e) [d|e])

关于(F), 我们可以计算每个数的所有倍数之和并让他们平方得到(两两组合均合法)

由莫比乌斯反演可得,(f(d) = sum F(e) * mu(frac{e}{d}))

以上各种倍数的枚举均可以(O(n * ln_n))得出,再加上线性处理逆元即可。

#include <iostream>
#include <cstdio>
#define ll long long
#define int long long
using namespace std;
const int N = 2e5 + 5, M = 1e6 + 5;
int n, a[N], mx, v[M], prime[M], tot, mu[M];
ll ans, t[M], inv[M], f[M], sum, F[M];
const int mod = 998244353;
inline int read()
{
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
	return x * f;
}
void init(int n)
{
	mu[1] = 1;
	for(int i = 2; i <= n; i ++)
	{
		if(!v[i]) {prime[++ tot] = i; mu[i] = -1;}
		for(int j = 1; j <= tot && prime[j] * i <= n; j ++)
		{
			v[i * prime[j]] = 1;
			if(i % prime[j] == 0)
			{
				mu[i * prime[j]] = 0;
				break;
			}
			mu[i * prime[j]] = - mu[i];
		}
	}
	inv[0] = inv[1] = 1;
	for(int i = 2; i <= n; i ++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
void work()
{
	n = read();
	for(int i = 1; i <= n; i ++) a[i] = read(), t[a[i]] ++, mx = max(mx, a[i]), sum = (sum + a[i]) % mod;
	init(mx);
	for(int i = 1; i <= mx; i ++)
	{
		for(int j = i; j <= mx; j += i) F[i] = (F[i] + t[j] * j) % mod;
		F[i] = (F[i] * F[i]) % mod;
	}
	for(int i = 1; i <= mx; i ++) for(int j = i; j <= mx; j += i) f[i] = (f[i] + (F[j] * mu[j / i] % mod + mod)) % mod;
	for(int d = 1; d <= mx; d ++) ans = (ans + inv[d] * f[d] % mod) % mod;
	ans = (ans - sum + mod) % mod * inv[2] % mod;
	printf("%lld
", (ans + mod) % mod);
}
signed main() {return work(), 0;}
原文地址:https://www.cnblogs.com/Sunny-r/p/12566673.html