题解[Zeta的数论题[加强版]]

题目链接

题意:

(f(t)=sumlimits_{k=1}^t k[gcd(k,t)=1])

(sumlimits_{i=1}^n sumlimits_{j=1}^n (i^2+j^2+ij)f(gcd(i,j)),nleq 10^{10})

按照原题的做法,将原式莫反,得到

(sumlimits_{T=1}^n left(T^2sumlimits_{d|T}^{}f(d)mu(T/d) ight) sumlimits_{i=1}^{leftlfloorfrac{n}{T} ight floor} sumlimits_{j=1}^{leftlfloorfrac{n}{T} ight floor}i^2+j^2+ij)

以及还有 (f(d)=dfrac{1}{2}left[d=1 ight]+dfrac{1}{2}dvarphi(d))

这仍然以数论分块做,但由于数据范围的扩大,需要快速求出(g(T)=T^2sumlimits_{d|T}f(d)mu(T/d)) 的前缀和。

(f) 的通项带入 (g) ,得 (g(T)=dfrac{1}{2}T^2left(mu(T)+sumlimits_{d|T}dvarphi(d)mu(T/d) ight))

可以拆为两部分,分别求前缀和。

(r(T)=T^2mu(T),h(T)=T^2sumlimits_{d|T}dvarphi(d)mu(T/d)) ,则有 (g(T)=dfrac{1}{2}(r(T)+h(T)))

首先做相对简单的 (r(T)) 的前缀和,这可以直接杜教筛:

(sumlimits_{d|n}r(d)left(frac{n}{d} ight)^2=n^2sumlimits_{d|n}mu(d)=n^2[n=1])

然后是 (h(T)) ,仍然可以杜教筛,但为了构造先看看其 ( ext{dgf})

(H(z)=sumlimits_{ngeq 1}dfrac{1}{n^{z-2}}sumlimits_{d|n}dvarphi(d)mu(n/d)=left(sumlimits_{ngeq 1}dfrac{varphi(n)}{n^{z-3}} ight)left(sumlimits_{ngeq 1}dfrac{mu(n)}{n^{z-2}} ight)=dfrac{zeta(z-4)}{zeta(z-3)zeta(z-2)})

所以 (H(z)) 卷上 (zeta(z-3)zeta(z-2)) 能得到 (zeta(z-4))

([n^{-z}]zeta(z-4)=n^4) 的前缀和是 (frac{1}{30}n(n+1)(2n+1)(3n^2+3n-1))(O(1)) 求出。

剩下就是杜教筛时 (zeta(z-3)zeta(z-2)) 的前缀和。

(s(n)=[n^{-z}]zeta(z-3)zeta(z-2)=sumlimits_{d|n}d^3(frac{n}{d})^2=n^2sumlimits_{d|n}d)

(sumlimits_{n=1}^{N}s(n)=sumlimits_{n=1}^{N}n^2sumlimits_{d|n}d=sumlimits_{d=1}^{N}dsumlimits_{1leq nleq N,d|n}n^2=sumlimits_{d=1}^{n}d^3sumlimits_{n=1}^{leftlfloorfrac{n}{d} ight floor} n^2)

这可以杜教筛时再套一个数论分块算出来。

