原题链接
好妙的一道神仙题
题目大意
让你求在(k)进制下,(frac{x}{y})((xin [1,n],yin [1,m]))中有多少个最简分数是纯循环小数
SOLUTION
首先查一下资料,你会发现在十进制下,一个分数是纯循环小数的充要条件是分母的质因子中不含(2)和(5)。因为(10=2 imes 5),于是我们猜在(k)进制下只要分母与(k)互质即可
orz,猜对了!但是怎么证明呢?
***先在十进制下考虑,看一下题目给的提示,可以知道那些余数其实是(x mod y),(10x mod y),(10^2x mod y)...,余数出现重复,表明如下同余方程有解:
又因为(gcd(x,y)=1),所以(10^lequiv 1(mod y)),然后可以得出(gcd(y,10)=1)
在(k)进制下同理***
于是题目可以等价成让我们求这个式子的值(分数线默认向下取整)
两个和号并一起太丑了,先分开
上个莫比乌斯反演
把(d)往前提
因为(k)只有(2000),所以后面那个和号可以预处理一下然后(O(1))的求。大概就是设(g(n)=sumlimits_{i=1}^{n}[gcd(i,k)=1]),同时有(g(n)=frac{n}{k}g(k)+g(n mod k)),只需要预处理到(g_k)就好了
主要是前面的这一部分,推导过程参考自yyb
设
然后一波天秀的操作(用到了(mu)函数的定义和它的积性)
然后就可以愉快地递归加个记忆化了,但是边界稍微有点麻烦:
当(nleqslant 1)或(k=1)时,它等于(sumlimits_{d=1}^{n}mu (d)),而(n)最大可能有(10^9),所以需要上一个杜教筛
回到这个式子
把后面用(g)替换,得到
发现是一个二维整除分块的形式,然后就没了
代码,用了一点小优化:
#include <bits/stdc++.h>
using namespace std;
#define pii pair<int, int>
#define mp make_pair
#define ll long long
#define MAXN 1000000
#define MAXK 2000
int n, m, k;
int prime[MAXN + 5], mu[MAXN + 5], sum[MAXN + 5], cnt, g0[MAXK + 5];
bool vis[MAXN + 5];
map<int, int> m1;
map<ll, int> m2;
int gcd(int x, int y) {
return !y ? x : gcd(y, x % y);
}
void init() {
mu[1] = sum[1] = 1;
vis[1] = true;
for (int i = 2; i <= MAXN; ++i) {
if (!vis[i]) prime[++cnt] = i, mu[i] = -1;
for (int j = 1; j <= cnt && i * prime[j] <= MAXN; ++j) {
vis[i * prime[j]] = true;
if (i % prime[j] == 0) break;
else mu[i * prime[j]] = -mu[i];
}
sum[i] = mu[i] + sum[i - 1];
}
for (int i = 1; i <= k; ++i) g0[i] = g0[i - 1] + (gcd(i, k) == 1);
}
int getsum(int n) {
if (n <= MAXN) return sum[n];
else if (m1.count(n)) return m1[n];
int ret = 1;
for (int l = 2, r; l <= n; l = r + 1) {
r = n / (n / l);
ret -= (r - l + 1) * getsum(n / l);
}
return m1[n] = ret;
}
int f(int n, int k) {
if (k == 1 || n <= 1) return getsum(n);
else if (m2.count(3000LL * n + k)) return m2[3000LL * n + k];
int ret = 0;
for (int i = 1; i <= k; ++i) {
if (k % i) continue;
if(mu[i]) ret += mu[i] * mu[i] * f(n / i, i); // 优化,如果mu[i]是0就不需要递归了
}
return m2[3000LL * n + k] = ret;
}
int g(int n) {
return n / k * g0[k] + g0[n % k];
}
int main() {
scanf("%d%d%d", &n, &m, &k);
init();
int lim = min(n, m);
ll ans = 0;
for (int l = 1, r; l <= lim; l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans += 1LL * (f(r, k) - f(l - 1, k)) * (n / l) * g(m / l);
}
printf("%lld
", ans);
return 0;
}