[CF960G] Bandit Blues

题目描述

(CF)题面:https://codeforces.com/problemset/problem/960/G

洛谷题面(带翻译):https://www.luogu.org/problemnew/show/CF960G

Solution

考虑序列可以被前缀(后缀)最大值分成(a+b-2)个块,注意我们忽略了中间的大小为(n)的数。

设这些最大值为(p_i),那么每个块就是([p_i,p_{i+1}-1])

注意到我们可以随意分配,每次还要把最大的放在最前面(最后面),所以可以注意到这是个圆排列,所以分成这么多块的方案数就是(s(n-1,a+b-2))(s)为第一类斯特林数。

然后我们要把(a-1)的块放在前面,所以乘上一个组合数,答案就是:

[s(n-1,a+b-2)cdot inom{a+b-2}{a-1} ]

斯特林数可以分治(FFT)求,复杂度(O(nlog ^2 n))


第一类斯特林数的求法如下,我们可以构造生成函数:

[prod _{i=0}^{n-1}(x+i) ]

那么这个生成函数的第(k)项就是(s(n,k))


#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}

#define lf double
#define ll long long 

const int maxn = 6e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;

int qpow(int a,int x) {
	int res=1;
	for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
	return res;
}

int f[maxn],a,b,n,w[maxn],rw[maxn],pos[maxn],N,mxn,bit,fac[maxn];

void prepare() {
	w[0]=1,w[1]=qpow(3,(mod-1)/mxn);
	for(int i=2;i<=mxn;i++) w[i]=1ll*w[i-1]*w[1]%mod;
	rw[0]=1,rw[1]=qpow(qpow(3,mod-2),(mod-1)/mxn);
	for(int i=2;i<=mxn;i++) rw[i]=1ll*rw[i-1]*rw[1]%mod;
}

void ntt(int *r,int op) {
	for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
	for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
		for(int j=0;j<N;j+=i<<1)
			for(int k=0;k<i;k++) {
				int x=r[j+k],y=1ll*r[i+j+k]*(op==1?w:rw)[k*d]%mod;
				r[j+k]=(x+y)%mod,r[i+j+k]=(x-y+mod)%mod;
			}
	if(op==-1) {
		int inv=qpow(N,mod-2);
		for(int i=0;i<N;i++) r[i]=1ll*r[i]*inv%mod;
	}
}

int tmp[18][maxn],tmp1[maxn],tmp2[maxn];

void solve(int l,int r,int d) {
	if(l==r) return tmp[d][0]=l,tmp[d][1]=1,void();
	int mid=(l+r)>>1;
	solve(l,mid,d+1);
	for(int i=0;i<=mid-l+1;i++) tmp[d][i]=tmp[d+1][i];
	solve(mid+1,r,d+1);
	for(int i=0;i<=r-mid;i++) tmp2[i]=tmp[d+1][i];
	for(bit=0,N=1;N<(r-l+1)<<1;N<<=1,bit++);
	for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
	for(int i=mid-l+2;i<N;i++) tmp[d][i]=0;
	for(int i=r-mid+1;i<N;i++) tmp2[i]=0;
	ntt(tmp[d],1),ntt(tmp2,1);
	for(int i=0;i<N;i++) tmp[d][i]=1ll*tmp[d][i]*tmp2[i]%mod;
	ntt(tmp[d],-1);
	for(int i=r-l+2;i<N;i++) tmp[d][i]=0;
}

int main() {
	read(n),read(a),read(b);
	if(!a||!b||n<a+b-1) return 0*puts("0");
	if(n==1) return 0*puts("1");
	fac[0]=1;
	for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
	for(mxn=1;mxn<=(n-1)<<1;mxn<<=1);
	prepare();
	solve(0,n-2,0);
	write(1ll*tmp[0][a+b-2]*fac[a+b-2]%mod*qpow(1ll*fac[a-1]*fac[b-1]%mod,mod-2)%mod);
	return 0;
}
原文地址:https://www.cnblogs.com/hbyer/p/10574863.html