[SDOI2015]序列统计

题面在这里

题意

求长度为(n),元素取值集合为(S(forall xin S in [1,m-1]))的数列乘积为(x)的方案数

sol

首先我们可以想到一个(DP):设(f[i][j])表示长度为(i)的数列乘积为(j)的方案数,那么有:

[f[i][j]=sum_{k|j}(f[i-1][k] imes sum[frac{j}{k}]) ]

其中(sum[i])表示集合(S)中与(i) (mod) (m)同余的数的个数
由于每(f[i-1]->f[i])的转移是固定的,那么可以使用矩乘优化

考虑如何进一步优化:我们发现这很像一个卷积的形式,因此考虑NTT快速幂
但卷积的形式为$$a_i imes b_j=c_{i+j}$$
而这里是$$a_i imes b_j=c_{i*j}$$
怎么办呢?
注意到题目中给出的(m)一定是个质数
因此我们可以把每个(<m)的数当作(m)的原根(g)对应的次幂
于是当(i=g^{M_i})(这里(M)是一个映射)时就可以对应的求出其原值和对应原根的次幂

快速幂的时候注意需要对(ge m)次幂的系数取模

#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<iomanip>
#include<cstring>
#include<complex>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define mp make_pair
#define pb push_back
#define RG register
#define il inline
using namespace std;
typedef unsigned long long ull;
typedef vector<int>VI;
typedef long long ll;
typedef double dd;
const dd eps=1e-10;
const int mod=1004535809;
const int g=3;
const int inv=334845270;
const int N=40010;
il ll read(){
    RG ll data=0,w=1;RG char ch=getchar();
    while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')w=-1,ch=getchar();
    while(ch<='9'&&ch>='0')data=data*10+ch-48,ch=getchar();
    return data*w;
}
il void file(){
    //freopen("a.in","r",stdin);
    //freopen("b.out","w",stdout);
}

il int getg(int p){
    RG int ret,b;
    for(RG int i=2;;i++){
        ret=b=1;
        for(RG int j=1;j<p-1;j++){
            ret=1ll*ret*i%p;if(ret==1){b=0;break;}
        }
        if(b)return i;
    }
}

il ll poww(ll a,ll b){
    RG ll ret=1;
    for(a%=mod;b;b>>=1,a=a*a%mod)if(b&1)ret=ret*a%mod;
    return ret;
}

namespace NTT{
    int n,l,r[N],rev;
    il void NTT(int *A,int n,bool b){
        for(RG int i=0;i<n;i++)if(i<r[i])swap(A[i],A[r[i]]);
        for(RG int i=2;i<=n;i<<=1){
            RG int w=poww((b?g:inv),(mod-1)/i);
            for(RG int j=0;j<n;j+=i){
                RG int wn=1;
                for(RG int k=j;k<j+(i>>1);k++,wn=1ll*wn*w%mod){
                    RG int x=1ll*wn*A[k+(i>>1)]%mod;
                    A[k+(i>>1)]=(A[k]-x+mod)%mod;
                    A[k]=(A[k]+x)%mod;
                }
            }
        }
        if(!b)for(RG int i=0;i<n;i++)A[i]=1ll*A[i]*rev%mod;
    }
    il void conv(int *A,int *B,int n,int len){
        NTT(A,n,1);if(A!=B)NTT(B,n,1);
        for(RG int i=0;i<n;i++)A[i]=1ll*A[i]*B[i]%mod;
        NTT(A,n,0);if(A!=B)NTT(B,n,0);
        for(RG int i=0;i<len;i++)
            if(A[i+len])A[i]=(A[i]+A[i+len])%mod,A[i+len]=0;
    }
    il void fastpow(int* A,int* B,int t,int len){
        for(n=1;n<=len*2;n<<=1)l++;rev=poww(n,mod-2);
        for(RG int i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((1&i)<<(l-1));
        while(t){if(t&1)conv(A,B,n,len);conv(B,B,n,len);t>>=1;}
    }
}

int n,m,x,S,s[N],f[N];
map<int,int>M;
il void init(){
    RG int G=getg(m),ret=1;
    for(RG int i=0;i<m-1;i++){M[ret]=i;ret=1ll*ret*G%m;}
}

int main()
{
    file();n=read();m=read();x=read();S=read();init();
    for(RG int i=1,k;i<=S;i++)if(k=read())s[M[k]]++;f[0]=1;
    NTT::fastpow(f,s,n,m-1);
    printf("%d
",f[M[x]]);
    return 0;
}

原文地址:https://www.cnblogs.com/cjfdf/p/8436113.html