青云的机房组网方案(中等) 计蒜客

分析:树形dp

dp[i][v]代表以i为根的子树中权值为v的节点到i的路径和
c[i][v]代表以i为根的子树中权值为v的节点个数
 
对于一个节点 有很多孩子分支,从最左边的分支开始,算跨过当前根节点的
每次分支算一次
最后再算当前根节点i与它子树互质的

每次算的方法就是先算出不互质的,然后总的减
复杂度O(n*m*sqrt(m)*k*2^k) m=max(a[i]),i属于1到n
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <vector>
#include <cmath>
using namespace std;
typedef long long LL;
const int N=1e4+5;
struct Edge
{
    int v,next;
} edge[N<<1];
int head[N],tot;
void add(int u,int v)
{
    edge[tot].v=v;
    edge[tot].next=head[u];
    head[u]=tot++;
}
vector<int>fac[505];
LL dp[N][505],c[N][505],a[N],n,ret,cnt[505],mx,cnt2[505];
void cal(int rt)
{
    memset(cnt,0,sizeof(cnt));
    memset(cnt2,0,sizeof(cnt2));
    for(int i=1; i<=mx; ++i)
    {
        for(int j=1; j*j<=i; ++j)
        {
            if(i%j)continue;
            cnt[j]+=dp[rt][i];
            cnt2[j]+=c[rt][i];
            if(i/j!=j){
              cnt2[i/j]+=c[rt][i];
              cnt[i/j]+=dp[rt][i];
            }
        }
    }
}
void treedp(int u,int f)
{
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==f)continue;
        treedp(v,u);
        LL sum=0,sum2=0;
        for(int j=1; j<=mx; ++j)sum+=dp[u][j],sum2+=c[u][j];
        cal(u);
        for(int j=1; j<=mx; ++j)
        {
            if(!c[v][j])continue;
            int l=(1<<(fac[j].size()));
            LL tot=0,tot2=0;
            for(int k=1; k<l; ++k)
            {
                int tmp=1,t1=0;
                for(int p=0; p<fac[j].size(); ++p)
                    if(k&(1<<p))++t1,tmp*=fac[j][p];
                if(t1&1)tot+=cnt[tmp],tot2+=cnt2[tmp];
                else tot-=cnt[tmp],tot2-=cnt2[tmp];
            }
            ret+=(sum-tot)*c[v][j]+(sum2-tot2)*(dp[v][j]+c[v][j]);
        }
        for(int j=1; j<=mx; ++j)
        {
            dp[u][j]+=dp[v][j]+c[v][j];
            c[u][j]+=c[v][j];
        }
    }
    cal(u);
    LL sum=0,tot=0;
    for(int j=1; j<=mx; ++j)sum+=dp[u][j];
    int l=(1<<(fac[a[u]].size()));
    for(int i=1; i<l; ++i)
    {
        int tmp=1,t1=0;
        for(int j=0; j<fac[a[u]].size(); ++j)
            if(i&(1<<j))++t1,tmp*=fac[a[u]][j];
        if(t1&1)tot+=cnt[tmp];
        else tot-=cnt[tmp];
    }
    ret+=(sum-tot);
    c[u][a[u]]++;
}
int main()
{
    for(int i=1; i<=500; ++i)fac[i].clear();
    for(int i=2; i<=500; ++i)
    {
        int t=i;
        for(int j=2; j<=t; ++j)
        {
            if(t%j)continue;
            fac[i].push_back(j);
            while(t%j==0)t/=j;
        }
    }
    while(~scanf("%lld",&n))
    {
        mx=0;
        for(int i=1; i<=n; ++i)scanf("%lld",&a[i]),mx=max(mx,a[i]);
        memset(head,-1,sizeof(head));
        tot=0;
        memset(dp,0,sizeof(dp));
        memset(c,0,sizeof(c));
        for(int i=1; i<n; ++i)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            add(u,v),add(v,u);
        }
        ret=0;
        treedp(1,0);
        printf("%lld
",ret);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/shuguangzw/p/5571312.html