LG5577 算力训练 k进制FWT

题意说人话就是给出一个长度为(n)的数列(a_1,a_2,...,a_n),求(prodlimits_{i=1}^n (1+x^{a_i})),其中卷积的下标加法定义为(k)进制不进位加法。


(k)进制不进位加法不难想到(k)进制FWT,所以我们需要快速求出(prodlimits_{i=1}^n mathrm{DWT}_k(1 + x^{a_i})),这里的乘法是点积,最后IDWT回来即可。

因为要用形式幂级数做到乘单位根,所以我们获得了一个(O(nk^mmk^2))的优秀算法,然而并没有什么×用。


考虑优化。可以发现任意一个(mathrm{DWT}_k(1 + x^{a_i}))中所有数构成的集合都是集合(W = {w_k^i+1|i in [0,k-1] cap Z})的子集,而(|W|=k)并不大,所以我们可以考虑设

[x_{i,j} = sumlimits_{p=1}^n [[x^i]mathrm{DWT}_k(1+x^{a_p}) = w_k^j + 1] ]

那么这(n)个幂级数的卷积的第(i)项的值就是

[prodlimits_{p=0}^{k-1} (w_k^p+1)^{x_{i,p}} ]

可以快速计算。


接下来考虑对于一个确定的(i)如何求出(x_{i,j})的值。毫无疑问需要解方程。

首先我们显然有一个式子是

[x_{i,0} + x_{i,1} + x_{i,2}+...+x_{i,k-1}=n ]

设幂级数(A)满足([x^j]A = sumlimits_{i=1}^n [a_i = j]),考虑([x^i] mathrm{DWT}_k(A))。对于一个(p)和一个满足([x^i]mathrm{DWT}_k(1+x^{a_j}) = w_k^p + 1)(j),可以发现其对([x^i]mathrm{DWT}_k(A))的贡献是(w_k^p)。所以我们有

[w_k^0x_{i,0}+w_k^1x_{i,1}+...+w_k^{k-1}x_{i,k-1} = [x^i]mathrm{DWT}_k(A) ]


观察两个方程的系数向量:

[1,1,1,1,...,1 ]

[w_k^0,w_k^1,w_k^2,...,w_k^{k-1} ]

有些范德蒙德矩阵的Feeling。那么我们能不能求出

[w_k^0x_{i,0}+w_k^2x_{i,1}+...+w_k^{2(k-1)}x_{i,k-1} ]

也就是原来对某个位置贡献为(w_k^p)的数组,在变换之后它的贡献变为(w^{2p}_k)

可以发现这相当于将FWT过程中的单位根平方一下。所以我们只需要把所有(w_k)都变为(w_k^2)就可以了。具体来说,以(k=5)为例,(k)进制FWT使用下面的位矩阵:

[left( egin{array}{cccc} 1 & 1 & 1 & 1 & 1 \ 1 & w_5^1 & w_5^2 & w_5^3 & w_5^4 \ 1 & w_5^2 & w_5^4 & w_5^1 & w_5^3 \ 1 & w_5^3 & w_5^1 & w_5^4 & w_5^2 \ 1 & w_5^4 & w_5^3 &w_5^2 & w_5^1 end{array} ight) ]

现在把单位根平方,位矩阵就变成下面这样:

[left( egin{array}{cccc} 1 & 1 & 1 & 1 & 1 \ 1 & w_5^2 & w_5^4 & w_5^1 & w_5^3 \ 1 & w_5^4 & w_5^3 & w_5^2 & w_5^1 \ 1 & w_5^1 & w_5^2 & w_5^3 & w_5^4 \ 1 & w_5^3 & w_5^1 &w_5^4 & w_5^2 end{array} ight) ]

使用这一个位矩阵进行DWT,则([x^i]mathrm{DWT}_k(A)=w_k^0x_{i,0}+w_k^2x_{i,1}+...+w_k^{2(k-1)}x_{i,k-1})

值得注意的是我们只是求值,所以这个矩阵就算不是合法的位矩阵也可以这么做(毕竟你不需要进行逆操作)。

照葫芦画瓢地可以得到(k)个方程,其系数矩阵是DWT使用的范德蒙德矩阵。所以只要对于得到的所有结果IDWT一下就可以得到所有(x_{i,j})的值。


