[洛谷P3321][SDOI2015]序列统计

题目大意:给你一个集合$n,m,x,S(S_iin(0,m],mleqslant 8000,min m{prime},nleqslant10^9)$,求一个长度为$n$的序列$Q$,满足$Q_iin S$,且$prodlimits _{i=1}^nQ_i=x$,求序列的个数

题解:乘比较麻烦,可以把每个数求$ln$,可以求出$m$的原根,求原根可以暴力$O(m^2)$求,然后每个数求$ln$,求出生成函数$F(x)$,算出$F^n(x)$。发现$n$较大,多项式快速幂即可。

卡点:

C++ Code:

#include <cstdio>
#include <algorithm>
#include <cstring>
#define maxn 16384 | 3
#define maxm 8010
const int mod = 1004535809, G = 3;
int n, m, x, S, g;
int vis[maxm];
int get(int m) {
	bool find = false;
	for (int i = 2; i < m; i++) {
		memset(vis, -1, sizeof vis);
		int t = 1;
		vis[1] = 0;
		for (int j = 1; j < m - 1; j++) {
			t = t * i % m;
			if (~vis[t]) break;
			else vis[t] = j;
			if (j == m - 2) find = true;
		}
		if (find) return i;
	}
	return 20040826;
}
int lim, ilim, s, rev[maxn];
int base[maxn], ans[maxn], Wn[maxn + 1];
inline int pw(int base, int p) {
	int res = 1;
	for (; p; p >>= 1, base = 1ll * base * base % mod) if (p & 1) res = 1ll * res * base % mod;
	return res;
}
inline int Inv(int x) {return pw(x, mod - 2);}
inline void init(int n) {
	lim = 1, s = -1; while (lim < n) lim <<= 1, s++; ilim = Inv(lim);
	for (int i = 0; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
	int t = pw(G, (mod - 1) / lim);
	Wn[0] = 1; for (int i = 1; i <= lim; i++) Wn[i] = 1ll * Wn[i - 1] * t % mod;
}
inline void up(int &a, int b) {if ((a += b) >= mod) a -= mod;}
inline void NTT(int *A, int op) {
	for (int i = 0; i < lim; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
	for (int mid = 1; mid < lim; mid <<= 1) {
		int t = lim / mid >> 1;
		for (int i = 0; i < lim; i += mid << 1) {
			for (int j = 0; j < mid; j++) {
				int W = op ? Wn[j * t] : Wn[lim - j * t];
				int X = A[i + j], Y = 1ll * A[i + j + mid] * W % mod;
				up(A[i + j], Y), up(A[i + j + mid] = X, mod - Y);
			}
		}
	}
	if (!op) for (int i = 0; i < lim; i++) A[i] = 1ll * A[i] * ilim % mod;
}
int C[maxn], D[maxn];
inline void MUL(int *A, int *B) {
	for (int i = 0; i < lim; i++) C[i] = A[i], D[i] = B[i];
	NTT(C, 1), NTT(D, 1);
	for (int i = 0; i < lim; i++) C[i] = 1ll * C[i] * D[i] % mod;
	NTT(C, 0);
	for (int i = 0; i < lim; i++) A[i] = C[i];
	for (int i = m - 1; i < lim; i++) (A[i % (m - 1)] += A[i]) %= mod, A[i] = 0;
}
int main() {
	scanf("%d%d%d%d", &n, &m, &x, &S);
	g = get(m);
	for (int i = 0, tmp; i < S; i++) {
		scanf("%d", &tmp);
		if (tmp) base[vis[tmp]] = 1;
	}
	init(m << 1);
	ans[0] = 1;
	for (; n; n >>= 1, MUL(base, base)) if (n & 1) MUL(ans, base);
	printf("%d
", ans[vis[x]]);
	return 0;
}

  

原文地址:https://www.cnblogs.com/Memory-of-winter/p/9719203.html