[Sdoi2017]序列计数

4818: [Sdoi2017]序列计数

Time Limit: 30 Sec  Memory Limit: 128 MB
Submit: 317  Solved: 210

Description

Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望
,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。

Input

一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100

Output

一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

Sample Input

3 5 3

Sample Output

33
 
先是容斥,总答案=所有的方案数-只有合数的方案数。
可以推出暴力的DP方程:
f[i][j]代表长度为imod
pj的方案数
f[i][(j+k)%p]+=f[i-1][j]
然后这样的复杂度是O(n*m*p)20分。
考虑优化,发现对于每个i的转移都是一样的,所以可以用矩阵快速幂。
A为转移矩阵,发现每一个j都可以转移到j+k这个位置,A[j][j+k]+1
这样暴力构矩阵复杂度为O(m*p),80分。
还可以优化,发现有很多地方可以记忆化,所以可以先预处理出1-m每个数中
p的每个剩余系的数量,然后直接加进去即可。
复杂度Om)。100分。
 1 #include <algorithm>
 2 #include <iostream>
 3 #include <cstdlib>
 4 #include <cstring>
 5 #include <cstdio>
 6 #include <cmath>
 7 #define maxn 20000010
 8 #define mod 20170408
 9 #define LL long long
10 using namespace std;
11 LL a[110][110][2],b[110][110][2],s[110][110];
12 int su[maxn/10],he[maxn],s1[110],s2[110];
13 LL n,m,p;
14 bool bj[maxn];
15 void mul(int a1,int b1){
16   for(int i=0;i<p;i++)
17     for(int j=0;j<p;j++)
18       for(int k=0;k<p;k++)
19     s[i][j]+=a[i][k][a1]*a[k][j][b1],s[i][j]%=mod;
20   for(int i=0;i<p;i++)
21     for(int j=0;j<p;j++)
22       a[i][j][a1]=s[i][j],s[i][j]=0;
23 }
24 void mul1(int a1,int b1){
25   for(int i=0;i<p;i++)
26     for(int j=0;j<p;j++)
27       for(int k=0;k<p;k++)
28     s[i][j]+=b[i][k][a1]*b[k][j][b1],s[i][j]%=mod;
29   for(int i=0;i<p;i++)
30     for(int j=0;j<p;j++)
31       b[i][j][a1]=s[i][j],s[i][j]=0;
32 }
33 int main()
34 {
35   freopen("count.in","r",stdin);
36   freopen("count.out","w",stdout);
37   LL tot=0,tot1=0;
38   scanf("%lld%lld%lld",&n,&m,&p);
39   for(int i=1;i<=m;i++) a[0][i%p][0]++;
40   for(int i=1;i<=m;i++) s1[i%p]++;
41   for(int i=0;i<p;i++)
42     for(int j=0;j<p;j++)
43       a[j][(j+i)%p][1]+=s1[i];
44   int mi=n-1;
45   while(mi){
46     if(mi%2) mul(0,1);
47     mi>>=1;
48     mul(1,1);
49   }
50   bj[1]=1;
51   for(int i=2;i<=m;i++){
52     if(!bj[i]) su[++tot]=i;
53     for(int j=1;j<=tot;j++){
54       if(su[j]*i>m) break;
55       bj[su[j]*i]=1;
56       if(i%su[j]==0) break;
57     }
58   }
59   for(int i=1;i<=m;i++) if(bj[i])b[0][i%p][0]++;
60   for(int i=1;i<=m;i++) if(bj[i])s2[i%p]++;
61   for(int i=0;i<p;i++)
62     for(int j=0;j<p;j++)
63       b[j][(j+i)%p][1]+=s2[i];
64   mi=n-1;
65   while(mi){
66     if(mi%2) mul1(0,1);
67     mi>>=1;
68     mul1(1,1);
69   }
70   printf("%lld",(a[0][0][0]-b[0][0][0]+mod)%mod);
71   return 0;
72 }





原文地址:https://www.cnblogs.com/pantakill/p/6708520.html