bzoj4818 [Sdoi2017]序列计数(矩阵)

Description

Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。
Input
一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100
Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

Sample Input
3 5 3

Sample Output
33

分析:
至少有一个是质数=所有-没有质数
记f[i][j]为只考虑i个数,前i个数的和在模p意义下为j的方案数
f[i+1][k]+=f[i][j]*num ((j+x)%p=k,符合这个条件的数的个数是num)
考虑矩阵加速
观察矩阵的特点:

% f[i][0] f[i][1] f[i][2] f[i][3]
f[i-1][0] a b c d
f[i-1][1] d a b c
f[i-1][2] c d a b
f[i-1][3] b c d a
p=4
a:%p=0的数的个数
b:%p=1的数的个数
c:%p=2的数的个数
d:%p=3的数的个数

这种矩阵称作循环矩阵,
循环矩阵的乘积还是循环矩阵,所以做矩阵乘法时候只需算第一行,
然后按循环矩阵性质填出其他行即可

看了一下网上的程序
f和矩阵的初始化竟然O(mm)即可,震惊(ΩДΩ)
不太理解为什么矩阵的初始化要写成:
m.m[0][(-i%p+p)%p]++;

tip

注意矩阵的下标是0~p-1
1不是素数
最后的答案:
ans=(ans+f[i]*an.m[0][i]%mod)%mod;
开ll
最后的答案:(f1-f2+mod)%mod //+mod

这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long

using namespace std;

const ll mod=20170408;
int n,mm,p;
int tot=0,f[101];
int sshu[20000010];
bool no[20000010];
struct node{
    ll m[101][101];
    node operator *(const node &a) const
    {
        node ans;
        for (int j=0;j<p;j++)  //只计算第一行 
        {
            ans.m[0][j]=0;
            for (int k=0;k<p;k++)
                ans.m[0][j]=(ans.m[0][j]+m[0][k]*a.m[k][j]%mod)%mod;
        }                                                     
        for (int i=1;i<p;i++)
            for (int j=0;j<p;j++)
            {
                int t=j-1;
                if (t==-1) t=p-1;
                ans.m[i][j]=ans.m[i-1][t];
            }
        return ans;
    }
    void clear()
    {
        memset(m,0,sizeof(m));
    }
    node KSM(ll pp)
    {
        pp--;
        node tt=(* this);
        node a=(* this);
        while (pp)
        {
            if (pp&1) 
               tt=tt*a;
            a=a*a;
            pp>>=1;
        }
        return tt;
    }
};
node m;

void cl()  //求素数 
{
    memset(no,0,sizeof(no));
    no[1]=1;   ///
    for (int i=2;i<=mm;i++)
    {
        if (!no[i])
           sshu[++tot]=i;
        for (int j=1;i*sshu[j]<=mm&&j<=tot;j++)
        {
            no[i*sshu[j]]=1;
            if (i%sshu[j]==0) break;
        }
    }
}

ll solve1()
{
    for (int i=1;i<=mm;i++) f[i%p]++;  //
    for (int i=1;i<=mm;i++) m.m[0][(-i%p+p)%p]++;  //
    for (int i=1;i<p;i++)
        for (int j=0;j<p;j++)
        {
            int t=j-1;
            if (t==-1) t=p-1;
            m.m[i][j]=m.m[i-1][t];
        }
    node an=m.KSM(n-1);
    ll ans=0;
    for (int i=0;i<p;i++) ans=(ans+(ll)f[i]*an.m[0][i]%mod)%mod;  //
    return ans; 
}

ll solve2()
{
    memset(f,0,sizeof(f));
    for (int i=1;i<=mm;i++) if (no[i]) f[i%p]++;
    m.clear();
    for (int i=1;i<=mm;i++) if (no[i]) m.m[0][(-i%p+p)%p]++;
    for (int i=1;i<p;i++)
        for (int j=0;j<p;j++)
        {
            int t=j-1;
            if (t==-1) t=p-1;
            m.m[i][j]=m.m[i-1][t];
        }
    node an=m.KSM(n-1);
    ll ans=0;
    for (int i=0;i<p;i++) ans=(ans+(ll)f[i]*an.m[0][i]%mod)%mod;  //
    return ans; 
}

int main()
{
    scanf("%d%d%d",&n,&mm,&p);
    cl();
    ll f1=solve1();
    ll f2=solve2();
    printf("%lld",(ll)(f1-f2+mod)%mod);   //+mod
    return 0;
}
原文地址:https://www.cnblogs.com/wutongtong3117/p/7673490.html