AGC035F

题目大意

题解

怎么又不是正解啊

考虑算重的情况:

有一个格子(i,j),(i,1..j)和(1..i-1,j)刚好被算了一次,横竖就可以在(i,j)上有两种放法

硬点一下,当第i行选了ki时(i,ki +1)不能被竖列放,这样就不会算重

把每一列的生成函数搞出来是这样:

(A(x)=sum frac{n+1-i}{i!}x^i)

最后(A(x)[x^i])表示有i行确定,那么就有n-i行刚好放了m(要除(n-i)!),所以答案就是

(ans=sum A^m(x)[x^i]/(n-i)!)

(A^m(x))可以快速幂求,但是(应该)过不了

这个i!看着就很EGF,用泰勒公式搞♂一下

泰勒公式:(e^x=1+x+frac{x^2}{2!}+frac{x^3}{3!}+...=sum frac{x^i}{i!})

(A(x)=sum frac{n+1-i}{i!}x^i)

(=sum frac{n+1}{i!}x^i-sum frac{i}{i!}x^i)

(=(n+1)e^x-xsum_{i<n} frac{x^i}{i!})

(=(n+1)e^x-xe^x)

(=(n+1-x)e^x)

那么(A^m(x))就是

(A^m(x)=(n+1-x)^me^{mx})

左边二项式展开,右边泰勒展开,卷一下即可

简单又自然

code

#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define C(n,m) (jc[n]*Jc[m]%998244353*Jc[(n)-(m)]%998244353)
#define min(a,b) (a<b?a:b)
#define mod 998244353
#define Mod 998244351
#define ll long long
#define G 3
//#define file
using namespace std;

ll A[1048576],B[1048576],a[1048576],b[1048576],w[500001],jc[500001],Jc[500001],ans;
int N,len,n,m,i,j,k,l;

ll qpower(ll a,int b) {ll ans=1;while (b) {if (b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans;}
void swap(int &x,int &y) {int z=x;x=y;y=z;}
ll dft(ll *a,int type)
{
	int i,j,k,l,S=N,s1=2,s2=1;
	
	fo(i,0,N-1)
	{
		j=i;k=0;
		fo(l,1,len) k=k*2+(j&1),j>>=1;
		A[k]=a[i];
	}
	memcpy(a,A,N*8);
	
	fo(i,1,len)
	{
		ll W=(type==1)?qpower(G,(mod-1)/s1):qpower(G,(mod-1)-(mod-1)/s1);
		S>>=1;
		
		fo(j,0,S-1)
		{
			ll w=1;
			fo(k,0,s2-1)
			{
				ll u=a[j*s1+k],v=a[j*s1+k+s2]*w;
				a[j*s1+k]=(u+v)%mod;
				a[j*s1+k+s2]=(u-v)%mod;
				w=w*W%mod;
			}
		}
		s1<<=1,s2<<=1;
	}
}

void mul(ll *a,ll *b)
{
	ll s=qpower(N,Mod);
	int i;
	
	memset(B,0,sizeof(B));
	memcpy(B,b,4*N);
	dft(a,1);
	dft(B,1);
	fo(i,0,N-1) a[i]=a[i]*B[i]%mod;
	dft(a,-1);
	fo(i,0,N/2-1) a[i]=a[i]*s%mod;
	fo(i,N/2,N-1) a[i]=0;
}

int main()
{
	#ifdef file
	freopen("agc035F.in","r",stdin);
	#endif
	
	scanf("%d%d",&n,&m);len=ceil(log2(n+1))+1;N=qpower(2,len);
	if (n>m) swap(n,m);
	jc[0]=jc[1]=Jc[0]=Jc[1]=w[1]=1;fo(i,2,500000) w[i]=mod-w[mod%i]*(mod/i)%mod,jc[i]=jc[i-1]*i%mod,Jc[i]=Jc[i-1]*w[i]%mod;
	
	fo(i,0,n) a[i]=qpower(n+1,m-i)*C(m,i)*qpower(-1,i)%mod; //or min(n,m)
	fo(i,0,n) b[i]=qpower(m,i)*Jc[i]%mod;
	mul(a,b);
	
	fo(i,0,n) ans=(ans+Jc[n-i]*a[i])%mod;
	fo(i,1,n) ans=ans*i%mod;
	printf("%lld
",(ans+mod)%mod);
	
	fclose(stdin);
	fclose(stdout);
	return 0;
}
原文地址:https://www.cnblogs.com/gmh77/p/12834532.html