HDU 5446 lucas CRT

n中选m个模M,M为多个素数之积 $n, m, k (1 leq m leq n leq 10^{18}, 1 leq k leq 10)$,$M = p_1 · p_2 · · · p_k ≤ 10^{18}$,$p_i leq 10^5$

由于n,m很大组合数自然想到lucas,但是如果直接用M会因为M太大lucas就没什么用了,所以考虑以构成M的素因子为模数分别对组合数的lucas构建k个同余方程,这样就能得到模M下组合数了。了解题目意思后就很裸了

注意每个不同模数下的逆元、阶乘的模数也不同阿...

/** @Date    : 2017-10-11 12:56:59
  * @FileName: J.cpp
  * @Platform: Windows
  * @Author  : Lweleth (SoungEarlf@gmail.com)
  * @Link    : https://github.com/
  * @Version : $Id$
  */
#include <bits/stdc++.h>
#define LL long long
#define PII pair
#define MP(x, y) make_pair((x),(y))
#define fi first
#define se second
#define PB(x) push_back((x))
#define MMG(x) memset((x), -1,sizeof(x))
#define MMF(x) memset((x),0,sizeof(x))
#define MMI(x) memset((x), INF, sizeof(x))
using namespace std;

const int INF = 0x3f3f3f3f;
const int N = 1e5+20;
const double eps = 1e-8;

LL fac[N];
LL inv[N];
LL p[20];
LL r[20];
LL mod;
void init(int n, LL mod)
{
	fac[0] = fac[1] = 1;
	inv[0] = inv[1] = 1;
	for(int i = 2; i < n; i++)
	{
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	}
	for(int i = 2; i < n; i++)
		(inv[i] *= inv[i - 1]) %= mod;
}

LL C(LL n, LL m, LL mod)
{
	if(m > n)
        return 0;
    LL ans = 0;
    ans = ((fac[n] * inv[m] % mod)* inv[n - m]) % mod;
    return ans;
}

LL lucas(LL n, LL m, LL mod)
{
	if(m == 0)
		return 1;
	return C(n % mod, m % mod, mod) * lucas(n / mod, m / mod, mod) % mod;
}

LL exgcd(LL a, LL b, LL &x, LL &y)
{
	LL d = a;
	if(b == 0)
		x = 1, y = 0;
	else 
	{
		d = exgcd(b, a % b, y, x);
		y -= (a / b) * x;
	}
	return d;
}

LL mul(LL a, LL b, LL mod)
{
	while(a < 0)
		a += mod;
	while(b < 0)
		b += mod;
	LL ans = 0;
	while(b)
	{
		if(b & 1) 
			ans = (ans + a) % mod;
		a = (a + a) % mod;
		b >>= 1;
	}
	return ans;
}

LL CRT(LL n, LL rem[], LL mod[])
{
	LL M = 1, x, y;
	for(int i = 0; i < n; i++)
		M *= mod[i];
	LL res = 0;
	for(int i = 0; i < n; i++)
	{
		LL t = M / mod[i];
		exgcd(t, mod[i], x, y);
		res = (res + mul(mul(t , rem[i], M), x, M)) % M;
	}

	return (res % M + M) % M;
}

int main()
{
	int T;
	cin >> T;
	while(T--)
	{
		LL n, m, k;
		scanf("%lld%lld%lld", &n, &m, &k);
		mod = 1LL;
		for(int i = 0; i < k; i++)
		{
			scanf("%lld", p + i);
			init(p[i], p[i]);
			r[i] = lucas(n, m, p[i]);
		}
		LL ans = CRT(k, r, p);
		printf("%lld
", ans);
	}
    return 0;
}
原文地址:https://www.cnblogs.com/Yumesenya/p/7657678.html