hdu6706 huntian oy

hdu
好久没写数论函数题了,上一次写还是在纪中学min25筛的时候了,赶紧来一道补下手感

题面:求

[sum_{i=1}^nsum_{j=1}^igcd(i^a-j^a,i^b-j^b)[gcd(i,j)=1] ]

保证(n,a,bleq 10^9,gcd(a,b)=1)

知道((i^a-j^a))这个玩意因式分解后有((i-j)),不妨大力猜一下(gcd(i^a-j^a,i^b-j^b)=i-j),证明明天再补
那么

[egin{aligned} 原式=&sum_{i=1}^nsum_{j=1}^i(i-j)[gcd(i,j)=1]\ =&sum_{i=1}^nsum_{j=1}^i(i-j)sum_{d|gcd(i,j)}mu(d)\ =&sum_{d=1}^nmu(d)dsum_{i=1}^{lfloorfrac{n}{d} floor}sum_{j=1}^ii-j\ end{aligned} ]

后面那个你拿平方和公式算一下应该有(sum_{i=1}^nsum_{j=1}^ii-j=frac{n^3-n}{6}),前面那一部分就是(mu·id),卷上(id)之后就是(epsilon),杜教筛直接求

感觉我回忆杜教筛的时间都比前面推导的时间长。。。

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<math.h>
#include<queue>
#include<set>
#include<map>
#include<unordered_map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=6000000;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define fir first
#define sec second
#define mp(a,b) make_pair(a,b)
#define pb(a) push_back(a)
#define maxd 1000000007
#define inv6 166666668
#define eps 1e-8
int pri[N+10],tot=0,mu[N+10];
int sum[N+10];
bool vis[N+10];
unordered_map<int,ll> has;

int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

void sieve()
{
	mu[1]=1;
	rep(i,2,N)
	{
		if (!vis[i]) {pri[++tot]=i;mu[i]=-1;}
		for (int j=1;j<=tot && i*pri[j]<=N;j++)
		{
			vis[i*pri[j]]=1;
			if (i%pri[j]==0) break;
			else mu[i*pri[j]]-=mu[i];
		}
	}
	rep(i,1,N) sum[i]=(sum[i-1]+i*mu[i]+maxd)%maxd;
}

ll calc(int n)
{
	ll ans=1ll*n*n%maxd*n%maxd;
	ans=(ans-n+maxd)%maxd;
	ans=ans*inv6%maxd;
	return ans;
}

ll query(int n)
{
	if (n<=N) return sum[n];
	if (has[n]) return has[n];
	ll ans=1;int l,r;
	for (l=2;l<=n;l=r+1)
	{
		r=n/(n/l);
		ll tmp=(1ll*(r+l)*(r-l+1)/2)%maxd;
		ans=(ans-tmp*query(n/l)%maxd+maxd)%maxd;
	}
	has[n]=ans;
	return ans;
}
		
ll f(int n,int a,int b)
{
	ll ans=0;
	int l,r;
	for (l=1;l<=n;l=r+1)
	{
		r=n/(n/l);
		ll nowsum=(query(r)-query(l-1)+maxd)%maxd;
		ans=(ans+nowsum*calc(n/l)%maxd)%maxd; 
	}
	return ans;
}

int main()
{
	int T=read();
	sieve();
	while (T--)
	{
		int n=read(),a=read(),b=read();
		printf("%lld
",f(n,a,b));
	}
	return 0;
}
原文地址:https://www.cnblogs.com/encodetalker/p/11406654.html