洛谷P3321「序列统计」

洛谷P3321「序列统计」

题目描述

(C) 有一个集合 (S),里面的元素都是小于 (m) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 (n) 的数列,数列中的每个数都属于集合 (S)

(C) 用这个生成器生成了许多这样的数列。但是小 (C) 有一个问题需要你的帮助:给定整数 (x),求所有可以生成出的,且满足数列中所有数的乘积  (mod  m) 的值等于 (x) 的不同的数列的有多少个。

小C认为,两个数列 (A)(B) 不同,当且仅当 (exists i ; ext{s.t.} A_i eq B_i)。另外,小 (C) 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 (1004535809) 取模的值就可以了。

输入格式

一行,四个整数,(n,m,x,left | S ight |),其中 (left | S ight |) 为集合 (S) 中元素个数。

第二行,(left | S ight |) 个整数,表示集合 (S) 中的所有元素。

输出格式

一行一个整数表示答案。

输入输出样例

输入

4 3 1 2
1 2

输出

8

说明/提示

【样例说明】

可以生成的满足要求的不同的数列有

((1,1,1,1),;(1,1,2,2),;(1,2,1,2),;(1,2,2,1),;(2,1,1,2),;(2,1,2,1),;(2,2,1,1),;(2,2,2,2))

【数据规模和约定】

对于 (10\%) 的数据,(1leq nleq 1000)

对于 (30\%) 的数据,(3leq mleq 100)

对于 (60\%) 的数据,(3leq mleq 800)

对于 (100\%) 的数据,(1leq nleq 10^9,3leq mleq 8000,1leq x <m)

(m) 为质数,输入数据保证集合 (S) 中元素不重复。

题解

(m) 的值域很小,考虑用Triple一题的思路,将集合中的数放到多项式上,问题就可以转化为一个式子:

[sum_{ijequiv xmod m}a_ia_j ]

这已经很像我们多项式乘法的式子了,但是唯一的不同是,这里的 (i)(j) 是相乘的

[sum_{i+jequiv xmod m}a_ia_j ]

考虑如何将乘法转化成加法?

可以利用高中数学里的对数

对数有个很好的性质:

[log^{ab}=log^{a}+log^{b} ]

不妨将 (i)(j) 都用一个数来取对数,得到 (log^i)(log^j)

在询问 (x) 地方的值的时候,相当于是询问 (log^x) 地方的值

现在问题转化为

[sum_{log^i+log^jequiv log^xmod log^m}a_ia_j ]

问题又来了,要选取哪个底数来对所有的值取对数呢?

利用原根

如果连原根都不知道是什么的小朋友,可以去百度百科初步了解一下,不会原根,你怎么学的 (NTT)

我们想把所有值取 (log) 要保证什么?

比如说我们取的底数是 (g)

我们需要 (1sim m-1)(log_g^imod log_g^m) 互不相同

(1sim m-1)(g^imod m) 互不相同

这不就是原根的第二个性质嘛

(1sim m-1)(g^i) 正好一一对应了 (1sim m - 1) 的所有值

我们就可以把原题中集合 (S) 的每个值去取对数了

然后我们就可以得到一个 (1sim m-1) 的多项式,由于没有常数项难以处理,直接变成 (0sim m-2) 的多项式来处理

对于取模,就很好处理了

两个长度为 (m-2) 的多项式相乘,对于得到的多项式的 (m-1) 次项以后,同时也对答案造成了贡献,次数模上模数加上贡献即可

我们要选 (n) 个数,而且没有像Triple一样“不能选重复的限制”,不用什么乱七八糟的容斥,所以直接 (f(x)^k) 即可

我们可以不用像多项式快速幂那么麻烦的快速幂,而且 (a_0) 不一定为 (1),也用不了

可以像普通实数快速幂一样,每次只留前 (m-2) 位,只不过效率是 (nlog^{2n})

代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

typedef long long ll;
typedef unsigned long long ull;

using namespace std;

const int maxn = 2e5 + 50, INF = 0x3f3f3f3f, mod = 1004535809, inv3 = 334845270;

inline int read () {
	register int x = 0, w = 1;
	register char ch = getchar ();
	for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
	for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
	return x * w;
}

inline void write (register int x) {
	if (x / 10) write (x / 10);
	putchar (x % 10 + '0');
}

int n, m, g, X, s, len = 1, bit;
bool vis[maxn];
int res[maxn], tmp[maxn], ans[maxn];
int f[maxn], id[maxn], rev[maxn];

inline int gqpow (register int a, register int b, register int ans = 1) {
	for (; b; b >>= 1, a = 1ll * a * a % m) 
		if (b & 1) ans = 1ll * ans * a % m;
	return ans;
}

inline int Get_g (register int m) {
	for (register int i = 0; i < m; i ++) {
		memset (vis, 0, sizeof 4 * m);
		for (register int k = 1, tmp; k <= m - 1; k ++) {
			tmp = gqpow (i, k);
			if (vis[tmp]) goto end;
			else vis[tmp] = 1;
		}
		return i;
		end:;
	}
	return -1;
}

inline int qpow (register int a, register int b, register int ans = 1) {
	for (; b; b >>= 1, a = 1ll * a * a % mod) 
		if (b & 1) ans = 1ll * ans * a % mod;
	return ans;
}

inline void NTT (register int len, register int * a, register int opt) {
	for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
	for (register int d = 1; d < len; d <<= 1) {
		register int w1 = qpow (opt, (mod - 1) / (d << 1));
		for (register int i = 0; i < len; i += d << 1) {
			register int w = 1;
			for (register int j = 0; j < d; j ++, w = 1ll * w * w1 % mod) {
				register int x = a[i + j], y = 1ll * w * a[i + j + d] % mod;
				a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
			}
		}
	}
}

inline void Calc (register int * a, register int * b) {
	memset (res, 0, 4 * len), memset (tmp, 0, 4 * len);
	for (register int i = 0; i < m; i ++) res[i] = a[i], tmp[i] = b[i], a[i] = 0;
	NTT (len, res, 3), NTT (len, tmp, 3);
	for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * tmp[i] % mod;
	NTT (len, res, inv3);
	register int inv = qpow (len, mod - 2);
	for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * inv % mod, a[i % (m - 1)] = (a[i % (m - 1)] + res[i]) % mod;
}

int main () {
	n = read(), m = read(), X = read(), s = read(), g = Get_g (m);
	for (register int i = 0; i <= m - 2; i ++) id[gqpow (g, i)] = i;
	while (len < m << 1) len <<= 1, bit ++;
	for (register int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
	for (register int i = 1; i <= s; i ++) {
		register int x = read();
		if (x) f[id[x]] ++;
	}
	ans[0] = 1;
	while (n) {
		if (n & 1) Calc (ans, f);
		n >>= 1, Calc (f, f);
	}
	printf ("%d
", ans[id[X]]);
	return 0;
}
原文地址:https://www.cnblogs.com/Rubyonly233/p/14217953.html