uoj#209【UER #6】票数统计

题目

做UER的A题涨信心

首先我们注意到这个所谓的至少有一条正确在(x)(y)不相等的时候非常弱,当(x<y)时,只有可能是后(y)位用户有(x)个通过;当(x>y)时,只有可能是前(x)位用户有(y)个通过。也就是说这些信息都能被转化成一些用来限制前后缀和的信息。

(pre_i)表示序列的前缀和,对于一条前(x)位用户有(y)个通过的限制,我们可以拆成(pre_x=y);对于一条后(y)位用户有(x)个通过的信息,可以视为(pre_n-pre_{n-y}=x),即(pre_{n-y}=pre_n-x)

如果我们知道(pre_n)的值,那么就只剩下了一些前缀和的信息了。所以我们可以直接枚举(pre_n)的值。这些关于前缀和的限制又将整个序列分割成了一些区间,每个区间的区间和也都被限制好了,直接使用组合数把每个区间的方案算出来就好了,答案就是每一个区间组合数的乘积。

但是上述的做法均不能处理(x=y)的情况,当(x=y)的时候,意味着有一个长度为(x)的前缀或后缀全都是(1)。这个(x)越大限制性必然越强,于是我们只需要考虑最大的(x=y),满足了最大的(x=y)剩下的(x=y)必然也都满足了。

我们枚举这个(x=y)限制前缀还是限制后缀,限制前缀就拆成(pre_x=x),限制后缀就拆成(pre_{n-x}=x)。但是如果有一种方案既有一段全是(1)的前缀,也有一段全是(1)的后缀,就会被计算两次。所以我们把两条限制条件都加上,再减掉这样的方案就好了。

代码

#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
inline int read() {
	char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=5e3+5;
const int mod=998244353;
int T,n,m,M;
int fac[maxn],ifac[maxn],inv[maxn];
int a[maxn],b[maxn],c[maxn],d[maxn],t[2],pre[maxn];
inline int C(int n,int m) {
	return m>n?0:1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
inline int solve(int sum) {
	for(re int i=1;i<=t[0];i++) {
		if(pre[a[i]]!=-1&&pre[a[i]]!=c[i]) return 0;
		pre[a[i]]=c[i];
	}
	for(re int i=1;i<=t[1];i++) {
		if(pre[n-b[i]]!=-1&&pre[n-b[i]]!=sum-d[i]) return 0;
		pre[n-b[i]]=sum-d[i]; 
	}
	if(pre[0]!=-1&&pre[0]!=0) return 0;
	pre[0]=0;int l=0,tot=1;
	for(re int i=1;i<=n;i++) {
		if(pre[i]==-1) continue;
		if(pre[i]<pre[l]) return 0;
		tot=1ll*tot*C(i-l,pre[i]-pre[l])%mod;l=i;
	}
	return tot;
}
int main() {
	T=read();fac[0]=ifac[0]=inv[1]=1;
	for(re int i=1;i<maxn;i++) fac[i]=1ll*fac[i-1]*i%mod;
	for(re int i=2;i<maxn;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
	for(re int i=1;i<maxn;i++) ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
	while(T--) {
		n=read(),m=read();int x,y;t[0]=t[1]=M=0;
		for(re int i=1;i<=m;i++) {
			x=read(),y=read();
			if(x<y) b[++t[1]]=y,d[t[1]]=x;
			if(x>y) a[++t[0]]=x,c[t[0]]=y;
			if(x==y) M=max(M,x);
		}
		int ans=0,now=M;
		for(re int i=1;i<=t[0];i++) now=max(now,c[i]);
		for(re int i=1;i<=t[1];i++) now=max(now,d[i]);
		for(re int i=now;i<=n;i++) {
			memset(pre,-1,sizeof(pre));
			pre[n]=i,pre[M]=M;
			ans=(ans+solve(i))%mod;
			if(!M) continue;
			memset(pre,-1,sizeof(pre));
			pre[n]=i,pre[n-M]=i-M;
			ans=(ans+solve(i))%mod;
			memset(pre,-1,sizeof(pre));
			pre[n]=i,pre[M]=M;pre[n-M]=i-M;
			if(M==n-M&&M!=i-M) continue;
			ans=(ans-solve(i)+mod)%mod;
		}
		printf("%d
",ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/asuldb/p/11391021.html