prime distance on a tree(点分治+fft)

最裸的点分治+fft,调了好久,太菜了。。。。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=200010,inf=1e9;
const double pi=acos(-1);
int f[maxn],t,last[maxn],pre[maxn],other[maxn],siz[maxn],vis[maxn];
int mi,root,rev[maxn],dep,N,n,p[maxn],tot,is[maxn];
ll sum[maxn],c[maxn],cnt[maxn];
void add(int x,int y){++t;pre[t]=last[x];last[x]=t;other[t]=y;}
void getroot(int x,int fa,int ac){
    f[x]=0;
    for(int i=last[x];i;i=pre[i]){
        int v=other[i];
        if(vis[v]||v==fa)continue;
        getroot(v,x,ac);
        f[x]=max(f[x],siz[v]);
    }
    f[x]=max(siz[ac]-siz[x],f[x]);//注意这里是siz[ac]而不是n; 
    if(f[x]<mi){mi=f[x];root=x;}
}
void dfs(int x,int fa,int d){
    c[d]++;
    for(int i=last[x];i;i=pre[i]){
        int v=other[i];
        if(vis[v]||v==fa)continue;
        dfs(v,x,d+1);
    }
}
struct cp{
    double r,i;
    cp operator+(cp&t){cp tp;tp.r=r+t.r;tp.i=i+t.i;return tp;}
    cp operator-(cp&t){cp tp;tp.r=r-t.r;tp.i=i-t.i;return tp;}
    cp operator*(cp&t){cp tp;tp.r=r*t.r-i*t.i;tp.i=t.r*i+t.i*r;return tp;}    
}A[maxn],B[maxn],tmp[maxn],wn,w,x,y;
void fft(cp a[],int n,int flag){
    for(int i=0;i<n;++i){
        rev[i]=rev[i>>1]>>1;
        if(i&1)rev[i]|=(n>>1);
    }
    for(int i=0;i<n;++i)tmp[i]=a[rev[i]];
    for(int i=0;i<n;++i)a[i]=tmp[i];
    for(int i=2;i<=n;i<<=1){
        wn.r=cos(2*pi/i);wn.i=flag*sin(2*pi/i);
        for(int j=0;j<n;j+=i){
            w.r=1;w.i=0;
            for(int k=j;k<j+i/2;++k){
                x=a[k];y=a[k+i/2]*w;
                a[k]=x+y;a[k+i/2]=x-y;
                w=w*wn;
            }
        }
    }
    if(flag==-1)for(int i=0;i<n;++i)a[i].r/=n;
}
void Siz(int x,int fa){
    siz[x]=1;
    for(int i=last[x];i;i=pre[i]){
        int v=other[i];
        if(v==fa||vis[v])continue;
        Siz(v,x);
        siz[x]+=siz[v];
    }
}
void calc(ll a[],int n,int flag){
    for(int i=0;i<n;++i)A[i].r=a[i],A[i].i=0;
    for(int i=0;i<n;++i)B[i].r=a[i],B[i].i=0;
    fft(A,n,1);
    fft(B,n,1);
    for(int i=0;i<n;++i)A[i]=A[i]*B[i];
    fft(A,n,-1);
    for(int i=0;i<n;++i)sum[i]+=flag*(ll)(A[i].r+0.3);
}
void solve(int x){
    mi=1e9;
    ll res=0;
    Siz(x,0);
    for(N=1;N<=siz[x];N<<=1);
    for(int i=0;i<N;++i)cnt[i]=0;
    cnt[0]=1;
    for(int i=last[x];i;i=pre[i]){
        int v=other[i];
        if(vis[v])continue;
        for(N=1;N<=2*siz[v];N<<=1);
        for(int j=0;j<N;++j)c[j]=0;
        dfs(v,x,1);
        calc(c,N,-1);
        for(int j=0;j<N;++j)cnt[j]+=c[j];
    }
    for(N=1;N<=siz[x];N<<=1);
    calc(cnt,N,1);
    /*for(int i=0;i<n;++i)cout<<A[i].r<<' ';
    cout<<endl;*/
    sum[0]=0;
}
void divont(int x){
    mi=1e9;
    Siz(x,0);
    getroot(x,0,x);
    int u=root;
    //cout<<u<<endl;
    solve(u);
    vis[u]=1;
    for(int i=last[u];i;i=pre[i]){
        int v=other[i];
        if(!vis[v])divont(v);
    }
}
int main(){
    cin>>n;
    int x,y;
    for(int i=1;i<n;++i){
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    divont(1);
    for(int i=2;i<=50010;++i){
        if(!is[i]){p[++tot]=i;}
        for(int j=1;j<=tot&&i*p[j]<=50010;++j){
            is[i*p[j]]=1;
            if(i%p[j]==0)break;
        }
    }
    double mu=(double)n*(n-1)/2,res=0;
    for(int i=1;i<=tot&&p[i]<=n;++i){
        res+=sum[p[i]];
    }
    res/=2;
    printf("%.7lf",double(res)/double(mu));
    return 0;
} 
原文地址:https://www.cnblogs.com/dibaotianxing/p/8591762.html