[SDOI2015]序列统计

嘟嘟嘟


此题很可做。


首先从一个暴力的dp入手:令(dp[i][j])表示第(i)个数为(j)时的数列个数,于是有(dp[i][j *a[k] \% M] += dp[i - 1][j])
但这个似乎只能拿10分。


一个显然的优化是改成倍增快速幂的形式,上述dp方程显然是可以合并的,即(dp[x + y][i * j \% M] += dp[x][i] * dp[y][j])
乘的时候暴力乘,这样(O(k ^ 2logn))能拿到60分。


现在瓶颈在于多项式的乘法。这东西和卷积很像,只不过卷积是加,他却是乘。那么怎么才能把乘变成加呢?取log啊!
然后我就卡在了这里:取完log下标都不是整数,那不gg了……


最后问了衡水的巨佬,他说你对原根取对数啊!问了一大顿才想起来,原根的定义是一个数(g),满足(g ^ 0, g ^ 1, g ^ 2 ldots g ^ {p - 2})刚好能凑出([1, p - 1])的所有整数。于是这题就完事了啊!
看来还是自己原根学的不好,一会儿赶快复习一下。


数据范围挺可爱的,规定了(x)不可以取(0),要不然还得分来讨论把(0)单出来算。
然而集合(S)中却有(0)……这得特判一下……

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<queue>
#include<vector>
#include<ctime>
#include<assert.h>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; (y = e[i].to) && ~i; i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e3 + 5;
const int maxM = 4e4 + 5;
const ll mod = 1004535809;
const ll G = 3;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) putchar('-'), x = -x;
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
	freopen("ha.in", "r", stdin);
	freopen("bf.out", "w", stdout);
#endif
}

int n, M, X, S, a[maxM], pos[maxM];
int len = 1, lim = 0, rev[maxM];

In ll inc(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;}
In ll quickpow(ll a, ll b, ll mod)
{
	ll ret = 1;
	for(; b; b >>= 1, a = a * a % mod)
		if(b & 1) ret = ret * a % mod;
	return ret;
}

In ll phi(ll n)
{
	ll ret = n;
	for(int i = 2; i * i <= n; ++i)
	{
		if(n % i) continue;
		ret = ret / i * (i - 1);
		while(n % i == 0) n /= i;
	}
	if(n > 1) ret = ret / n * (n - 1);
	return ret;
}
int p[1000], pcnt = 0;
In ll getRoot(ll m)
{
	ll Phi = phi(m); pcnt = 0;
	for(int i = 2; i * i <= Phi; ++i) if(Phi % i == 0)
	{
		p[++pcnt] = i;
		if(Phi / i != i) p[++pcnt] = Phi / i;
	}
	for(int g = 2; g <= Phi; ++g)
	{
		bool flg = 1;
		if(quickpow(g, Phi, m) ^ 1) continue;
		for(int i = 1; i <= pcnt && flg; ++i)
			if(quickpow(g, p[i], m) == 1) flg = 0;
		if(flg) return g;
	}
	return -1;
}
In void init()
{
	int g = getRoot(M), tp = 1;
	for(int i = 0; i < M - 1; ++i, tp = tp * g % M) pos[tp] = i;
	while(len < M + M) len <<= 1, ++lim;
	for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}

In void ntt(ll* a, int len, bool flg)
{
	for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
	for(int i = 1; i < len; i <<= 1)
	{
		ll gn = quickpow(G, (mod - 1) / (i << 1), mod);
		for(int j = 0; j < len; j += (i << 1))
		{
			ll g = 1;
			for(int k = 0; k < i; ++k, g = g * gn % mod)
			{
				ll tp1 = a[j + k], tp2 = a[j + k + i] * g % mod;
				a[j + k] = (tp1 + tp2) % mod, a[j + k + i] = (tp1 - tp2 + mod) % mod;
			}
		}
	}
	if(flg) return;
	reverse(a + 1, a + len); ll inv = quickpow(len, mod - 2, mod);
	for(int i = 0; i < len; ++i) a[i] = a[i] * inv % mod;
}

ll c[maxM], A[maxM], B[maxM];
In void mul(ll* a, ll* b)
{
	for(int i = 0; i < len; ++i)    //一定要复制到另一个数组再NTT!因为传过来的数组a和b可能是同一个!(debug到头秃)
	{
		A[i] = i < M - 1 ? a[i] : 0;
		B[i] = i < M - 1 ? b[i] : 0;
	}
	ntt(A, len, 1), ntt(B, len, 1);
	for(int i = 0; i < len; ++i) a[i] = A[i] * B[i] % mod;
	ntt(a, len, 0);
	for(int i = 0; i < M - 1; ++i) a[i] = inc(a[i], a[i + M - 1]);
}

ll f[maxM], g[maxM];
In ll Quickpow(int n)
{
	f[pos[1]] = 1;
	for(int i = 1; i <= S; ++i) if(a[i]) g[pos[a[i]]] = 1;
	for(; n; n >>= 1, mul(g, g)) 
		if(n & 1) mul(f, g);
	return f[pos[X]];
}

int main()
{
//	MYFILE();
	n = read(), M = read(), X = read(), S = read();
	for(int i = 1; i <= S; ++i) a[i] = read();
	init();
	write(Quickpow(n)), enter;
	return 0;
}
原文地址:https://www.cnblogs.com/mrclr/p/11122723.html