bzoj4818 [Sdoi2017]序列计数

Description

Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。

Input

一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100

Output

一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

Sample Input

3 5 3

Sample Output

33

正解:矩阵快速幂/$FFT$+快速幂+中国剩余定理。

昨天晚上考这题,发现是一道$FFT$优化$DP$的裸题。然而数组开小了,所以只有$80$分。。

正解似乎是矩阵快速幂,不过为什么$FFT$跑得快一些。。并且当$p$很大时矩阵快速幂就没用了。。

首先我们可以想到一个$O(n^{3})$的暴力$DP$。设$f[i][j]$表示前$i$个数,模$p$为$j$的方案数,那么$f[i][(j+k) mod p]+=f[i-1][j]$,其中$k$为枚举选哪个数。

同正解一样,我们求出所有数的情况,然后减去没有质数的情况,最后得到的就是至少有一个质数的情况。

显然,这是一个卷积的形式,那么我们可以考虑用$FFT$来优化$DP$。

首先构造一个模$p$的多项式,$a[i]$表示模$p$为$i$的数有多少个。那么直接用$FFT$和快速幂算出$a^{n}$就行了。

然后线性筛求出所有质数,构造出除去所有质数的多项式$a$,再用一次$FFT$算出$a^{n}$。

第一个多项式的$a[0]$就是模$p$为$0$的情况,第二个多项式的$a[0]$就是没有任何质数且模$p$为$0$的情况,易知两个$a[0]$相减即为答案。

不过这个模数会炸精度,我们把这个模数拆成$8,1091,2311$,分别算出模这$3$个数的答案,最后用中国剩余定理合并就行了。

总复杂度$O(plogplogn)$,所以比矩阵快速幂的$O(p^{3}logn)$要快。不过我没加任何常数优化,所以还是很慢。。

  1 //It is made by wfj_2048~
  2 #include <algorithm>
  3 #include <iostream>
  4 #include <complex>
  5 #include <cstring>
  6 #include <cstdlib>
  7 #include <cstdio>
  8 #include <vector>
  9 #include <cmath>
 10 #include <queue>
 11 #include <stack>
 12 #include <map>
 13 #include <set>
 14 #define rhl (20170408)
 15 #define NN (20000010)
 16 #define pi acos(-1.0)
 17 #define inf (1<<30)
 18 #define il inline
 19 #define RG register
 20 #define ll long long
 21 #define C complex <long double>
 22 #define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
 23 
 24 using namespace std;
 25 
 26 C a[1010],b[1010],c[1010],ans[1010];
 27 
 28 int prime[NN],rev[1010],r64[5],N,n,m,p,lg,cnt;
 29 ll res[1010],Ans[5],ans1,ans2,aans;
 30 bool vis[NN];
 31 
 32 il int gi(){
 33     RG int x=0,q=1; RG char ch=getchar();
 34     while ((ch<'0' || ch>'9') && ch!='-') ch=getchar();
 35     if (ch=='-') q=-1,ch=getchar();
 36     while (ch>='0' && ch<='9') x=x*10+ch-48,ch=getchar();
 37     return q*x;
 38 }
 39 
 40 il void sieve(){
 41     vis[1]=1;
 42     for (RG int i=2;i<=m;++i){
 43     if (!vis[i]) prime[++cnt]=i;
 44     for (RG int j=1,k;j<=cnt;++j){
 45         k=i*prime[j]; if (k>m) break;
 46         vis[k]=1; if (i%prime[j]==0) break;
 47     }
 48     }
 49     return;
 50 }
 51 
 52 il void FFT(C *a,RG int n,RG int f){
 53     for (RG int i=0;i<n;++i)
 54     if (i<rev[i]) swap(a[i],a[rev[i]]);
 55     for (RG int i=1;i<n;i<<=1){
 56     C wn(cos(pi/i),sin(f*pi/i)),x,y;
 57     for (RG int j=0;j<n;j+=(i<<1)){
 58         C w(1,0);
 59         for (RG int k=0;k<i;++k,w*=wn){
 60         x=a[j+k],y=w*a[j+k+i];
 61         a[j+k]=x+y,a[j+k+i]=x-y;
 62         }
 63     }
 64     }
 65     return;
 66 }
 67 
 68 il void mul(C *a,C *b,RG int pp){
 69     for (RG int i=0;i<N;++i) c[i]=b[i]; FFT(a,N,1),FFT(c,N,1);
 70     for (RG int i=0;i<N;++i) a[i]*=c[i]; FFT(a,N,-1);
 71     memset(res,0,sizeof(res));
 72     for (RG int i=0;i<N;++i){
 73     res[i%p]+=(ll)(a[i].real()/N+0.5);
 74     res[i%p]%=pp,a[i]=0;
 75     }
 76     for (RG int i=0;i<p;++i) a[i]=res[i]; return;
 77 }
 78 
 79 il void qpow(C *a,RG int b,RG int pp){
 80     for (RG int i=0;i<N;++i) ans[i]=a[i]; b--;
 81     while (b){ if (b&1) mul(ans,a,pp); mul(a,a,pp),b>>=1; }
 82     memset(res,0,sizeof(res));
 83     for (RG int i=0;i<N;++i)
 84     res[i%p]+=(ll)ans[i].real(),res[i%p]%=pp;
 85     return;
 86 }
 87 
 88 il void exgcd(RG ll a,RG ll b,RG ll &x,RG ll &y){
 89     if (!b){ x=1,y=0; return; }
 90     exgcd(b,a%b,y,x); y-=(a/b)*x; return;
 91 }
 92 
 93 il void work(){
 94     n=gi(),m=gi(),p=gi(); for (N=1;N<=(p<<1);N<<=1) lg++;
 95     for (RG int i=0;i<N;++i) rev[i]=rev[i>>1]>>1|((i&1)<<(lg-1));
 96     r64[1]=8,r64[2]=1091,r64[3]=2311,sieve();
 97     for (RG int k=1;k<=3;++k){
 98     memset(a,0,sizeof(a)); for (RG int i=1;i<=m;++i) a[i%p].real()++;
 99     qpow(a,n,r64[k]); ans1=res[0],memset(a,0,sizeof(a));
100     for (RG int i=1;i<=m;++i) if (vis[i]) a[i%p].real()++;
101     qpow(a,n,r64[k]); ans2=res[0],Ans[k]=(ans1-ans2+r64[k])%r64[k];
102     RG ll u=rhl/r64[k],v=r64[k],x=0,y=0; exgcd(u,v,x,y);
103     aans+=u*x%rhl*Ans[k],aans%=rhl;
104     }
105     printf("%lld",aans); return;
106 }
107 
108 int main(){
109     File("count");
110     work();
111     return 0;
112 }
原文地址:https://www.cnblogs.com/wfj2048/p/6694471.html