若一个格子左、上、下都有黑格子,那么该格子是否为黑色是不影响最后的三元组的,因此只用统计这样的格子为白色的情况,这样就能考虑到所有三元组了。
考虑按列 (DP),设 (f(i,j)) 表示考虑前 (i) 列,已经有 (j) 行至少有一个黑色格子的行的方案数,最终答案为 (sum inom{n}{i}f(m,i))。转移就是每次新增一列,考虑新增的有黑色格子的行数 (k),即 (f(i,j)) 转移到 (f(i+1,j+k))。
若 (k=0),那么这一列对三元组的贡献就只有这一列的最小行标和最大行标,相当于从已有的 (j) 行中选出不超过 (2) 个,贡献为 (1+j+inom{j}{2})。
若 (k>1),考虑原有的行是否在这新加的一列中放黑格子,因为黑格子左、上、下一定有一个方向没有黑格子,所以若是原有的行放黑格子,最多有两行放,且必须是在最小行标或者最大行标的位置。讨论一下最小行标和最大行标是来自新加入的行还是来自原有的行,得贡献为 (inom{j+k}{k}+2inom{j+k}{k+1}+inom{j+k}{k+2}=inom{j+k+2}{k+2}),这个组合数也可以直接用组合意义来说明。
发现转移是卷积的形式,用 (NTT) 优化后可以做到 (O(nmlog n))。
#include<bits/stdc++.h>
#define maxn 64010
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,lim=1,inv,ans;
int rev[maxn];
ll f[maxn],g[maxn],h[maxn],fac[maxn],ifac[maxn];
ll qp(ll x,ll y)
{
ll v=1;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
void NTT(ll *a,int type)
{
for(int i=0;i<lim;++i)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int len=1;len<lim;len<<=1)
{
ll wn=qp(3,(p-1)/(len<<1));
for(int i=0;i<lim;i+=len<<1)
{
ll w=1;
for(int j=i;j<i+len;++j,w=w*wn%p)
{
ll x=a[j],y=w*a[j+len]%p;
a[j]=(x+y)%p,a[j+len]=(x-y+p)%p;
}
}
}
if(type==1) return;
for(int i=0;i<lim;++i) a[i]=a[i]*inv%p;
reverse(a+1,a+lim);
}
void init()
{
fac[0]=ifac[0]=f[0]=1;
for(int i=1;i<=n+2;++i) fac[i]=fac[i-1]*i%p;
ifac[n+2]=qp(fac[n+2],p-2);
for(int i=n+1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
for(int i=1;i<=n;++i) h[i]=ifac[i+2];
while(lim<=(n<<1)) lim<<=1;
for(int i=0;i<lim;++i) rev[i]=(rev[i>>1]>>1)|((i&1)?lim>>1:0);
inv=qp(lim,p-2),NTT(h,1);
}
int main()
{
read(n),read(m),init();
while(m--)
{
for(int i=0;i<=n;++i) g[i]=f[i]*ifac[i]%p,f[i]=f[i]*(((ll)i*i+i+2)/2%p)%p;
for(int i=n+1;i<lim;++i) g[i]=0;
NTT(g,1);
for(int i=0;i<lim;++i) g[i]=g[i]*h[i]%p;
NTT(g,-1);
for(int i=1;i<=n;++i) f[i]=(f[i]+g[i]*fac[i+2]%p)%p;
}
for(int i=0;i<=n;++i) ans=(ans+f[i]*fac[n]%p*ifac[i]%p*ifac[n-i]%p)%p;
printf("%d",ans);
return 0;
}