南京网络赛 E K Sum

K Sum

终于过了这玩意啊啊啊====

莫比乌斯反演,杜教筛,各种分块,积性函数怎么线性递推还很迷==,得继续研究研究

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define maxn 1000000+10
int P[maxn],g[maxn];
bool vis[maxn];

unordered_map<int,int> mp;
int T,n,k;
const int mod =(1e9+7);
int cnt=0;
void init()
{
    g[1]=1;
    for(int i=2;i<maxn;i++){
        g[i]=1;
    }
    for(int i=2; i<maxn; i++)
    {
        //g[i]=1;
        if(!vis[i])
        {
            P[cnt++]=i;
            g[i]=(i*i-1)%mod;
        }
        for(int j=0; j<cnt&&(P[j]*i)<maxn; j++)
        {
            vis[i*P[j]]=1;
            g[P[j]*i]=(g[P[j]]*g[i]%mod);
            if(i%P[j]==0)
            {
                g[P[j]*i]=(g[i]*(P[j]*P[j])%mod)%mod;
                break;
            }

        }
    }
    for(int i=1; i<maxn; i++)
    {
        g[i]=(g[i]+g[i-1]+mod)%mod;
    }
}
int qp(int x,int n)
{
    int ans=1;
    
    while(n)
    {
        if(n&1)
        {
            ans=(ans*x)%mod;
        }
        x=(x*x)%mod;
        n>>=1;
    }
    return ans%mod;
}
int _k;
int Sum(int x,int n)
{
    if(x==1)
    {
        return (_k-1+mod)%mod;
    }
    else
    {
        return (((x*(qp(x,n)-1+mod)%mod)%mod*qp((x-1)%mod,mod-2)%mod)%mod-x+mod)%mod;
    }
}
int Sum2(int n)
{
    int ans=qp(6,mod-2);
    ans=(ans*((n*(n+1)%mod)%mod*(2*n%mod+1)%mod)%mod)%mod;
    return ans;
}
int G(int n)
{
    int ans=0;
    int r;
    for(int i=2; i<=n; i=r+1)
    {
         r=n/(n/i);
        int x=n/i;
        if(x<maxn)
        {
            ans=(ans+g[x]*(r-i+1))%mod;
        }
        else if(mp[x])
        {
            ans=(ans+mp[x]*(r-i+1))%mod;
        }
        else ans=(ans+G(x)*(r-i+1))%mod;
    }

    mp[n]=(Sum2(n)-ans+mod)%mod;
    return mp[n];
}
int cal(int x)
{
    if(x<maxn)return g[x];
    if(mp[x])return mp[x];
    return G(x);
}
char s[maxn];
signed main()
{
    init();
    int ans=0;
    scanf("%lld",&T);
    //string s;
    while(T--)
    {
        ans=0;
        scanf("%lld",&n);
        scanf("%s",s);
        k=0;
        _k=0;
        int _n=strlen(s);
        for(int i=0; i<_n; i++)
        {
             _k=(_k*10+s[i]-'0')%(mod);
            k=((k*10)+s[i]-'0')%(mod-1);
        }
        int r;

        for(int i=1; i<=n; i=r+1) ///i
        {
            r=n/(n/i);
            ans=(ans+((Sum((n/i),k))%mod*(cal(r)-cal(i-1)+mod)%mod)%mod)%mod;
        }
        cout<<ans<<'
';
    }


}
原文地址:https://www.cnblogs.com/liulex/p/11457654.html