LOJ6059「2017 山东一轮集训 Day1」Sum

https://loj.ac/p/6059

倍增 dp
考虑看到 (m,p) 很小,dp 的话转移 (n) 遍但是 (n) 很大,于是想矩阵快速幂但发现并不太行
具体的,设 (f(i,u,k)) 表示考虑到第 (i) 位,当前数字 (mod p=u),当前数位和为 (m)
考虑转移,有:(f(i+1,(10u+v)mod p,k+v)=sum_{v=0}^9 f(i,u,k)),这一遍是 (O(10pm))

不能矩乘就考虑倍增 dp,从 (f(i)) 直接转移到 (f(2i)),那么式子是:

[f(2i,(u imes 10^i+v)mod p,k)=sum_{x=0}^k f(i,u,x)f(i,v,k-x) ]

发现这样直接转移是 (O(p^2m^2)) 的,但 (sum) 里是卷积的形式,直接给 (f(i,u)) 都 ntt 一波,就变成每次转移 (O(p^2m))
倍增转移一共 (O(log n)) 次,所以总复杂度是 (O(p^2mlog n))

#define G 114514
#define mod 998244353
#define N 6006
long long power(long long a,long long b,long long o=mod){
	long long ans=1;
	while(b){
		if(b&1) ans=ans*a%o;
		a=a*a%o;b>>=1;
	}
	return ans;
}
inline void add(long long &a,long long b){a=(a+b>=mod)?(a+b-mod):(a+b);}
inline long long Mod(long long a){return a>=mod?(a-mod):a;}
int rev[N];
inline int init(int n){
	int max=1;while(max<=n) max<<=1;
	for(int i=0;i<max;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(max>>1):0;
	return max;
}
inline void ntt(int n,long long *a,int type){
	for(int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
	for(int h=1;h<n;h<<=1){
		long long gn=power(G,(mod-1)/(h<<1)),g,o;
		gn=type?gn:power(gn,mod-2);
		for(int i=0;i<n;i+=h<<1){
			g=1;
			for(int j=i;j<i+h;j++,g=g*gn%mod){
				o=g*a[j+h]%mod;
				a[j+h]=Mod(a[j]-o+mod);add(a[j],o);
			}
		}
	}
	if(!type){
		long long inv=power(n,mod-2);
		for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
	}
}
long long f[55][N],g[55][N];
inline void work(int n,int p,int m){
	int len=init(m*2+2);
	int pos=30;
	while(!(n&(1<<pos))) pos--;
	f[0][0]=1;
	int i=0;
	for(;~pos;pos--){
		if(!i) goto ADD;
		for(int u=0;u<p;u++) ntt(len,f[u],1);
		for(int u=0;u<p;u++)for(int v=0;v<p;v++){
			long long pp=(u*power(10,i,p)+v)%p;
			for(int x=0;x<len;x++) add(g[pp][x],f[u][x]*f[v][x]%mod);
//			for(int x=0;x<=m;x++)for(int k=0;k<=x;k++)
//				add(g[(u*pp+v)%p][x],f[u][k]*f[v][x-k]%mod);
		}
		for(int u=0;u<p;u++){
			ntt(len,g[u],0);
			for(int x=m+1;x<len;x++) g[u][x]=0;
		}
		std::memcpy(f,g,sizeof f);std::memset(g,0,sizeof g);
		i<<=1;
ADD:	if(n&(1<<pos)){
			for(int u=0;u<p;u++)for(int k=0;k<=m;k++){
				for(int v=0;v<10&&k+v<=m;v++) add(g[(10*u+v)%p][k+v],f[u][k]);
			}
			std::memcpy(f,g,sizeof f);std::memset(g,0,sizeof g);
			i++;
		}
	}
	assert(i==n);
}
int main(){
	int n=read(),p=read(),m=read();
	work(n,p,m);
	printf("%lld ",f[0][0]);
	for(int i=1;i<=m;i++) add(f[0][i],f[0][i-1]),printf("%lld ",f[0][i]);
	return 0;
}
原文地址:https://www.cnblogs.com/suxxsfe/p/15411591.html