[CF960G] Bandit Blues

problem

求满足(sum_i[p_i=max_{j=1}^i p_j]=a)(sum_i[p_i=max_{j=i}^n p_j]=b)的1到n的排列p的个数。

solution

设f[i,j]为从大到小地向序列中加入i个数,形成了j个前缀最大值的情况,转移有

[egin{aligned} f[0,0]=1,&&f[i,j]=f[i-1,j-1]+(i-1)f[i-1,j] end{aligned} ]

显然这恰是第一类斯特林数,即(f[i,j]=s(i,j))

一个数集与一个操作方案能对应一个序列。考虑枚举数n的位置,那么答案为

[sum_{i=1}^ns(i-1,a-1)s(n-i,b-1) imes C(n-1,i-1) ]

这相当于是把1到n-1给分成a+b-2个环的方案数(其中环有两类,每类分别由a+1个和b+1个)即答案

[s(n-1,a+b-2) imes C(a+b-2,a-1) ]

至此问题已完结。

#include <bits/stdc++.h>
#define ll long long 
using namespace std;

const int N=2e5+10;
const int mod=998244353;
const int inf=0x3f3f3f3f;

inline ll qpow(ll x,ll y) {
	ll c=1;
	for(; y; y>>=1,x=x*x%mod)
		if(y&1) c=x*c%mod;
	return c;
}
int p,pcur,rev[N];
inline void ntt_init(int len) {
	for(p=1,pcur=0; p<(len<<1);) p<<=1,pcur++;
	for(int i=0; i<p; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(pcur-1));
}
inline void ntt(ll*a,int tp) {
	for(int i=0; i<p; ++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int m=1; m<p; m<<=1) {
		int wm=qpow(3,(mod-1)/(m<<1)); if(tp<0) wm=qpow(wm,mod-2);
		for(int i=0; i<p; i+=(m<<1)) { ll w=1,tmp;
			for(int j=0; j<m; ++j,w=w*wm%mod) {
				tmp=w*a[i+j+m]%mod;
				a[i+j+m]=(a[i+j]-tmp+mod)%mod;
				a[i+j]=(a[i+j]+tmp)%mod;
			}
		} 
	}
	if(tp<0) {
		ll tmp=qpow(p,mod-2);
		for(int i=0; i<p; ++i) a[i]=tmp*a[i]%mod;
	}
}
inline void chm(ll*A,ll*B) {
	ntt(A,1); ntt(B,1);
	for(int i=0; i<p; ++i) (A[i]*=B[i])%=mod;
	ntt(A,-1); 
}
ll fac[N],fav[N],A[N],B[N];
void calc(int n,ll*s) {
	if(n==0) {s[0]=1; return;}
	if(n==1) {s[1]=1; return;}
	int m(n/2); calc(m,s); ntt_init(m+1);
	for(int i=0; i<=m; ++i) A[m-i]=fac[i]*s[i]%mod;
	for(int i=0; i<=m; ++i) B[i]=fav[i]*qpow(m,i)%mod;
	for(int i=m+1; i<p; ++i) A[i]=B[i]=0;
	chm(A,B);
	for(int i=0; i<=m; ++i) B[i]=A[m-i]*fav[i]%mod;
	for(int i=0; i<=m; ++i) A[i]=s[i];
	for(int i=m+1; i<p; ++i) A[i]=B[i]=0;
	chm(A,B);
	for(int i=0; i<=m+m; ++i) s[i]=A[i];
	if(n&1)
	for(int i=n; i>=0; --i) s[i]=((i?s[i-1]:0)+(n-1)*s[i]%mod)%mod;
}

ll s[N];
int main() {
	fac[0]=fac[1]=fav[0]=fav[1]=1;
	for(int i=2; i<N; ++i) fav[i]=fav[mod%i]*(mod-mod/i)%mod;
	for(int i=2; i<N; ++i) fav[i]=fav[i-1]*fav[i]%mod,fac[i]=fac[i-1]*i%mod;
	//int n; scanf("%d",&n); calc(n,s);
	//for(int i=0; i<=n; ++i) printf("s(%d,%d)=%d
",n,i,s[i]);
	int n,a,b;
	scanf("%d%d%d",&n,&a,&b);
	calc(n-1,s);
	printf("%lld",fac[a+b-2]*fav[a-1]%mod*fav[b-1]%mod*s[a+b-2]%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/nosta/p/10972968.html