luogu P3321 [SDOI2015]序列统计 FFT

首先数相同,位置不同的算作不同的方案,每多出一个位置就能多转移一次,所以我们可以写出这样的转移。

(displaystyle C[k]=sum_{i imes j \%m==k}A[i] imes B[j])

我们平时写的FFT/NTT都是加号,这里是乘号,想要把乘号变成加号就要取(log),又因为是在mod m的意义下,m又是一个质数,因为对于(m)的原根(g),(1le ile m-1),(g^i)两两互不相同,所以我们可以找到m的原根,将原根作为底数,就能做多项式乘法了。

因为(n)比较大,所以要用多项式快速幂,当时我比较菜,还不会ln和exp的(O(nlogn))快速幂,但实测(O(nlog^2n))快速幂可过。

另外需要注意一下给的元素里可能有0,需要忽略。从乘法上看,倘若有了0,结果一定是0。从原根角度看,对0取log无意义。

#include<iostream>
#include<cstdio>
//#define int long long
#define LL long long
using namespace std;
int n, m, x, s, tmp, k, tot;
const int N = 100010, mod = 1004535809, G = 3, Ginv = (mod + 1) / 3;
int r[N], mo[N], to[N];
LL F[N], ans[N];
int read() 
{
	char ch; int x = 0, f = 1;
	while (!isdigit(ch = getchar())) {(ch == '-')&&(f = -f);}
	while (isdigit(ch)) {x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
LL ksm(LL a, LL b, LL mod) 
{
	LL res = 1;
	for (; b; b >>= 1, a = a * a % mod)
		if (b & 1)res = res * a % mod;
	return  res;
}
void NTT(LL *A, int lim, int opt) 
{
	for (int i = 0; i < lim; ++i)r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
	for (int i = 0; i < lim; ++i)
		if (i < r[i])swap(A[i], A[r[i]]);
	int len;
	LL wn, w, x, y;
	for (int mid = 1; mid < lim; mid <<= 1) 
	{
		len = mid << 1;
		wn = ksm(opt == 1 ? G : Ginv, (mod - 1) / len, mod);
		for (int j = 0; j < lim; j += len) 
		{
			w = 1;
			for (int k = j; k < j + mid; ++k, w = w * wn % mod) 
			{
				x = A[k]; y = A[k + mid] * w % mod;
				A[k] = (x + y) % mod;
				A[k + mid] = (x - y + mod) % mod;
			}
		}
	}
	if (opt == 1)return;
	int ni = ksm(lim, mod - 2, mod);
	for (int i = 0; i < lim; ++i)A[i] = A[i] * ni % mod;
}
void MUL(LL *A, int n, LL *B, int m, LL *C) 
{
	static LL X[N], Y[N];
	int lim = 1;
	while (lim <= (n + m))lim <<= 1;
	for (int i = 0; i <= n; ++i)X[i] = A[i];
	for(int i = n + 1;i <= lim;++ i)X[i]=0;
	for (int i = 0; i <= m; ++i)Y[i] = B[i];
	for(int i = m + 1;i <= lim;++ i)Y[i]=0;
	NTT(X, lim, 1); NTT(Y, lim, 1);
	for (int i = 0; i < lim; ++i)X[i] = X[i] * Y[i] % mod;
	NTT(X, lim, -1);
	for (int i = 0; i < m - 1; ++i)X[i] = (X[i] + X[i + m - 1]) % mod;
	for (int i = 0; i < m - 1; ++i)C[i] = X[i];
}
int GET_ROOT(int m) 
{
	tot = 0;
	int tmp = m - 1;
	for (int i = 2; i * i <= tmp; ++i)
		if (!(tmp % i)) 
		{
			mo[++tot] = i;
			while (!(tmp % i))tmp /= i;
		}
	if (tmp != 1)mo[++tot] = tmp;
	bool flag;
	for (int i = 2; i <= m; ++i) 
	{
		flag = 1;
		for (int j = 1; j <= tot; ++j)
			if (ksm(i, (m - 1) / mo[j], m) == 1) 
			{
				flag = 0;
				break;
			}
		if (flag)return i;
	}
	return -1;
}
void YYCH(LL *A, int b, LL *c) 
{
	c[to[1]] = 1;
	for (; b; b >>= 1, MUL(A, m, A, m, A))
		if (b & 1)MUL(A, m, c, m, c);
}
signed main() 
{
	cin >> n >> m >> x >> s;
	k = GET_ROOT(m);
	tmp = 1;
	for (int i = 0; i < m - 1; ++i, tmp = tmp * k % m)to[tmp] = i;
	for (int i = 1; i <= s; ++i) 
	{
		tmp = read();
		if (tmp)F[to[tmp]]++;
	}
	YYCH(F, n, ans);
	cout << ans[to[x]];
	return 0;
}
原文地址:https://www.cnblogs.com/wljss/p/12029009.html