时间复杂度仍然是 (O(n^{frac{2}{3}})) 的,因为套的这个数论分块,仍与接下来杜教筛的递归复杂度同阶。

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=4641588,mod=998244353,hmod=1e5+7;
const ll inv2=499122177,inv4=748683265,inv6=166374059;
const ll inv12=582309206,inv30=432572553;
int p[N+10],prm,mu[N+10];bool b[N+10];char ch;
ll d[N+10],pw[N+10],pp[N+10],sum[N+10],sum_[N+10],p2[N+10],ans,n;
void write(ll x){if(x>9)write(x/10);putchar(48+x%10);}
void pre_work(){
	register int i,j,k,h;
	mu[1]=d[1]=sum[1]=1;
	for(i=2;i<=N;++i){
		if(!b[i])p[++prm]=i,d[i]=pp[i]=i+1,p2[prm]=1ll*i*i%mod,sum[i]=(p2[prm]-i-1)%mod,mu[i]=-1;
		for(j=1;j<=prm&&(k=i*p[j])<=N;++j){
			b[k]=1;
			if(i%p[j]==0){
				pp[k]=pp[i]*p[j]+1;
				d[k]=d[i]/pp[i]*pp[k];
				sum[k]=sum[i]*p2[j]%mod;
				if((h=i/p[j])%p[j])sum[k]=(sum[k]+sum[h]*p[j]%mod)%mod;
				mu[k]=0;break;
			}
			mu[k]=-mu[i];
			d[k]=d[i]*d[p[j]]%mod;
			pp[k]=pp[p[j]];
			sum[k]=sum[i]*sum[p[j]]%mod;
		}
	}
	for(i=1;i<=N;++i)d[i]=(d[i-1]+d[i]*i%mod*i%mod)%mod;
	for(i=1;i<=N;++i)sum_[i]=(sum_[i-1]+1ll*i*i%mod*mu[i]%mod)%mod;
	for(i=1;i<=N;++i)sum[i]=(sum[i-1]+sum[i]*i%mod*i%mod)%mod;
}
inline ll s_1(ll x){x%=mod;return x*(x+1)%mod*inv2%mod;}
inline ll s_2(ll x){x%=mod;return x*(x+1)%mod*(x<<1|1)%mod*inv6%mod;}
inline ll s_3(ll x){x%=mod;return x*(x+1)%mod*(x*x%mod+x)%mod*inv4%mod;}
inline ll s_4(ll x){x%=mod;return x*(x+1)%mod*(x<<1|1)%mod*(3*x*x%mod+3*x-1)%mod*inv30%mod;}
inline ll g(ll x){x%=mod;return x*(x+1)%mod*(11*x%mod*x%mod+7*x%mod)%mod*inv12%mod;}
struct hash_table{
	int h[N+10],nextn[N+10],edg;ll from[N+10],to[N+10];
	void add(int s,ll x,ll y){to[++edg]=y;from[edg]=x;nextn[edg]=h[s];h[s]=edg;}
	ll inquiry(ll x){
		for(int i=h[x%hmod];i;i=nextn[i])if(from[i]==x)return to[i];
		return 0;
	}
}H,HH,H_;
ll getsum1(ll n){
	if(n<=N)return d[n];
	ll tmp=HH.inquiry(n);
	if(tmp)return tmp;
	tmp=0;
	for(register ll l=1,r;l<=n;l=r+1){
		r=n/(n/l);
		tmp=(tmp+(s_3(r)-s_3(l-1))*s_2(n/l)%mod)%mod;
	}
	HH.add(n%hmod,n,tmp);
	return tmp;
}
ll getsum(ll n){
	if(n<=N)return sum[n];
	ll tmp=H.inquiry(n);
	if(tmp)return tmp;
	tmp=s_4(n);
	for(register ll l=2,r;l<=n;l=r+1){
		r=n/(n/l);
		tmp=(tmp-(getsum1(r)-getsum1(l-1))*getsum(n/l)%mod)%mod;
	}
	H.add(n%hmod,n,tmp);
	return tmp;
}
ll getsum_(ll n){
	if(n<=N)return sum_[n];
	ll tmp=H_.inquiry(n);
	if(tmp)return tmp;
	tmp=1;
	for(register ll l=2,r;l<=n;l=r+1){
		r=n/(n/l);
		tmp=(tmp-(s_2(r)-s_2(l-1))*getsum_(n/l)%mod)%mod;
	}
	H_.add(n%hmod,n,tmp);
	return tmp;
}
main(){
	pre_work();ch=getchar();
	while(ch>47)n=(n<<1)+(n<<3)+(ch^48),ch=getchar();
	for(register ll l=1,r;l<=n;l=r+1){
		r=n/(n/l);
		ans=(ans+(getsum(r)-getsum(l-1)+getsum_(r)-getsum_(l-1))*g(n/l)%mod)%mod;
	}
	ans=(ans*inv2%mod+mod)%mod;
	write(ans);
}
原文地址:https://www.cnblogs.com/Y-B-X/p/15387909.html