BZOJ 3622 Luogu P4859 已经没有什么好害怕的了 (容斥原理、DP)

题目链接

(Luogu) https://www.luogu.org/problem/P4859
(bzoj) https://www.lydsy.com/JudgeOnline/problem.php?id=3622

题解

我依然啥都不会啊……

先给(A,B)数组从小到大排序。
考虑容斥,设(f[j])表示钦定了(j)个满足(A>B), 所有钦定方案的方案数总和。
这个怎么算?dp算。设(dp[i][j])表示前(i)个的(f[j]), 然后发现转移的时候并不知道之前选的那些没有钦定的有几个比当前的大。
怎么办?转换思路,我们考虑先选出钦定的(B), 不考虑剩下未钦定的。那么很容易列出方程: (dp[i][j]=dp[i-1][j]+dp[i-1][j-1] imes (t-i+1)), 其中(t)(a_i)大于(B)序列中的多少个数。
最后(f[j]=dp[n][j] imes (n-j)!), 完美解决。

容斥怎么办?思考组合意义或者二项式反演,总之最后是一个式子(g[i]=sum^n_{j=i}(-1)^{j-i}{jchoose i}f[j]), 其中(g[i])表示恰好有(i)(A>B)的方案数。

时间复杂度(O(n^2)).

代码

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<iostream>
#include<algorithm>
#define llong long long
using namespace std;

inline int read()
{
	int x=0; bool f=1; char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
	if(f) return x;
	return -x;
}

const int N = 2000;
const int P = 1e9+9;
int a[N+3],b[N+3];
llong dp[N+3][N+3];
llong f[N+3];
llong fact[N+3],finv[N+3];
int n,m;

llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
llong comb(llong x,llong y) {return x<0||y<0||x<y ? 0ll : fact[x]*finv[y]%P*finv[x-y]%P;}

int main()
{
	fact[0] = 1ll; for(int i=1; i<=N; i++) fact[i] = fact[i-1]*i%P;
	finv[N] = quickpow(fact[N],P-2); for(int i=N-1; i>=0; i--) finv[i] = finv[i+1]*(i+1)%P;
	scanf("%d%d",&n,&m);
	if((n+m)&1) {printf("0"); return 0;}
	m = (n+m)>>1;
	for(int i=1; i<=n; i++) scanf("%d",&a[i]);
	for(int i=1; i<=n; i++) scanf("%d",&b[i]);
	sort(a+1,a+n+1); sort(b+1,b+n+1);
	dp[0][0] = 1ll;
	for(int i=1; i<=n; i++)
	{
		int t = lower_bound(b+1,b+n+1,a[i])-b-1;
		for(int j=0; j<=i; j++)
		{
			dp[i][j] = dp[i-1][j];
			if(j>0 && t-j+1>0) {dp[i][j] = (dp[i][j]+dp[i-1][j-1]*(t-j+1))%P;}
		}
	}
	for(int i=0; i<=n; i++) f[i] = dp[n][i]*fact[n-i]%P;
	llong ans = 0ll;
	for(int i=m; i<=n; i++)
	{
		llong tmp = f[i]*comb(i,m);
		if((i-m)&1) {ans = (ans-tmp+P)%P;}
		else {ans = (ans+tmp)%P;}
	}
	printf("%lld
",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/suncongbo/p/11270017.html