朴素实现复杂度大概是(O(k^{m+4}m))的((k)次FWT,每一次(k^mmk)次乘法,乘法复杂度(k^2))比较慢。下面是一些实现细节:

  1. 可以发现将单位根乘方之后对幂级数进行FWT得到的每一位的值构成的集合一定是对其进行(k)进制FWT得到的值的集合的子集。可以实现一个函数计算对单位根进行乘方后FWT得到的每一位的值分别对应进行(k)进制FWT后哪一位的值。这样可以把一个(k)摘掉,复杂度变为(O(k^{m+3}m))。这一部分实现可以参考代码中的getid函数。

  2. 单位根在模意义下不存在所以要扩域,即将一个数表示为(a_0w^0+a_1w^1+...+a_{k-1}w^{k-1}),可以发现它是封闭的。如果直接这样做有一个非常大的好处是所有单位根都只有一个位置有值,可以做到(O(k))乘单位根。有一个bug是最后的答案并不是(a_0)所以并没有这样实现。

  3. 延续点2中的问题,我们最后的答案不是(a_0)的原因是有一些单位根它们的和为(0),比如说(sumlimits_{i=0}^{k-1} w_k^{k-1}=0),或者在(2 mid k)(w_k^i = w_k^{i+frac{k}{2}}),这意味着一个整数在这个域上的表示不是唯一的。这就是为什么我写成了二合一:

  • 对于(k=5)的情况将(w^4 = -(w^0+w^1+w^2+w^3))代入,将一个数表示为(a_0w^0+a_1w^1+a_2w^2+a_3w^3)的形式进行求解。
  • 对于(k=6)的情况先使用(w_1=-w_4,w_3=-w_0,w_5=-w_2),这样就只剩下(w_0,w_2,w_4),然后代入(w_4=-w_0-w_2),这样我们可以只用将数表示为(a_0w_0+a_1w_2)的形式就可以求解了。

以这样的形式求解最后得到的(a_0)就是答案。

然而我很想知道为什么这么消了之后一个整数就一定能被唯一表示……


code:

#include<bits/stdc++.h>
using namespace std;

int read(){
	int a = 0; char c = getchar(); while(!isdigit(c)) c = getchar();
	while(isdigit(c)){a = a * 10 + c - 48; c = getchar();} return a;
}

const int MOD = 998244353;
int upd(int x){return x + (x >> 31 & MOD);}
int add(int x , int y){return upd(x + y - MOD);}
int sub(int x , int y){return upd(x - y);}
int mul(int x , int y){return x <= 1 || y <= 1 ? x * y : 1ll * x * y % MOD;}

int N , K , M , arr[1000003] , pwK[10]; long long IVK;
int id[100003];
void getid(int L , int pw){
	if(L == 1) return (void)(id[0] = 0);
	int p = L / K; getid(p , pw);
	for(int i = 0 ; i < p ; ++i){
		int t = id[i];
		for(int j = 0 ; j < K ; ++j , t = (t + p * pw) % L)
			id[i + j * p] = t;
	}
}

template < typename op >
void FWT(op *now , op *tmp , op *w , int tp){
	for(int i = 0 ; i < M ; ++i){
		for(int j = 0 ; j < pwK[M] ; j += pwK[i + 1])
			for(int k = 0 ; k < pwK[i] ; ++k){
				for(int l = 0 ; l < K ; ++l) tmp[j + k + pwK[i] * l] = op();
				for(int l = 0 ; l < K ; ++l)
					for(int p = 0 ; p < K ; ++p)
						tmp[j + k + pwK[i] * l] = tmp[j + k + pwK[i] * l] +
							(!l || !p ? now[j + k + pwK[i] * p] : now[j + k + pwK[i] * p] * w[(tp == 1 ? l : K - l) * p % K]);
			}
		for(int j = 0 ; j < pwK[M] ; ++j) now[j] = tmp[j];
	}
}

namespace solve1{
	struct op{
		int arr[4];
		op(int _a = 0 , int _b = 0 , int _c = 0 , int _d = 0){arr[0] = _a; arr[1] = _b; arr[2] = _c; arr[3] = _d;}
		int& operator [](int x){return arr[x];}
		friend op operator +(op x , op y){op t; for(int i = 0 ; i < 4 ; ++i) t[i] = add(x[i] , y[i]); return t;}
		friend op operator -(op x , op y){op t; for(int i = 0 ; i < 4 ; ++i) t[i] = sub(x[i] , y[i]); return t;}
		friend op operator *(op x , op y){
			op t; int sum4 = 0;
			for(int j = 0 ; j < 4 ; ++j)
				if(y[j])
					for(int i = 0 ; i < 4 ; ++i){
						int id = i + j >= 5 ? i + j - 5 : i + j;
						(id == 4 ? sum4 : t[id]) = add(id == 4 ? sum4 : t[id] , mul(x[i] , y[j]));
					}
			if(sum4) for(int i = 0 ; i < 4 ; ++i) t[i] = sub(t[i] , sum4);
			return t;
		}
	}now[100003] , tmp[100003] , val[5][100003] , w[5]{op(1),op(0,1),op(0,0,1),op(0,0,0,1),op(MOD-1,MOD-1,MOD-1,MOD-1)} , pww1[5][1003] , pww2[5][1003];

	void init_pww(){
		for(int i = 0 ; i < K ; ++i){
			pww1[i][0] = pww2[i][0] = op(1); pww1[i][1] = w[i] + op(1);
			for(int j = 2 ; j <= 1000 ; ++j) pww1[i][j] = pww1[i][j - 1] * pww1[i][1];
			pww2[i][1] = pww1[i][1000];
			for(int j = 2 ; j <= 1000 ; ++j) pww2[i][j] = pww2[i][j - 1] * pww2[i][1];
		}
	}

