csp2019 Emiya家今天的饭题解

qwq

由于窝太菜了,实在是不会,所以在题解的帮助下过掉了这道题。
写此博客来整理一下思路

正文

传送
简化一下题意:现在有(n)(m)列数,选(k)个数的合法方案需满足:
1.一行最多选一个
2.一列最多选(lfloor frac{k}{2} floor)个数
当然,如果你在某一行里选了0,就相当于没有在这一行里选数
选一次对答案的贡献是你选的所有不为零的数的乘积。对于任意的(k),只要有合法方案,就能取。
(希望没有把题目变得更复杂叭)
根据上面的要求,我们发现(k)的取值范围是([1,n])。而且根据要求2,如果某个方案在满足1的前提下,是不合法的,那么这个方案里面一定有且仅有1列选了超出(lfloor frac{k}{2} floor)个数,因为不可能有两列选的数同时超过(lfloor frac{k}{2} floor)个。我们现在知道了不合法方案的一个特征,那么我们不妨试试总方案数-不合法方案数这个思路。

因为 满足1情况的总方案数-满足1而且不合法的方案数=乱选方案数-不满足1或不满足2的方案数 ,所以我们接下来计算方案数都在满足1的条件下来计算。

计算总方案数:设(all[i][j])表示前(i)行,每行至多选一个,一共选了(j)个的方案数,那么(all[i][j]=all[i-1][j]+sum_{l=1}^m{all[i-1][j-1] imes a[i][l]})。用(sum[i])表示第(i)行所有数的和,那么(all[i][j]=all[i-1][j]+all[i-1][j-1] imes sum[i])

我们再来看看不合法方案怎么算。上面说到一个不合法方案一定只有1行选的数超过了(lfloor frac{k}{2} floor)个,所以我们可以枚举每一列。但是我们不知道(k)。那么我们可以设(no[i][j][l])表示前(i)行,该列选了(j)个,其他列选了(l)个。(no[i][j][l]=no[i-1][j][l]+no[i-1][j-1][l] imes a[i][j]+no[i-1][j][l-1] imes(sum[i]-a[i][j]))这样就可以由(j,l)确定唯一的(k)。枚举列:(O(m)),枚举(i):(O(n)),因为选数的个数最多是(n),所以枚举(j,l)都是(O(n)),总复杂度(O(mn^3))

显然是不够的,需要优化。发现我们其实并不需要具体的(k),只需要知道当前列和其他列选的数的差值即可。为什么呢?不妨设当前列选的数为(x+j)个,其他列选的数为(x)个,那么一共选的数就是(2x+j)个。这里(x)取值任意(只要合法就行),所以可以(2x+j)覆盖所有的(k)。所以设(no[i][j])表示前(i)行,当前枚举的列比其他列多选了(j)个的方案数。
(no[i][j]=no[i-1][j]+no[i-1][j-1] imes a[i][j]+no[i-1][j+1] imes (sum[i]-a[i][j]))
注意这里有个坑:枚举到第(i)行的时候,当前列最多会比其他列少(i)个数,所以(j)应该从(-i)开始枚举,而不是0。考虑到不能出现负下标,所以在代码中将每个下标+n。
如果这个方案是不合法方案,那么对应的(j)一定大于0。
最终答案就是(sum_{j=1}^n{all[n][j]}-sum_{j=1}^n{no[n][j]})

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline ll read()
{
	char ch=getchar();
	ll x=0;bool f=0;
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<3)+(x<<1)+(ch^48);
		ch=getchar();
	}
	return f?-x:x;
}
const ll mod=998244353;
ll n,m,a[209][2009],sum[109],all[209][2009];
ll no[109][2109];
ll ans;
int main()
{
    n=read();m=read();	
    for(int i=1;i<=n;i++)
     for(int j=1;j<=m;j++)
      a[i][j]=read(),sum[i]=(sum[i]+a[i][j])%mod;//保险起见随时随地模一下
    all[0][0]=1;  
    for(int i=1;i<=n;i++)
		for(int j=0;j<=n;j++)
			all[i][j]=(all[i-1][j]+all[i-1][j-1]*sum[i]%mod+mod)%mod;  
	for(int j=1;j<=n;j++)
	 ans=(ans+all[n][j])%mod;
	for(int lie=1;lie<=m;lie++)
	{
		memset(no,0,sizeof(no));
		no[0][n]=1;
		for(int i=1;i<=n;i++)
	    {
		    for(int j=n-i;j<=n+i;j++)
		   {
			 no[i][j]=(no[i-1][j]+no[i-1][j-1]*a[i][lie]%mod+no[i-1][j+1]*(sum[i]-a[i][lie])%mod+mod)%mod;
		   }
	    }
	    for(int j=n+1;j<=2*n;j++)
	     ans=(ans-no[n][j]+mod)%mod;
	}
	cout<<ans;
	return 0;
}
原文地址:https://www.cnblogs.com/lcez56jsy/p/12079053.html