[SDOI2015]序列统计

[SDOI2015]序列统计

题意:

小C有一个集合(S),里面的元素都是小于(m)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为(n)的数列,数列中的每个数都属于集合(S)

小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数(x),求所有可以生成出的,且满足数列中所有数的乘积%(m)的值等于(x)的不同的数列的有多少个。

小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对(1004535809)取模的值就可以了。

输入格式:

一行,四个整数(n,m,x,∣S∣)其中(∣S∣)为集合(S)中元素个数。
第二行,(∣S∣)个整数,表示集合(S)中的所有元素。

输出格式:

一行一个整数表示答案。

输入样例:

4 3 1 2
1 2

输出样例:

8

Solution:

首先定义数组(f[i][j])表示生成到了第(i)个数,答案是(j)的方案数
(n)的大小为(1e9),首先可以ksm优化
(f[i*2][j] = sum_{a*b mod m=j} f[i][a]*f[i][b])
(8000 imes8000 imeslog_{1e9})的复杂度还是过不了题
可以发现后面那一坨有点像fft式子,但是条件是乘号。
想想只有对数可以吧加法和乘法联系在一起
这方面的知识点可以参考博主的博客 取模意义下的对数&生成元的查找
然后当我们把第二维全部换成对数时,可以得到式子
(f[i*2][j] = sum_{a+b mod {m-1}=j}f[i][a]*f[i][b])
这个式子就直接ntt就行,中间为什么是(mod{m-1})参考博客即可
代码:

#include<bits/stdc++.h>
#define ll long long
#define R register
using namespace std;
template<class T>
void rea(T &x)
{
	char ch=getchar();int f(0);x = 0;
	while(!isdigit(ch)) {f|=ch=='-';ch=getchar();}
	while(isdigit(ch)) {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	x = f?-x:x;
}
int ksm(int x, int k, int mod)
{
    int ret = 1;
    while(k)
    {
        if(k&1) ret = 1ll*ret*x%mod;
        x = 1ll*x*x%mod;
        k >>= 1;
    }
    return ret;
}
int getroot(int mod)
{
    int prime[10000], tot = 0;
    int num = mod-1;
    for(R int i = 2; i*i <= num; ++i)
        if(num%i == 0)
        {
            prime[++tot] = i;
            while(num%i == 0) num /= i;
        }
    if(num > 1) prime[++tot] = num;
    num = mod-1;
    for(R int i = 2; i <= num; ++i)
    {
        bool ban = 0;
        for(R int j = 1; j <= tot; ++j) 
            if(ksm(i, num/prime[j], mod) == 1) { ban = 1; break; }
        if(!ban) return i;
    }
    return false;
}
const int N = 10000, mod = 1004535809, G = 3, Gi = ksm(3, mod-2, mod);
int n, m, x, s, base[N<<2], ans[N<<2], pos[N<<2];
map<int, int>Log;
void prepos(int k)
{
	int len = (1<<k);
	for(R int i = 0; i < len; ++i)
		pos[i] = (pos[i>>1]>>1)|((i&1)<<(k-1));
}
void NTT(int *a, int len, int flag)
{
	for(R int i = 0; i < len; ++i) if(pos[i] > i) swap(a[pos[i]], a[i]);
	for(R int mid = 1; mid < len; mid*=2)
	{
		int wx = ksm(flag==1?G:Gi, (mod-1)/(mid*2), mod);
		for(R int i = 0; i < len; i += mid*2)
		{
			int w = 1;
			for(R int j = i; j < i+mid; ++j)
			{
				int x = a[j], y = 1ll*a[j+mid]*w%mod;
				a[j] = (x+y)%mod, a[j+mid] = (x-y+mod)%mod;
				w = 1ll*w*wx%mod;
			}
		}
	}
	if(flag == -1)
	{
		int inv = ksm(len, mod-2, mod);
		for(R int i = 0; i < len; ++i) a[i] = 1ll*a[i]*inv%mod;
	}
}
void X(int *a, int *b, int len)
{
	int A[N<<2], B[N<<2];
	for(R int i = 0; i < len; ++i) A[i] = a[i], B[i] = b[i];
	NTT(A, len, 1); NTT(B, len, 1);
	for(R int i = 0; i < len; ++i) A[i] = 1ll*A[i]*B[i]%mod;
	NTT(A, len, -1);
	for(R int i = 0; i < m-1; ++i) A[i] = (A[i]+A[m+i-1])%mod;
	for(R int i = 0; i < m-1; ++i) a[i] = A[i];
}
int main()
{
	rea(n), rea(m), rea(x), rea(s);
	int g = getroot(m); for(R int i = 0; i < m-1; ++i) Log[ksm(g, i, m)] = i;
	for(R int i = 1; i <= s; ++i) {rea(g);if(g%m) base[Log[g%m]]++;}
	int limit = 2, k = 1;
	while(limit < m*2) limit <<= 1, k++;
	prepos(k);
	ans[0] = 1;
	while(n)
	{
		if(n&1) X(ans, base, limit);
		X(base, base, limit);
		n >>= 1;
	}
	printf("%d
", ans[Log[x]]);
	return 0;
}
原文地址:https://www.cnblogs.com/heanda/p/12397393.html