POJ3904 Sky Code

vjudge传送


题面:给定(n)个数,求从中选出的四元组((a,b,c,d))数目,满足(gcd(a,b,c,d)=1).((nleqslant 10^4, 1leqslant a_i leqslant 10^4)


这题虽然不是很难,但我还是没想出来,惭愧。
首先四个数的最大公约数为1,肯定不代表两两互质,这就不是很好做。于是正难则反,求出所有满足最大公约数大于1的四元组数目(ans),那么答案就是(C_n^4 - ans).
至于(ans)如何计算?考虑对于任意一个最大公约数(x),如果我们能在只(x)的最小质因子处筛掉(x),那自然是最好不过了。但是这样难于实现,因为无法确定哪一个数可能成为四元组的最大公因数。


那么可以换一种思路,我们可以用容斥计算(x)的贡献,即将(x)质因数分解后,(x=p_1^{c_1}*p_2^{c_2}*cdots p_k^{c_k}),那么加上(p_1,p_2,cdots,p_n)的贡献,减去(p_1*p_2,p_1*p_3,cdots,p_i*p_j)的贡献,再加上三个质因数相乘的贡献……这样我们只用将(n)个数质因数分解,然后记录他们所有质因子的次数不大于1的因子的个数,这样答案就是(sum (-1)^{k+1}C_d^4)了((k)(d)中含有的不同质因子的个数)。


我没看懂大多数题解上的二进制容斥,就先自己预处理出质因子次数不大于1的数了。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e4 + 5;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n, a[maxn];

bool one[maxn], num[maxn];
ll C[maxn];
In int calc1(int n)
{
	int ret = 0, x = n;
	for(int i = 2; i * i <= n; ++i)
		if(n % i == 0)
		{
			int tp = 0;
			while(n % i == 0) n /= i, ret ^= 1, tp++;
			if(tp > 1) one[x] = 0;
		} 
	if(n > 1) ret ^= 1;
	return ret;
}
int ans[maxn];
In void calc2(int n)
{
	for(int i = 2; i * i <= n; ++i)
		if(n % i == 0)
		{
			if(one[i]) ans[i]++;
			if(i * i < n && one[n / i]) ans[n / i]++;
		}
	if(n > 1 && one[n]) ans[n]++;
}

In void init()
{
	for(int i = 4; i < maxn; ++i) C[i] = 1LL * i * (i - 1) * (i - 2) * (i - 3) / 24;
	for(int i = 2; i < maxn; ++i) one[i] = 1, num[i] = calc1(i);
}

int main()
{
	init();
	while(scanf("%d", &n) != EOF)
	{
		Mem(ans, 0);
		for(int i = 1; i <= n; ++i) calc2(read());
		ll Ans = 0;
		for(int i = 2; i < maxn; ++i) Ans += (num[i] ? C[ans[i]] : -C[ans[i]]);	
		write(C[n] - Ans), enter;
	}
	return 0;
}
原文地址:https://www.cnblogs.com/mrclr/p/14879019.html