【洛谷P3321】序列统计

题目

题目链接:https://www.luogu.com.cn/problem/P3321
小C有一个集合 (S),里面的元素都是小于 (m) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 (n) 的数列,数列中的每个数都属于集合 (S)
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数 (x),求所有可以生成出的,且满足数列中所有数的乘积 (mod m) 的值等于 (x) 的不同的数列的有多少个。
小C认为,两个数列 (A)(B) 不同,当且仅当 (exists i ext{ s.t. } A_i eq B_i)。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 (1004535809) 取模的值就可以了。
(nleq 10^9,mleq 8000)

思路

之前在 GMOJ 这道题时限开 (5s) 被我 (O(m^2log n)) 艹过去了。
首先 (60)pts 的倍增 dp 就是设 (f[i][j]) 表示选了 (2^i) 个数,乘积 (mod p) 之后的结果为 (j) 的方案数。
转移为

[f[k][l]=sum^{}_{i imes jmod p=l}f[k-1][i] imes f[k-1][j] ]

然后二进制拆分即可。
如果这个乘号是加号的话,我们就可以 NTT 优化了。
考虑如何把乘号变为加号,因为 (log_ab+log_ac=log_a(bc)),所以可以用对数进行转化。
但是我们需要保证转化后对于任意两个 (x,yin [1,m))(x eq y),都有 (log_a x eq log_a y),由于 (m) 是质数,所以我们取 (m) 的原根即可。
接下来就和 (60)pts 的做法一样了。将每一数转化为对数之后扔进一个多项式里,然后倍增计算即可。
时间复杂度 (O(mlog nlog m))

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=18010,MOD=1004535809;
int n,m,s,l,G,lim,a[N],rev[N];
ll f[N],g[N],h[N];

ll fpow(ll x,ll k,ll mod=(ll)MOD)
{
	ll ans=1;
	for (;k;k>>=1,x=x*x%mod)
		if (k&1) ans=ans*x%mod;
	return ans;
}

int findg(int p)
{
	vector<int> d;
	for (int i=2;i<=p-1;i++)
		if ((p-1)%i==0) d.push_back(i);
	for (int i=1;i<=p;i++)
	{
		bool flag=1;
		for (int j=0;j<d.size();j++)
			if (fpow(i,(p-1)/d[j],p)==1) { flag=0; break; }
		if (flag) return i;
	}
}

void NTT(ll *f,bool tag)
{
	for (int i=0;i<lim;i++)
		if (i<rev[i]) swap(f[i],f[rev[i]]);
	for (int k=1;k<lim;k<<=1)
	{
		ll tmp=fpow((tag?3:334845270),(MOD-1)/(k<<1));
		for (int i=0;i<lim;i+=(k<<1))
		{
			ll w=1;
			for (int j=0;j<k;j++,w=w*tmp%MOD)
			{
				ll x=f[i+j],y=w*f[i+j+k]%MOD;
				f[i+j]=(x+y)%MOD; f[i+j+k]=(x-y)%MOD;
			}
		}
	}
}

int main()
{
	scanf("%d%d%d%d",&n,&m,&s,&l);  // 十分优雅的读入
	G=findg(m);
	for (int i=1;i<m;i++)
		a[fpow(G,i,m)]=i;
	for (int i=1,x;i<=l;i++)
	{
		scanf("%d",&x);
		if (x) f[a[x]]++;
	}
	g[0]=lim=1;
	while (lim<=2*m) lim<<=1;
	for (int i=0;i<lim;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
	ll inv=fpow(lim,MOD-2);
	for (int k=0;k<=30;k++)
	{
		if (n&(1<<k))
		{
			memcpy(h,f,sizeof(f));
			NTT(g,1); NTT(h,1);
			for (int i=0;i<lim;i++) g[i]=g[i]*h[i]%MOD;
			NTT(g,0);
			for (int i=1;i<m;i++)
				g[i]=(g[i]+g[i+m-1])*inv%MOD;
			for (int i=m;i<lim;i++) g[i]=0;
		}
		NTT(f,1);
		for (int i=0;i<lim;i++) f[i]=f[i]*f[i]%MOD;
		NTT(f,0);
		for (int i=1;i<m;i++)
			f[i]=((f[i]+f[i+m-1])*inv%MOD+MOD)%MOD;
		for (int i=m;i<lim;i++) f[i]=0;
	}
	printf("%lld",(g[a[s]]%MOD+MOD)%MOD);
	return 0;
}
原文地址:https://www.cnblogs.com/stoorz/p/14248727.html