	op getpw(int id , int val){return pww2[id][val / 1000] * pww1[id][val % 1000];}
	
	void work(){
		init_pww();
		for(int i = 1 ; i <= N ; ++i){
			int tmp = arr[i] , t = 0 , tms = 1; while(tmp){t += tmp % 10 * tms; tms *= K; tmp /= 10;} ++now[t][0];
		}
		FWT(now , tmp , w , 1); for(int i = 0 ; i < pwK[M] ; ++i) val[0][i][0] = N;
		for(int i = 1 ; i < K ; ++i){getid(pwK[M] , i); for(int j = 0 ; j < pwK[M] ; ++j) val[i][j] = now[id[j]];}
		for(int i = 0 ; i < pwK[M] ; ++i){
			for(int j = 0 ; j < K ; ++j) tmp[j] = op(0 , 0);
			for(int j = 0 ; j < K ; ++j) for(int l = 0 ; l < K ; ++l) tmp[j] = tmp[j] + val[l][i] * w[l * (K - j) % K];
			now[i] = op(1); for(int j = 0 ; j < K ; ++j) now[i] = now[i] * getpw(j , 1ll * tmp[j][0] * IVK % MOD);
		}
		FWT(now , tmp , w , -1); int iv = 1; for(int i = 0 ; i < M ; ++i) iv = 1ll * iv * IVK % MOD;
		for(int i = 0 ; i < pwK[M] ; ++i) printf("%lld
" , 1ll * now[i][0] * iv % MOD);
	}
}

namespace solve2{
	struct op{
		int x , y; op(int _x = 0 , int _y = 0) : x(_x) , y(_y){}
		friend op operator +(op x , op y){return op(add(x.x , y.x) , add(x.y , y.y));}
		friend op operator -(op x , op y){return op(sub(x.x , y.x) , sub(x.y , y.y));}
		friend op operator *(op x , op y){int t = mul(x.y , y.y); return op(sub(mul(x.x , y.x) , t) , sub(add(mul(x.y , y.x) , mul(x.x , y.y)) , t));}
	}now[100003] , tmp[100003] , val[6][100003] , w[6]{op(1),op(1,1),op(0,1),op(MOD-1),op(MOD-1,MOD-1),op(0,MOD-1)} , pww1[6][1003] , pww2[6][1003];

	void init_pww(){
		for(int i = 0 ; i < K ; ++i){
			pww1[i][0] = pww2[i][0] = op(1); pww1[i][1] = w[i] + op(1);
			for(int j = 2 ; j <= 1000 ; ++j) pww1[i][j] = pww1[i][j - 1] * pww1[i][1];
			pww2[i][1] = pww1[i][1000];
			for(int j = 2 ; j <= 1000 ; ++j) pww2[i][j] = pww2[i][j - 1] * pww2[i][1];
		}
	}

	op getpw(int id , int val){return pww2[id][val / 1000] * pww1[id][val % 1000];}
	
	void work(){
		init_pww();
		for(int i = 1 ; i <= N ; ++i){
			int tmp = arr[i] , t = 0 , tms = 1;
			while(tmp){t += tmp % 10 * tms; tms *= K; tmp /= 10;}
			++now[t].x;
		}
		FWT(now , tmp , w , 1); for(int i = 0 ; i < pwK[M] ; ++i) val[0][i].x = N;
		for(int i = 1 ; i < K ; ++i){getid(pwK[M] , i); for(int j = 0 ; j < pwK[M] ; ++j) val[i][j] = now[id[j]];}
		for(int i = 0 ; i < pwK[M] ; ++i){
			for(int j = 0 ; j < K ; ++j) tmp[j] = op(0 , 0);
			for(int j = 0 ; j < K ; ++j) for(int l = 0 ; l < K ; ++l) tmp[j] = tmp[j] + val[l][i] * w[l * (K - j) % K];
			now[i] = op(1 , 0); for(int j = 0 ; j < K ; ++j) now[i] = now[i] * getpw(j , 1ll * tmp[j].x * IVK % MOD);
		}
		FWT(now , tmp , w , -1); int iv = 1; for(int i = 0 ; i < M ; ++i) iv = 1ll * iv * IVK % MOD;
		for(int i = 0 ; i < pwK[M] ; ++i) printf("%lld
" , 1ll * now[i].x * iv % MOD);
	}
}

int main(){
	N = read(); K = read(); M = read(); for(int i = 1 ; i <= N ; ++i) arr[i] = read();
	IVK = MOD + 1; while(IVK % K) IVK += MOD;
	IVK /= K; pwK[0] = 1; for(int i = 1 ; i <= M ; ++i) pwK[i] = pwK[i - 1] * K;
	if(K == 6) solve2::work(); else solve1::work();
	return 0;
}
原文地址:https://www.cnblogs.com/Itst/p/12221839.html