BZOJ4623 : Styx

$g$是积性函数,可以通过分解质因数在$O(nlog n loglog n)$的时间内求出。

对于$((A imes B) imes C) imes D$,可以转化为$D imes (C imes (B imes A))$,并视向量个数的奇偶性取反答案。

对于$D imes (C imes (B imes A))$,可以将$D imes$,$C imes$,$B imes$用$3$个$3 imes 3$的矩阵表示,然后对树进行点分治即可。

时间复杂度$O(nlog n loglog n)$。

#include<cstdio>
const int N=100010,E=N*20,M=1000010,P=1000000007;
int n,m,i,j,x,y,a[N],mv;
int g[N],v[N<<1],nxt[N<<1],ok[N<<1],ed;
int all,f[N],son[N],now,pos[N];
int G[N],V[E],NXT[E],ED,p[N],d[N],cnt,ans[N][3];
struct Q{int x,y;}q[N];
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];ok[ed]=1;g[x]=ed;}
inline void ADD(int x,int y){V[++ED]=y;NXT[ED]=G[x];G[x]=ED;}
namespace Num{
int inv[M],p[N],tot,mul,m[M],pm[M],g[M];bool v[M];
void init(){
  for(i=2;i<=mv;i++){
    if(!v[i])p[tot++]=i;
    for(j=0;j<tot;j++){
      if(i*p[j]>mv)break;
      v[i*p[j]]=1;
      if(i%p[j]==0)break;
    }
  }
  for(inv[0]=inv[1]=1,i=2;i<=mv;i++)inv[i]=1LL*(P-inv[P%i])*(P/i)%P;
  for(mul=1,i=2;i<=mv;i++)pm[i]=g[i]=1;
}
int exgcd(int a,int b){
  if(!b)return x=1,y=0,a;
  int d=exgcd(b,a%b),t=x;
  return x=y,y=t-a/b*y,d;
}
inline int rev(int a){exgcd(a,P);while(x<0)x+=P;return x%P;}
inline void up(int&a,int b){a+=b;if(a>=P)a-=P;}
inline void add(int x){
  m[x]++;
  g[x]=(1LL*(x-1)*pm[x]%P*m[x]+g[x])%P;
  pm[x]=1LL*pm[x]*x%P;
  up(g[x],pm[x]);
}
inline void del(int x){
  up(g[x],P-pm[x]);
  pm[x]=1LL*pm[x]*inv[x]%P;
  g[x]=(1LL*(P-x+1)*pm[x]%P*m[x]+g[x])%P;
  m[x]--;
}
inline void divadd(int n){
  for(int i=0;p[i]*p[i]<=n;i++)if(n%p[i]==0){
    int j=p[i],old=rev(g[j]);
    while(n%j==0)n/=j,add(j);
    mul=1LL*mul*old%P*g[j]%P;
  }
  if(n==1)return;
  int old=rev(g[n]);
  add(n);
  mul=1LL*mul*old%P*g[n]%P;
}
inline void divdel(int n){
  for(int i=0;p[i]*p[i]<=n;i++)if(n%p[i]==0){
    int j=p[i],old=rev(g[j]);
    while(n%j==0)n/=j,del(j);
    mul=1LL*mul*old%P*g[j]%P;
  }
  if(n==1)return;
  int old=rev(g[n]);
  del(n);
  mul=1LL*mul*old%P*g[n]%P;
}
}
int mat[N][3][3],vec[N][3],A[N][3][3],B[N][3][3],C[3][3];
inline void mul(int a[][3],int b[][3],int c[][3]){
  c[0][0]=(1LL*a[0][0]*b[0][0]+1LL*a[0][1]*b[1][0]+1LL*a[0][2]*b[2][0])%P;
  c[0][1]=(1LL*a[0][0]*b[0][1]+1LL*a[0][1]*b[1][1]+1LL*a[0][2]*b[2][1])%P;
  c[0][2]=(1LL*a[0][0]*b[0][2]+1LL*a[0][1]*b[1][2]+1LL*a[0][2]*b[2][2])%P;
  c[1][0]=(1LL*a[1][0]*b[0][0]+1LL*a[1][1]*b[1][0]+1LL*a[1][2]*b[2][0])%P;
  c[1][1]=(1LL*a[1][0]*b[0][1]+1LL*a[1][1]*b[1][1]+1LL*a[1][2]*b[2][1])%P;
  c[1][2]=(1LL*a[1][0]*b[0][2]+1LL*a[1][1]*b[1][2]+1LL*a[1][2]*b[2][2])%P;
  c[2][0]=(1LL*a[2][0]*b[0][0]+1LL*a[2][1]*b[1][0]+1LL*a[2][2]*b[2][0])%P;
  c[2][1]=(1LL*a[2][0]*b[0][1]+1LL*a[2][1]*b[1][1]+1LL*a[2][2]*b[2][1])%P;
  c[2][2]=(1LL*a[2][0]*b[0][2]+1LL*a[2][1]*b[1][2]+1LL*a[2][2]*b[2][2])%P;
}
void dfs(int x,int y){
  Num::divadd(a[x]);
  vec[x][0]=Num::mul;
  vec[x][1]=x<<2;
  vec[x][2]=1;
  for(int i=g[x];i;i=nxt[i])if(v[i]!=y)dfs(v[i],x),vec[x][2]+=vec[v[i]][2];
  mat[x][0][1]=P-vec[x][2];
  mat[x][0][2]=vec[x][1];
  mat[x][1][0]=vec[x][2];
  mat[x][1][2]=P-vec[x][0];
  mat[x][2][0]=P-vec[x][1];
  mat[x][2][1]=vec[x][0];
  Num::divdel(a[x]);
}
void findroot(int x,int y){
  son[x]=1;f[x]=0;
  for(int i=g[x];i;i=nxt[i])if(ok[i]&&v[i]!=y){
    findroot(v[i],x);
    son[x]+=son[v[i]];
    if(son[v[i]]>f[x])f[x]=son[v[i]];
  }
  if(all-son[x]>f[x])f[x]=all-son[x];
  if(f[x]<f[now])now=x;
}
void dfs1(int x,int y,int z){
  pos[x]=z;f[x]=y;d[x]=d[y]^1;
  mul(mat[x],A[y],A[x]);
  for(int i=g[x];i;i=nxt[i])if(ok[i]&&v[i]!=y)dfs1(v[i],x,z);
}
void dfs2(int x,int y){
  mul(B[y],mat[x],B[x]);
  for(int i=g[x];i;i=nxt[i])if(ok[i]&&v[i]!=y)dfs2(v[i],x);
}
void solve(int x){
  if(!G[x])return;
  f[0]=all=son[x],findroot(x,now=0);
  int i;
  pos[now]=now;
  for(int i=0;i<3;i++)for(int j=0;j<3;j++)A[now][i][j]=0;
  A[now][0][0]=A[now][1][1]=A[now][2][2]=1;
  mul(A[now],mat[now],B[now]);
  f[now]=d[now]=0;
  for(i=g[now];i;i=nxt[i])if(ok[i])dfs1(v[i],now,v[i]),dfs2(v[i],now);
  for(cnt=0,i=G[x];i;i=NXT[i])p[++cnt]=V[i];G[x]=0;
  for(i=1;i<=cnt;i++)if(pos[q[p[i]].x]==pos[q[p[i]].y])ADD(pos[q[p[i]].x],p[i]);
  else{
    int X=q[p[i]].x,Y=q[p[i]].y;
    mul(A[Y],B[f[X]],C);
    ans[p[i]][0]=(1LL*C[0][0]*vec[X][0]+1LL*C[0][1]*vec[X][1]+1LL*C[0][2]*vec[X][2])%P;
    ans[p[i]][1]=(1LL*C[1][0]*vec[X][0]+1LL*C[1][1]*vec[X][1]+1LL*C[1][2]*vec[X][2])%P;
    ans[p[i]][2]=(1LL*C[2][0]*vec[X][0]+1LL*C[2][1]*vec[X][1]+1LL*C[2][2]*vec[X][2])%P;
    if(d[X]^d[Y]){
      ans[p[i]][0]=P-ans[p[i]][0];
      ans[p[i]][1]=P-ans[p[i]][1];
      ans[p[i]][2]=P-ans[p[i]][2];
    }
  }
  for(i=g[now];i;i=nxt[i])if(ok[i])ok[i^1]=0,solve(v[i]);
}
int main(){
  B[0][0][0]=B[0][1][1]=B[0][2][2]=1;
  read(n),read(m);
  for(ed=i=1;i<=n;i++){
    read(a[i]);
    if(a[i]>mv)mv=a[i];
  }
  for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
  Num::init();
  dfs(1,0);
  for(i=1;i<=m;i++){
    read(q[i].x),read(q[i].y);
    if(q[i].x==q[i].y){
      ans[i][0]=vec[q[i].x][0];
      ans[i][1]=vec[q[i].x][1];
      ans[i][2]=vec[q[i].x][2];
    }else ADD(1,i);
  }
  son[1]=n;solve(1);
  for(i=1;i<=m;i++)printf("%d %d %d
",ans[i][0],ans[i][1],ans[i][2]);
  return 0;
}

  

原文地址:https://www.cnblogs.com/clrs97/p/5568834.html