luogu P5577

简要题意:给一个序列,对每个 (i)(k) 进制意义下不进位加法和为 (i) 的方案数。

显然可以暴力多维FFT。弱化一点的版本是异或,即(k=2)。(参考UNR#2黎明前的巧克力


考虑怎么优化。考虑 (1+x^a_i) 对应的多项式,高维FFT后可以发现每一位上的值形如 (w_k^i+1)

(k) 很小,考虑对每一位求出(w_k^i+1)有多少个。

可以先把 (1) 去掉,考虑每一个位置上的多项式是什么,可以发现答案我们先将所有多项式加起来,然后FFT之后,扩域意义下的 (w_k^i) 系数就是最后的 (w_k^i+1) 在这一位上有多少个。

这个 FFT 相当于做了一次 dp,状态是 FFT 操作进行到现在的时候每一位有多少个多项式会变成 (w_k^i)

然后再 IDFT 回来就行了。

需要轻微卡常。

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
inline int add(int a,int b){a+=b;return a>=mod?a-mod:a;}
inline int sub(int a,int b){a-=b;return a<0?a+mod:a;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
inline int qpow(int a,int b){int ret=1;for(;b;b>>=1,a=mul(a,a))if(b&1)ret=mul(ret,a);return ret;}
/* math */
const int N = 1e6+5;
int k,n,invk;

struct dat{
	int a[6];
	dat(int a_1=0,int a_2=0,int a_3=0,int a_4=0,int a_5=0,int a_6=0) {
		a[0]=a_1,a[1]=a_2,a[2]=a_3,a[3]=a_4,a[4]=a_5,a[5]=a_6;
	}
	dat operator + (const dat b){
		return dat(a[0]+b.a[0],a[1]+b.a[1],a[2]+b.a[2],a[3]+b.a[3],a[4]+b.a[4],a[5]+b.a[5]);
	}
	int val(){
		return sub(add(a[0],a[1]),add(a[2],a[3]));
	}
	void print(){
		cout << "(";for(int i=0;i<k;i++)cout << a[i] << ",";cout << ")";
	}
};

dat add(dat a,dat b){
	dat ret=a+b;for(int i=0;i<k;i++)ret.a[i]=ret.a[i]>=mod?ret.a[i]-mod:ret.a[i];
	return ret;
}
dat mul(dat x,int d){
	dat ret;
	for(int i=0;i<k;i++)ret.a[(i+d)%k]=x.a[i];
	return ret;
}
dat div(dat x,int inver){
	for(int i=0;i<k;i++)x.a[i]=mul(x.a[i],inver);
	return x;
}
dat ret[10];
void fft(dat *a,int flg){
	for(int i=0;i<k;i++){
		ret[i]=dat();
		for(int j=0;j<k;j++){
			ret[i]=add(ret[i],mul(a[j],(i*j)%k));
		}
	}
	for(int i=0;i<k;i++)a[i]=ret[i];
	if(flg==1)return;
	for(int i=1;i<k-i;i++)swap(a[i],a[k-i]);
	for(int i=0;i<k;i++)a[i]=div(a[i],invk);
}
dat a[10];
inline void Transfer(dat *f,int len,int flg){
	for(int Step=1;Step<len;Step*=k){
		int D = Step*k;
		for(int i=0;i<len;i+=D){
			for(int j=0;j<Step;j++){
				for(int t=0;t<k;t++)a[t]=f[i+j+t*Step];
				fft(a,flg);
				for(int t=0;t<k;t++)f[i+j+t*Step]=a[t];
			}
		}
	}
}
inline dat mul(dat x,dat y){
	dat ret;
	for(int i=0;i<k;i++)for(int j=0;j<k;j++){
		ret.a[(i+j)%k]=add(ret.a[(i+j)%k],mul(x.a[i],y.a[j]));
	}
	return ret;
}
dat f[100010];
dat g[100010];
int len,m;

dat pw[1010],pw2[1010];
int readk(){
	char x=0;int ans=0;while(x<'0'||x>'9')x=getchar();
	while(x>='0'&&x<='9')ans=ans*k+x-'0',x=getchar();
	return ans;
}

inline dat qpow(dat x,int k){
	dat ret=dat(1);
	for(;k;k>>=1,x=mul(x,x))if(k&1)ret=mul(ret,x);
	return ret;
}

int main()
{
	cin >> n >> k >> m;
	int len=1;
	for(int i=1;i<=m;i++)len=len*k;
	invk=qpow(k,mod-2);
	for(int i=1;i<=n;i++){
		int a=readk();
		f[a].a[0]++;
	}
	Transfer(f,len,1);
	for(int i=0;i<len;i++)g[i]=dat(1);
	for(int j=0;j<k;j++){
		dat ml=dat(1);
		ml.a[j]++;
		pw[0]=pw2[0]=dat(1);
		dat ml2=qpow(ml,1000);
		for(int i=1;i<=1000;i++)pw[i]=mul(pw[i-1],ml2);
		for(int i=1;i<=1000;i++)pw2[i]=mul(pw2[i-1],ml);
		for(int i=0;i<len;i++){
			g[i]=mul(g[i],pw[f[i].a[j]/1000]);
			g[i]=mul(g[i],pw2[f[i].a[j]%1000]);
		}
	}
	for(int i=0;i<len;i++)f[i]=g[i];
	Transfer(f,len,-1);
	for(int i=0;i<len;i++)printf("%d
",f[i].val());
}
原文地址:https://www.cnblogs.com/weiyanpeng/p/11845387.html