CF809E Surprise me!

一、题目

点此看题

二、解法

哼哼哼,(color{orange}{C202044zxy}) 又切黑题,其实他就是只会做套路题而且不带脑子的伞兵

利用欧拉函数的不完全积性,可以发现 (varphi(xcdot y)=frac{varphi(x)varphi(y)}{varphi(gcd)}cdot gcd),所以我们直接枚举 (gcd)

[egin{aligned}&sum_{g=1}^nfrac{g}{varphi(g)}sum_{g|a_i}sum_{g|a_j}varphi(a_i)varphi(a_j)dis(i,j)cdot[(i,j)=1]\=&sum_{g=1}^nfrac{g}{varphi(g)}sum_{g|i}sum_{g|j}varphi(i)varphi(j)dis(p_i,p_j)cdot[(i,j)=1]\=&sum_{g=1}^nfrac{g}{varphi(g)}sum_{d=1}^{n/d}mu(d)sum_{gd|i}sum_{gd|j}varphi(i)varphi(j)dis(p_i,p_j)\=&sum_{T=1}^nsum_{d|T}mu(d)cdot frac{T/d}{varphi(T/d)}sum_{T|i}sum_{T|j}varphi(i)varphi(j)dis(p_i,p_j)\end{aligned} ]

前面那东西可以 (O(nlog n)) 卷出来,计算后面的式子考虑上树,我们把 (T|u) 的所有点 (p_u) 建出虚树。然后考虑每一条虚树边的贡献,就是子树内关键点的 (varphi) 之和乘上子树外关键点的 (varphi) 之和,因为链在虚树中被压缩成了边所以还要乘上链的长度。

点数是调和级数级别的,所以时间复杂度 (O(nlog^2n))

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int M = 200005;
const int MOD = 1e9+7;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,ans,s[M],p[M],a[M<<1],dep[M],sum[M],in[M],out[M];
int Ind,cnt,phi[M],vis[M],mu[M],o[M],inv[M],fa[M][20];
vector<int> g[M];
void init()
{
	phi[1]=mu[1]=inv[0]=inv[1]=1;
	for(int i=2;i<=n;i++)
	{
		inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
		if(!vis[i])
		{
			p[++cnt]=i;
			mu[i]=-1;phi[i]=i-1;
		}
		for(int j=1;j<=cnt && i*p[j]<=n;j++)
		{
			vis[i*p[j]]=1;
			if(i%p[j]==0)
			{
				phi[i*p[j]]=phi[i]*p[j];
				break;
			}
			mu[i*p[j]]=-mu[i];
			phi[i*p[j]]=phi[i]*(p[j]-1);
		}
	}
	for(int i=1;i<=n;i++) p[i]=vis[i]=0;
	for(int i=1;i<=n;i++)
		for(int j=i;j<=n;j+=i)
			o[j]=(o[j]+mu[i]*(j/i)*inv[phi[j/i]])%MOD;
}
void pre(int u,int p)
{
	dep[u]=dep[p]+1;
	in[u]=++Ind;fa[u][0]=p;
	for(int i=1;i<20;i++)
		fa[u][i]=fa[fa[u][i-1]][i-1];
	for(auto v:g[u]) if(v^p)
		pre(v,u);
	out[u]=++Ind;
}
int cmp(int x,int y)
{
	int t1=x>0?in[x]:out[-x];
	int t2=y>0?in[y]:out[-y];
	return t1<t2;
}
int lca(int u,int v)
{
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=19;i>=0;i--)
		if(dep[fa[u][i]]>=dep[v])
			u=fa[u][i];
	if(u==v) return u;
	for(int i=19;i>=0;i--)
		if(fa[u][i]^fa[v][i])
			u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
void dfs(int u,int fa,int &res)
{
	for(auto v:g[u])
		dfs(v,u,res),sum[u]=(sum[u]+sum[v])%MOD;
	int len=dep[u]-dep[fa];
	res=(res+(m-sum[u])*sum[u]%MOD*len)%MOD;
}
signed main()
{
	n=read();init();
	for(int i=1;i<=n;i++) p[read()]=i;
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		g[u].push_back(v);
		g[v].push_back(u);
	}
	pre(1,0);
	for(int i=1;i<=n;i++) g[i].clear();
	for(int T=1;T<=n;T++)
	{
		int k=0,k2=0,top=0;m=0;
		for(int i=T;i<=n;i+=T)
		{
			a[++k]=p[i],vis[p[i]]=1;
			sum[p[i]]=phi[i];m=(m+phi[i])%MOD;
		}
		sort(a+1,a+1+k,cmp);k2=k;
		for(int i=1;i<k2;i++)
		{
			int tmp=lca(a[i],a[i+1]);
			if(!vis[tmp]) a[++k]=tmp,vis[tmp]=1;
		}
		if(!vis[1]) a[++k]=1,vis[1]=1;k2=k;
		for(int i=1;i<=k2;i++) a[++k]=-a[i];
		sort(a+1,a+1+k,cmp);
		for(int i=1;i<=k;i++)
		{
			if(a[i]>0) s[++top]=a[i];
			else
			{
				int t=s[top--];
				if(t==1) break;
				g[s[top]].push_back(t);
			}
		}
		int res=0;dfs(1,0,res);
		ans=(ans+2*res*o[T])%MOD;
		for(int i=1;i<=k;i++) if(a[i]>0)
			sum[a[i]]=vis[a[i]]=0,g[a[i]].clear();
	}
	ans=ans*inv[n]%MOD*inv[n-1]%MOD;
	printf("%lld
",(ans+MOD)%MOD);
}
原文地址:https://www.cnblogs.com/C202044zxy/p/15545733.html