【2019北京集训测试赛(十三)】函树 虚树

题目大意:给你一颗$n$个节点的树,定义$d(x,y)=$点$x$到点$y$最短路上经过的边数。

求$sumlimits_{i=1}^{n} sumlimits_{j=1}^{n} varphi(i imes j) imes d(i,j)$

答案对998244353$取模。

我们对这个式子做一些细微的处理,设最终的答案为$ans$:

$ans=sumlimits_{i=1}^{n} sumlimits_{j=1}^{n} varphi(i imes j) imes d(i,j)$

$=sumlimits_{i=1}^{n} sumlimits_{j=1}^{n} varphi(i)varphi(j)frac{gcd(i,j)}{varphi(gcd(i,j))} imes d(i,j)$

我们设$F(d)=sumlimits_{i=1}^{n} sumlimits_{j=1,d|gcd(i,j)}^{n} varphi(i)varphi(j) imes d(i,j)$

那么,$ans=sumlimits_{d=1}^{n} frac{d}{varphi(d)} sumlimits_{p|d} F(p) imes G(frac{d}{p})$

对于$G(x)$,设$x=prodlimits_{i=1}^{k} p_i$,$p_i$是质数,$G(x)=(-1)^k$

我们考虑如何求$F(d)$。

显然,我们只需要把所有点权能被$d$整除的点找出来,建一棵虚树,统计每条虚树边两端的$sum varphi(i)$,把它们乘起来,再乘上虚树边边长即可。

由于点权等于编号,所以n棵虚树的总点数是$O(nln n)$级别的,单次构建虚树的复杂度是$O(sizelog size)$的,所以并不会$T$掉。

然后就没有了,总复杂度是$O(nlog^2 n)$的。

  1 #include<bits/stdc++.h>
  2 #define M 100005
  3 #define MOD 998244353
  4 #define L long long
  5 using namespace std;
  6 
  7 int pri[M]={0},b[M]={0},phi[M]={0},zf[M]={0},Use=0;
  8 void init(){
  9     phi[1]=1; zf[1]=1;
 10     for(int i=2;i<M;i++){
 11         if(!b[i]) pri[++Use]=i,phi[i]=i-1,zf[i]=-1;
 12         for(int j=1;j<=Use&&i*pri[j]<M;j++){
 13             b[i*pri[j]]=1; zf[i*pri[j]]=-zf[i];
 14             if(i%pri[j]==0) {phi[i*pri[j]]=phi[i]*pri[j]; break;}
 15             phi[i*pri[j]]=phi[i]*(pri[j]-1);
 16         }
 17     }
 18 }
 19 
 20 L pow_mod(L x,L k){L ans=1; for(;k;k>>=1,x=x*x%MOD) if(k&1) ans=ans*x%MOD; return ans;}
 21 vector<int> G[M];
 22 
 23 struct edge{int u,v,next;}e[M*2]={0}; int head[M]={0},use=0;
 24 void add(int x,int y,int z){use++;e[use].u=y;e[use].v=z;e[use].next=head[x];head[x]=use;}
 25 int n,a[M]={0};
 26 
 27 int dep[M]={0},dfn[M]={0},low[M]={0},f[M][20]={0},t=0;
 28 void dfs(int x,int fa){
 29     dep[x]=dep[fa]+1; dfn[x]=++t; f[x][0]=fa;
 30     for(int i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1];
 31     for(int i=0;i<G[x].size();i++) if(G[x][i]!=fa) dfs(G[x][i],x);
 32     low[x]=t;
 33 }
 34 int getlca(int x,int y){
 35     if(dep[x]<dep[y]) swap(x,y); int cha=dep[x]-dep[y];
 36     for(int i=19;~i;i--) if((1<<i)&cha) x=f[x][i];
 37     if(x==y) return x;
 38     for(int i=19;~i;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
 39     return f[x][0];
 40 }
 41 
 42 vector<int> D[M]; L F[M]={0};
 43 
 44 bool cmp(int x,int y){return dfn[x]<dfn[y];}
 45 int point[M]={0},stk[M]={0},is[M]={0},pcnt=0,cnt=0,nowt=0;
 46 void build(){
 47     pcnt=cnt; int siz=0; nowt=0;
 48     sort(point+1,point+cnt+1,cmp);
 49     for(int i=1;i<=cnt;i++){
 50         int last=0;
 51         while(siz&&getlca(stk[siz],point[i])!=stk[siz]) last=stk[siz],stk[siz--]=0;
 52         if(last){
 53             int lca=getlca(last,point[i]);
 54             if(lca!=stk[siz]){
 55                 stk[++siz]=lca;
 56                 point[++pcnt]=lca;
 57                 is[lca]=0;
 58             }
 59         }
 60         stk[++siz]=point[i]; is[point[i]]=1;
 61     }
 62     sort(point+1,point+pcnt+1,cmp);
 63     while(siz) stk[siz--]=0;
 64 }
 65 L sumphi[M]={0},sum=0;
 66 int dfs(int x){
 67     if(is[x]) sumphi[x]=phi[a[x]]; else sumphi[x]=0; 
 68     int v; nowt++;
 69     while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){
 70         dfs(v);
 71         sumphi[x]=(sumphi[x]+sumphi[v])%MOD;
 72     }
 73 }
 74 void getans(int x,L fsum){
 75     int v; nowt++; 
 76     while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){
 77         sum=(sum+1LL*(dep[v]-dep[x])*sumphi[v]%MOD*(fsum+sumphi[x]-sumphi[v]+MOD))%MOD;
 78         getans(v,(fsum+sumphi[x]-sumphi[v])%MOD);
 79     }
 80 }
 81 void solve(int x){
 82     while(pcnt)point[pcnt--]=0; cnt=sum=0;
 83     for(int i=0;i<D[x].size();i++){
 84         point[++cnt]=D[x][i];
 85     }
 86     build();
 87     nowt=1; dfs(point[1]);
 88     nowt=1; getans(point[1],0);
 89     F[x]=sum;
 90 }
 91 
 92 int main(){
 93 //    freopen("in.txt","r",stdin);
 94 //    freopen("out.txt","w",stdout);
 95     init();
 96     scanf("%d",&n);
 97     for(int i=1;i<=n;i++){
 98         a[i]=i; //scanf("%d",a+i);
 99         for(int j=1;j*j<=a[i];j++) if(a[i]%j==0){
100             D[j].push_back(i);
101             if(j*j!=a[i]) D[a[i]/j].push_back(i);
102         }
103     }
104     for(int i=1;i<n;i++){
105         int x,y; scanf("%d%d",&x,&y);
106         G[x].push_back(y); G[y].push_back(x);
107     }
108     dfs(1,0);
109     for(int i=1;i<=n;i++) solve(i);
110     for(int i=n;i;i--){
111         for(int j=i*2;j<=n;j+=i)
112         F[i]=(F[i]-F[j]+MOD)%MOD;
113     }
114     L ans=0;
115     for(int d=1;d<=n;d++)
116     ans=(ans+F[d]*d%MOD*pow_mod(phi[d],MOD-2))%MOD;
117     cout<<ans*2%MOD<<endl;
118     //cout<<ans*2*pow_mod(1LL*n*(n-1)%MOD,MOD-2)%MOD<<endl;
119 }
原文地址:https://www.cnblogs.com/xiefengze1/p/10740311.html