LOJ #6059. 「2017 山东一轮集训 Day1」Sum

陈指导秒掉的题,不过确实好像挺显然的说

首先假设我们眼瞎没看见(nle 10^9),显然就是一个数位DP,设(f_{i,j,k})表示做了(i)位,(i)位的值模(p)(j),每位之和为(k)的方案数,转移枚举填哪个数即可

然后现在我们发现(n)很大,因此我们套路地选择倍增,只要考虑两种情况

  • (f_i o f_{i+1}),同上,直接枚举即可
  • (f_i o f_{2i}),稍微推一下转移方程为(f_{2i,(j+k)operatorname{mod} p,u+v}=sum_{u,v} f_{i,j,u} imes f_{i,k,v})

显然后面的那个转移形式是个卷积,用NTT维护下即可

注意一个坑点,后四个点(ple 16),但是前面的点(ple 50),因此我和陈指导都傻乎乎的RE了一发

#include<cstdio>
#include<iostream>
#define RI register int
#define CI const int&
using namespace std;
const int N=1005,mod=998244353;
int n,p,m,f[50][N],g[50][N],p10,ans;
inline int quick_pow(int x,int p=mod-2,int mul=1)
{
	for (;p;p>>=1,x=1LL*x*x%mod) if (p&1) mul=1LL*mul*x%mod; return mul;
}
namespace Poly
{
	int rev[N<<2],A[N<<2],B[N<<2],lim,p;
	inline void NTT(int* f,CI opt)
	{
		RI i,j,k; for (i=0;i<lim;++i) if (i<rev[i]) swap(f[i],f[rev[i]]);
		for (i=1;i<lim;i<<=1)
		{
			int D=quick_pow(3,opt==1?(mod-1)/(i<<1):mod-1-(mod-1)/(i<<1)),W;
			for (j=0;j<lim;j+=(i<<1)) for (W=1,k=0;k<i;++k,W=1LL*W*D%mod)
			{
				int x=f[j+k],y=1LL*f[i+j+k]*W%mod;
				f[j+k]=(x+y)%mod; f[i+j+k]=(x-y+mod)%mod;
			}	
		}
		if (!~opt)
		{
			int Inv=quick_pow(lim); for (i=0;i<lim;++i) f[i]=1LL*f[i]*Inv%mod;
		}
	}
	inline void init(CI n)
	{
		for (lim=1,p=0;lim<=(m<<1);lim<<=1,++p);
		for (RI i=0;i<lim;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<p-1);
	}
	inline void Convolution(int* a,int* b,int* c)
	{
		RI i; for (i=0;i<lim;++i) A[i]=B[i]=0;
		for (i=0;i<=m;++i) A[i]=a[i],B[i]=b[i];
		for (NTT(A,1),NTT(B,1),i=0;i<lim;++i) A[i]=1LL*A[i]*B[i]%mod;
		for (NTT(A,-1),i=0;i<=m;++i) (c[i]+=A[i])%=mod; 
	}
};
inline void solve(CI n)
{
	if (!n) return (void)(f[0][0]=p10=1); solve(n>>1);
	RI i,j,k; for (i=0;i<p;++i) for (j=0;j<=m;++j) g[i][j]=0;
	for (i=0;i<p;++i) for (j=0;j<p;++j) Poly::Convolution(f[i],f[j],g[(i*p10+j)%p]);
	for (i=0;i<p;++i) for (j=0;j<=m;++j) f[i][j]=g[i][j];
	if (n&1)
	{
		for (i=0;i<p;++i) for (j=0;j<=m;++j) g[i][j]=0;
		for (i=0;i<p;++i) for (j=0;j<=m;++j)
		for (k=0;k<=9&&j+k<=m;++k) (g[(i*10+k)%p][j+k]+=f[i][j])%=mod;
		for (i=0;i<p;++i) for (j=0;j<=m;++j) f[i][j]=g[i][j];
	}
	p10=p10*p10%p; if (n&1) (p10*=10)%=p;
}
int main()
{
	scanf("%d%d%d",&n,&p,&m); Poly::init(m); solve(n);
	for (RI i=0;i<=m;++i) (ans+=f[0][i])%=mod,printf("%d ",ans);
	return 0;
}
原文地址:https://www.cnblogs.com/cjjsb/p/13378320.html