BZOJ2956: 模积和

Description

 求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。

  

Input

第一行两个数n,m。

Output

  一个整数表示答案mod 19940417的值

Sample Input

3 4

Sample Output

1

样例说明
  答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1

数据规模和约定
  对于100%的数据n,m<=10^9。

Solution

题目就是求

[∑_{i=1}^n∑_{j=1}^m[i≠j](nspace modspace i)(mspace modspace j) ]

先讨论不考虑i≠j的限制条件的情况

[large egin{align*} &sum_{i=1}^nsum_{j=1}^m(nspace modspace i)(mspace modspace j)\ &=sumsum{(n-frac{n}{i}*i)(m-frac{m}{j}*j)}\ &=sum_{i=1}^{n}sum_{j=1}^{m}{nm-frac{n}{i}*i*m-n*frac{m}{j}*j+i*j*frac{n}{i}*frac{m}{j}}\ &=n^2m^2-nm^2sum_{i=1}^{n}{frac{n}{i}*i}-n^2*msum_{j=1}^m{frac{m}{j}*j}+nmsum_{i=1}^{n}{i*frac{n}{i}*}sum_{j=1}^{m}{j*frac{m}{j}} end{align*} ]

这是一种方法

然而还有更简便的方法

[large sum{nspace modspace i}*sum{mspace modspace j} ]

直接用余数之和那题的方法求这个就好(不知道余数之和那题怎么写的戳这里

就不用上面一大堆码起来也麻烦的式子了

对于i==j的情况

[large egin{align*} &sum_{i=1}^{k=min(n,m)}{(n-frac{n}{i}*i)(m-frac{m}{i}*i)}[i==j]\ &=sum_{i=1}^{k}{nm-m*frac{n}{i}*i-n*frac{m}{i}*i+i^2*frac{n}{i}*frac{m}{i}}\ &=knm-kmsum_{i=1}^{k}{frac{n}{i}*i}-knsum_{i=1}^{k}{frac{m}{i}*i}+ksum_{i=1}^{k}{i^2}sum_{i=1}^{k}{frac{n}{i}}sum_{i=1}^{k}{frac{m}{i}} end{align*} ]

利用数论分块(O(sqrt{n}))求出上面两式,将两式相减即可

P.S:(sum_{i=1}^n{i^2}=frac{n*(n+1)*(2n+1)}{6})

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define N 2010
#define mod 19940417
const ll m6 = 3323403;
ll n, m;
ll ans = 0;

ll sum(ll l, ll r) {
	return (r - l + 1) * (l + r) / 2 % mod;
}

ll calc(ll k) {
	ll ans = k * k % mod;
	for(int l = 1, r; l <= k; l = r + 1) {
		r = k / (k / l);
		ans = ((ans - sum(l, r) * (k / l) % mod) % mod + mod) % mod; 
	} 
	return ans;
}

ll cal(ll x) {
	return x * (x + 1) % mod * (2 * x + 1) % mod * m6 % mod;
}

ll sum2(ll l, ll r) {
	return (cal(r) - cal(l - 1) + mod) % mod;
} 

int main() {
	scanf("%lld%lld", &n, &m);
	if(n > m) swap(n, m);
	ans = calc(n) * calc(m) % mod;
	ans = ((ans - n * n % mod * m % mod) % mod + mod) % mod; 
	for(int l = 1, r; l <= n; l = r + 1) {
		r = min(n / (n / l), m / (m / l));
		ans = (ans + sum(l, r) * ((n/l)*m % mod + (m/l)*n % mod) % mod % mod);
		ans = (ans - sum2(l, r) * (n/l) % mod * (m/l) % mod + mod) % mod;
	}
	printf("%lld
", (ans % mod + mod) % mod);
	return 0;
}
原文地址:https://www.cnblogs.com/henry-1202/p/10201032.html