[洛谷P3301][BZOJ3129][SDOI2013]方程(扩展Lucas+容斥)

Solution

  • 先考虑\(n_1=0\)的情况
  • 那么只要考虑形如\(X_i>=A_i\)的限制
  • 注意求的是正整数解的个数,即对于\(i>n_2\)\(X_i>=1(A_i=1)\)
  • \(\sum_{i=1}^{n}B_i=m\)非负整数解的个数为\(C(m+n-1,m)\)
  • 解释:序列共\(m+n-1\)个位置,选\(n-1\)个位置出来当隔板,把序列分为长度之和为\(m\)\(n\)段(可能存在长度为\(0\)的段,即隔板相邻的情况)
  • 现在为了满足这些限制,令\(B_i=X_i-A_i\),则\(B_i\)非负整数解的个数就是原题的合法解的个数
  • 那么\(m\)要减掉\(\sum_{i=1}^{n}A_i\)
  • 考虑\(n_1>0\)的情况,用总方案数\(-\)存在\(X_i>=A_i+1(1<=i<=n_1)\)的情况
  • 即考虑容斥:不考虑前\(n_1\)个数的限制的方案数\(-\)\(n_1\)个数至少有\(1\)个不满足条件的方案数\(+\)\(n_1\)个数至少有\(2\)个不满足条件的方案数\(-\)……
  • 发现\(n,m\)很大,但任意一组数据的\(p\)都可以拆成\(\Pi_{i=1}^{k}pi^{qi}\),且\(p_i<=10007\),那么用扩展\(lucas\)求组合数取模即可

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
	res = res * 10 + (ch ^ 48);
}

const int o = 2000;
int a[o], b[o], pk, p, c[o], d[o], tst, n1, n2, n, m, ans, h[o], now, f[20][10010];
bool vis[o];
ll tot;

inline int exgcd(int a, int b, int &x, int &y)
{
	if (!b)
	{
		x = 1;
		y = 0;
		return a;
	}
	int ret = exgcd(b, a % b, x, y), tmp = x;
	x = y;
	y = tmp - a / b * y;
	return ret;
}

inline int ksm(int x, ll y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = (ll)res * x % pk;
		y >>= 1;
		x = (ll)x * x % pk;
	}
	return res;
}

inline int fac(int n, int p, int k)
{
	if (n == 1 || n == 0) return 1;
	ll cnt = n / p, bl = n / pk, res = fac(n / p, p, k), i, tmp;
	tot += cnt;
	tmp = f[now][pk - 1];
	tmp = ksm(tmp, bl);
	res = res * tmp % pk;
	res = res * f[now][n % pk] % pk;
	return res;
}

inline int solve(int n, int m, int id)
{
	int p = c[id], k = d[id];
	pk = a[id];
	tot = 0;
	int ra = fac(m, p, k); ll ta = tot;
	tot = 0;
	int rb = fac(n - m, p, k); ll tb = tot;
	tot = 0;
	int rc = fac(n, p, k); ll tc = tot;
	ll t = tc - ta - tb;
	if (t < 0) t = (t % k + k) % k;
	int ia, ib, xxx;
	exgcd(ra, pk, ia, xxx);
	exgcd(rb, pk, ib, xxx);
	if (ia < 0) ia += pk;
	if (ib < 0) ib += pk;
	return (ll)rc * ia % pk * ib % pk * ksm(p, t) % pk;
}

inline void init()
{
	int i, s = sqrt(p), lp = p, j;
	for (i = 2; i <= s; i++)
	if (lp % i == 0)
	{
		int t = 0, r = 1;
		while (lp % i == 0) 
		{
			t++;
			r *= i;
			lp /= i;
		}
		a[++a[0]] = r; 
		c[a[0]] = i;
		d[a[0]] = t;
	}
	if (lp != 1) 
	{
		a[++a[0]] = lp;
		c[a[0]] = lp;
		d[a[0]] = 1;
	}
	for (i = 1; i <= a[0]; i++)
	{
		f[i][0] = 1;
		for (j = 1; j <= a[i]; j++)
		if (j % c[i]) f[i][j] = (ll)f[i][j - 1] * j % a[i];
		else f[i][j] = f[i][j - 1];
	}
}

inline int cc(ll n, ll m, int p)
{
	if (n < m || m < 0) return 0;
	int ans = 0, i;
	for (i = 1; i <= a[0]; i++) 
	{
		now = i;
		b[i] = solve(n, m, i);
	}
	for (i = 1; i <= a[0]; i++)
	{
		int mi = p / a[i], g, y, aa = a[i];
		exgcd(mi, aa, g, y);
		ans = (ans + (ll)mi * g % p * b[i] % p + p) % p;
	}
	return ans;
}

inline void add(int &x, int y)
{
	x += y;
	if (x >= p) x -= p;
}

inline void pd()
{
	int i, tm = m, cnt = 0;
	for (i = 1; i <= n1; i++)
	if (vis[i])
	{
		cnt++;
		tm -= h[i] + 1;
	}
	else tm--;
	if (!cnt) return;
	if (cnt & 1) add(ans, p - cc(tm + n - 1, tm, p));
	else add(ans, cc(tm + n - 1, tm, p));
} 

inline void dfs(int k)
{
	if (k == n1 + 1)
	{
		pd();
		return;
	}
	vis[k] = 0;
	dfs(k + 1);
	vis[k] = 1;
	dfs(k + 1);
}

int main()
{
	int i;
	read(tst); read(p);
	init();
	while (tst--)
	{
		read(n); 
		read(n1); 
		read(n2);
		read(m);
		int tmp = n1 + n2;
		for (i = 1; i <= tmp; ++i) read(h[i]);
		m -= n - n1 - n2;
		for (i = n1 + 1; i <= n2 + n1; i++) m -= h[i];
		int tm = m - n1;
		ans = cc(tm + n - 1, tm, p);
		dfs(1);
		printf("%d\n", ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/cyf32768/p/12196441.html