codechef JIIT

考虑如何计算操作后的奇数个数。
假设在行操作了(i),列操作(j)次。
由补集转化,操作后奇数个数(=im+jn-ij)
(f_i)表示为行操作(i)次的答案,(g_i)表示列操作(i)次的答案,则答案就是符合要求的所有(f_i*g_j)
列出答案的EGF。
由于每行是相同的,强制让选择的(i)行在网格的前(i)
(f_i=({e^x-e^{-x}over 2})^i*({e^x+e^{-x}over 2})^{n-i}[x^q]*q!*C_n^i)
后面的(C_n^i)表示选择的方案数。
(f_i=2^n(e^x-e^{-x})^i*(e^x+e^{-x})^{n-i}[x^q]*q!*C_n^i)
如果使用二项式定理展开((e^x-e^{-x}))((e^x+e^{-x})),再暴力进行多项式乘法,则生成一个(2n)次关于(e^x)的多项式(指数可能是负的),时间复杂度(O(n^3))
但是注意到(f_{i+1})的后面的项等于(f_i)的多项式乘以((e^x-e^{-x}))再除以((e^x+e^{-x})),所以可以在(O(n))的时间内得到下面的多项式,时间复杂度(O(n^2))
(g)可以同理计算。
这样子已经可以通过本题,然而我们还有更为优秀的做法。
考虑容斥。(CTS2019 珍珠)
(h_i)表示钦定至少(i)行为为奇数,其它任意。
列出答案的EGF。
(h_i=sum C_i^j*g_j)
(h_i=({e^x-e^{-x}over 2})^i*e^{x(n-i)}[x^q]*q!*C_n^i)
(h_i=2^{-i}({e^x-e^{-x}})^i*e^{x(n-i)}[x^q]*q!*C_n^i)
(h_i=2^{-i}frac{1}{e^{xi}}({e^{2x}-1})^i*e^{x(n-i)}[x^q]*q!*C_n^i)
(h_i=2^{-i}({e^{2x}-1})^i*e^{x(n-2i)}[x^q]*q!*C_n^i)
(h_i=C_n^i2^{-i}sum_{j=0}^{i}e^{2j}(-1)^{i-j}e^{x(n-2i)}[x^q]*q!C_{i}^j)
(h_i=C_n^i2^{-i}sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2i+2j)}[x^q]*q!C_{i}^j)
(h_i=C_n^i2^{-i}sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2(i-j))}[x^q]*q!C_{i}^j)
(h_i=C_n^i2^{-i}i!sum_{j=0}^{i}(-1)^{i-j}((n-2(i-j))^qfrac{1}{(i-j)!j!})
(a_i=(-1)^i((n-2i))^qfrac{1}{i!},b_i=frac{1}{i!})
(a*b=h)
考虑二项式反演,(g_i=sum_{jgeq i}C_{j}^i(-1)^{j-i}h_j=frac{1}{i!}sum_{jgeq i}frac{j!}{(j-i)!}(-1)^{j-i}h_j)
(a_i=(-1)^ifrac{1}{i!},b_i=i!h_i)
(a,b)的减法卷积就是(g)
计算答案考虑(im+jn-ijleq k)
(j(n-2i)leq k-im)
根据((n-2i))的正负性分类讨论,使用前缀和计算。
时间复杂度(O(nlog_2n))
细节:
在卷积的时候注意把vector数组resize,以防后面的项错误产生贡献

#include<bits/stdc++.h>
using namespace std;
#define mo 998244353
#define N 500010
#define ll unsigned long long
#define int long long
#define pl vector<int>
int qp(int x,int y){
	int r=1;
	for(;y;y>>=1,x=1ll*x*x%mo)
		if(y&1)r=1ll*r*x%mo;
	return r;
}
int rev[N],v,le,w[N],p[N],ans[N];
void deb(pl x){
	for(int i:x)cout<<i<<' ';
	puts("");
}
void init(int n){
	v=1;
	le=0;
	while(v<n)le++,v*=2;
	for(signed i=0;i<v;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(le-1));
	int g=qp(3,(mo-1)/v);
	w[v/2]=1;
	for(int i=v/2+1;i<v;i++)
		w[i]=1ull*w[i-1]*g%mo;
	for(signed i=v/2-1;~i;i--)
		w[i]=w[i*2];
}
void fft(int v,pl &a,int t){
	static unsigned long long b[N];
	int s=le-__builtin_ctz(v);
   	for(int i=0;i<v;i++)
   		b[rev[i]>>s]=a[i];
	int c=0;
	w[0]=1;
    for(signed i=1;i<v;i*=2,c++)
    	for(signed r=i*2,j=0;j<v;j+=r)
            for(signed k=0;k<i;k++){
               	int tx=b[j+i+k]*w[k+i]%mo;
            	b[j+i+k]=b[j+k]+mo-tx;
            	b[j+k]+=tx;
            }
    for(int i=0;i<v;i++)
    	a[i]=b[i]%mo;
    if(t==0)return;
    int iv=qp(v,mo-2);
    for(signed i=0;i<v;i++)
    	a[i]=1ull*a[i]*iv%mo;
    a.resize(v);
    reverse(a.begin()+1,a.end());
}
pl operator *(pl x,pl y){
	int s=x.size()+y.size()-1;
	if(x.size()<=30||y.size()<=30){
		pl r;
		r.resize(s);
		for(int i=0;i<x.size();i++)
			for(int j=0;j<y.size();j++)
				r[i+j]=(r[i+j]+x[i]*y[j])%mo;
		return r;
	}
	init(s);
	x.resize(v);
	y.resize(v);
	fft(v,x,0);
	fft(v,y,0);
	//deb(x);
	//deb(y);
	for(int i=0;i<v;i++)
		x[i]=x[i]*y[i]%mo;
	fft(v,x,1);
	x.resize(s);
	return x;
}
void ad(pl &x,pl y,int l){
	x.resize(max((int)x.size(),(int)y.size()+l));
	for(int i=0;i<y.size();i++)
		x[i+l]=(x[i+l]+y[i])%mo;
}
pl operator +(pl x,pl y){
	ad(x,y,0);
	return x;
}
int f[N],g[N],n,m,q,k,jc[N],ij[N],h[N],s[N];
int c(int y,int x){
	if(y<0||x<0||y<x)
		return 0;
	return jc[y]*ij[x]%mo*ij[y-x]%mo;
}
void cal(int *f,int l){
	pl x,y;
	x.resize(l+1);
	y.resize(l+1);
	for(int i=0;i<=l;i++){
		x[i]=qp(mo-1,i)*qp(l-2*i,q)%mo*ij[i]%mo;
		y[i]=ij[i];
	}
	x=x*y;
	for(int i=0;i<=l;i++)
		h[i]=x[i]*qp(qp(2,mo-2),i)%mo*c(l,i)%mo*jc[i]%mo;
	for(int i=0;i<=l;i++)
		x[i]=jc[i]*h[i]%mo;
	x.resize(l+1);
	for(int i=0;i<=l;i++)
		y[l-i]=qp(mo-1,i)*ij[i]%mo;
	x=x*y;
	for(int i=0;i<=l;i++)
		f[i]=x[i+l]*ij[i]%mo;
}
signed main(){
	int T;
	jc[0]=1;
	for(int i=1;i<N;i++)
		jc[i]=jc[i-1]*i%mo;
	ij[N-1]=qp(jc[N-1],mo-2);
	for(int i=N-1;i;i--)
		ij[i-1]=ij[i]*i%mo;
	scanf("%lld",&T);
	while(T--){
		memset(f,0,sizeof(f));
		memset(g,0,sizeof(g));
		scanf("%lld%lld%lld%lld",&n,&m,&q,&k);
		cal(f,n);
		cal(g,m);
		int va=0;
		s[0]=g[0];
		for(int i=1;i<=m;i++)
			s[i]=(s[i-1]+g[i])%mo;
		for(int i=0;i<=n;i++){
			int p=n-2*i;
			if(!p){
				if(k-i*m>=0)
					va=(va+f[i]*s[m])%mo;
			}
			if(p<0){
				int v=ceil((long double)(k-i*m)/(long double)(n-2*i));
				if(v<=0){
					va=(va+s[m]*f[i]%mo)%mo;
				}
				else{
					va=(va+(s[m]-s[v-1]+mo)%mo*f[i]%mo)%mo;
				}
			}
			if(p>0){
				int v=floor((long double)(k-i*m)/(long double)(n-2*i));
				if(v>=0)
					va=(va+f[i]*s[min(v,m)])%mo;
			}
		}
		printf("%lld
",va);
	}
}
原文地址:https://www.cnblogs.com/ctmlpfs/p/14141967.html