【BZOJ 3993】【SDOI 2015】序列统计

http://www.lydsy.com/JudgeOnline/problem.php?id=3992
这道题好难啊。
第一眼谁都能看出来是个dp,设(f(i,j))表示转移到第i位时前i位的乘积模m等于j的方案数。
转移很显然啊(f(i,j)=sum_{x,yin[0,m)}[xymod m=j]f(i-1,x)*f(i-1,y))
这个下标是乘积取模的转移根本无法优化啊。
但注意到题目最下方说m是一个质数。。。
把x=0特判掉,剩下(xin[1,m-1))时把x转化为m的原根的幂次。
设m的原根为(g_m)
那么(f(i,g_m^j)=sum_{x,yin[0,m-1)}[(x+y)mod m=j]f(i-1,g_m^x)*f(i-1,g_m^y))
这样通过原根在[0,m-1)上的不重不漏的一一映射,乘积取模变成加法取模,化成了一个循环卷积的形式。
(话说看模数也知道是NTT啊qwq)循环卷积直接用NTT做就可以了。
但要做N次循环卷积,(Nleq 10^9)。。。
在外面套层快速幂就可以了O(∩_∩)O~~
快速幂套循环卷积的正确性?先不循环卷积然后再压成循环卷积就很好证明啊。不过也可以把快速幂看成一个倍增,每次合并两个dp数组之类的,正确性都显然啊qwq
注意数组不要开小!用于NTT的数组要开到2的幂次qwq
时间复杂度(O(m^2+mlog mlog n))
(看了Menci大大的博客,“把原根的幂次看成多项式的幂次,dp数组记录在系数里”这个东西还叫生成函数?)

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;

const int M = 8193;
const int g = 267198;
const int p = 1004535809;

int nin;

int ipow(int a, int b) {
	int ret = 1, w = a;
	while (b) {
		if (b & 1) ret = 1ll * ret * w % p;
		w = 1ll * w * w % p;
		b >>= 1;
	}
	return ret;
}

int da[M << 1], db[M << 1], dc[M << 1], rev[M << 1], nWN[15], WN[15], n, m, x, s, C;

void DNT(int *a, int *A, int flag) {
	int tmp = 1;
	for (int i = 0; i < n; ++i) A[rev[i]] = a[i];
	for (int len = 2; len <= n; len <<= 1, ++tmp) {
		int mid = len >> 1, wn = flag == 1 ? WN[tmp] : nWN[tmp];
		for (int i = 0; i < n; i += len) {
			int w = 1;
			for (int j = 0; j < mid; ++j) {
				int t = A[i + j], u = 1ll * A[i + j + mid] * w % p;
				A[i + j] = (t + u) % p;
				A[i + j + mid] = (t - u + p) % p;
				w = 1ll * w * wn % p;
			}
		}
	}
	
	if (flag == -1)
		for (int i = 0; i < n; ++i)
			A[i] = 1ll * A[i] * nin % p;
}

int top;

void NTTsqr(int *a) {
	DNT(a, da, 1);
	for (int i = 0; i < n; ++i)
		da[i] = 1ll * da[i] * da[i] % p;
	DNT(da, a, -1);
	for (int i = 0; i < top; ++i) {
		(a[i] += a[i + top]) %= p;
		a[i + top] = 0;
	}
}

void NTT(int *a, int *b) {
	DNT(a, da, 1); DNT(b, db, 1);
	for (int i = 0; i < n; ++i)
		dc[i] = 1ll * da[i] * db[i] % p;
	DNT(dc, a, -1);
	for (int i = 0; i < top; ++i) {
		(a[i] += a[i + top]) %= p;
		a[i + top] = 0;
	}
}

void init() {
	int tot = 0, num = top << 1;
	while (num) {num >>= 1; ++tot;}
	n = 1 << tot;
	nin = ipow(n, p - 2);
	
	int res;
	for (int i = 0; i < n; ++i) {
		num = i; res = 0;
		for (int j = 0; j < tot; ++j) {
			res <<= 1;
			if (num & 1) res |= 1;
			num >>= 1;
		}
		rev[i] = res;
	}
	
	WN[14] = g, nWN[14] = ipow(g , p - 2);
	for (int i = 13; i >= 1; --i) {
		WN[i] = 1ll * WN[i + 1] * WN[i + 1] % p;
		nWN[i] = 1ll * nWN[i + 1] * nWN[i + 1] % p;
	}
}

bool shown[M];
int r[M << 1], ww[M << 1], c[M];

int main() {
	scanf("%d%d%d%d", &C, &m, &x, &s); top = m - 1;
	if (x == 0) {printf("%d
", (ipow(m, n) - ipow(m - 1, n) + p) % p); return 0;}
	int num;
	for (int i = 2; i < m; ++i) {
		int ret = 1; bool flag = true;
		for (int j = 0; j < top; ++j) {
			ret = 1ll * ret * i % m;
			if (shown[ret]) {flag = false; break;}
			shown[ret] = true;
		}
		
		if (!flag || ret != 1) {
			ret = 1;
			for (int j = 0; j < top; ++j) {
				ret = 1ll * ret * i % m;
				if (shown[ret]) shown[ret] = false;
				else break;
			}
		} else {
			num = i;
			break;
		}
	}
	
	int ret = 1;
	for (int i = 0; i < top; ++i) {
		c[ret] = i;
		ret = 1ll * ret * num % m;
	}
	
	init();
	
	int tt;
	for (int i = 1; i <= s; ++i) {
		scanf("%d", &tt);
		if (tt != 0) ww[c[tt]] = 1;
	}
	
	r[0] = 1;
	while (C) {
		if (C & 1) NTT(r, ww);
		NTTsqr(ww);
		C >>= 1;
	}
	
	printf("%d
", r[c[x]]);
	
	return 0;
}
原文地址:https://www.cnblogs.com/abclzr/p/6413960.html