[SDOI2015]序列统计

题目链接


题目大意
有一个不重集合,每个元素都小于$m$,$m$是质数.要从集合中选$n$个元素,元素乘积除以$m$余$x$.求方案数,答案对$1004535809$取模.
$mleq8000,nleq10^9$

前置芝士
原根
若$m$为质数,$G^0,G^1, G^2,G^3,G^4...G^{m-2}$模$m$的余数各不相同,则称$G$是$m$的一个原根.
如何求原根呢?就是把$m-1$分解质因数,假设结果为$p_1^{k_1}*p_2^{k_2}*...*p_t^{k_t}$,其中$t$是$m-1$质因数数量.
因为原根都较小,可以暴力求.
从$2$开始枚举素数$x$,对于每个素数枚举所有的$p_i$,若对于所有的$p_i$,$x^frac{m-1}{p_i} otequiv1 (mod m)$,则$x$为$m$的一个原根.
代码如下:

int getG(int x){
    int t=x-1,cnt=0,pr[N];
    for(int i=2;i<=t;i++)
    if(t%i==0){
        pr[++cnt]=(x-1)/i;
        while(t%i==0)t/=i;
    }
    if(t>1)pr[++cnt]=(x-1)/t;
    for(int i=2;;i++){
        bool flag=1;
        for(int j=1;j<=cnt;j++)
        if(ksm(1ll*i,pr[j],x)==1){flag=0;break;}
        if(flag)return i;
    }
}

分析
考虑这样一个问题,把题面中的乘积换成和,怎么做呢?
很简单,令$f(x)=a_0*x^0+a_1*x^1+...+a_m*x^m$,其中,$a_x=[xin S]$
然后直接多项式快速幂,用$NTT$优化时间复杂度即可.
注意到多项式的次数会大于$m$,则只要令$c[x \%m]+=c[x]$即可.
相当于每次乘完取一次模

现在是乘积,怎么做呢?
假设我们已经求出了原根$G$
令数组$c[x]$,使得$G^{c[x]}equiv x (mod m)$

由于一个显然的性质:$G^x*G^y=G^{x+y}$,而且$G$的任何次方除以$m$的余数均不同,因此,若$x in S$,只要令$a_{c[x]}=1$即可.这样,我们将乘法问题转变成加法问题了.

然后就是和加法一样了,用多项式快速幂和$NTT$优化时间复杂度即可.
时间复杂度$O(m*log m*log n)$


代码如下

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define P (1004535809)
#define N (35001)
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
    static const int IN_LEN=1000000;
    static char buf[IN_LEN],*s,*t;
    return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
    static bool iosig;
    static char c;
    for(iosig=false,c=read();!isdigit(c);c=read()){
        if(c=='-')iosig=true;
        if(c==-1)return;
    }
    for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
    if(iosig)x=-x;
}
inline char readchar(){
    static char c;
    for(c=read();!isalpha(c);c=read())
    if(c==-1)return 0;
    return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
    if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
    *ooh++=c;
}
template<class T>
inline void print(T x){
    static int buf[30],cnt;
    if(x==0)print('0');
    else{
        if(x<0)print('-'),x=-x;
        for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
        while(cnt)print((char)buf[cnt--]);
    }
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
int n,m,X,S,w[N],G,Lim,len,rev[N],c[N];
LL a[N],res[N];
LL gg=1,inv;
LL ksm(LL a,int p,int mo){
    LL res=1;
    while(p){
        if(p&1)res=(res*a)%mo;
        a=(a*a)%mo,p>>=1;
    }
    return res;
}
LL ksm(LL a,int p){
    LL res=1;
    while(p){
        if(p&1)res=(res*a)%P;
        a=(a*a)%P,p>>=1;
    }
    return res;
}
void NTT(LL *a,int tp){
    for(int i=0;i<Lim;i++)
    if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int pos=1;pos<Lim;pos<<=1){
        LL w=ksm(3,(P-1)/(pos<<1));
        if(tp==-1)w=ksm(w,P-2);
        for(int R=pos<<1,j=0;j<Lim;j+=R){
            LL p=1;
            for(int k=j;k<j+pos;p=(p*w)%P,k++){
                LL x=a[k],y=(p*a[k+pos])%P;
                a[k]=(x+y)%P,a[k+pos]=(x-y+P)%P;
            }
        }
    }
    if(tp==-1)
    for(int i=0;i<Lim;i++)a[i]=(a[i]*inv)%P;
}
void qpow(int p){
    res[c[1]]=1;
    while(p){
        NTT(a,1);
        if(p&1){
            NTT(res,1);
            for(int i=0;i<Lim;i++)(res[i]*=a[i])%=P;
            NTT(res,-1);
            for(int i=Lim-1;i>m-1;i--)(res[i-m+1]+=res[i])%=P,res[i]=0;
        }
        for(int i=0;i<Lim;i++)(a[i]*=a[i])%=P;
        NTT(a,-1),p>>=1;
        for(int i=Lim-1;i>m-1;i--)(a[i-m+1]+=a[i])%=P,a[i]=0;
    }
}
int getG(int x){
    int t=x-1,cnt=0,pr[N];
    for(int i=2;i<=t;i++)
    if(t%i==0){
        pr[++cnt]=(x-1)/i;
        while(t%i==0)t/=i;
    }
    if(t>1)pr[++cnt]=(x-1)/t;
    for(int i=2;;i++){
        bool flag=1;
        for(int j=1;j<=cnt;j++)
        if(ksm(1ll*i,pr[j],x)==1){flag=0;break;}
        if(flag)return i;
    }
}
void init(){
    Lim=1;
    while(Lim<=m+m)Lim*=2,len++;
    for(int i=0;i<Lim;i++)
    rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    inv=ksm(Lim,P-2);
}
int main(){
    read(n),read(m),read(X),read(S);
    G=getG(m); 
    for(int i=1;i<=m-1;i++)
    (gg*=G)%=m,c[gg]=i;
    for(int i=1,x;i<=S;i++){
        read(x);
        if(x)a[c[x]]=1;
    }
    init(),qpow(n);
    printf("%lld
",res[c[X]]);
}
原文地址:https://www.cnblogs.com/Romeolong/p/10005553.html