hdoj6446(树形DP)

题目链接:https://vjudge.net/problem/HDU-6446

题意:简化题意后就是求距离和的2*(n-1)!倍。

思路:

  简单的树形dp,通过求每条边的贡献计算距离和,边(u,v)的贡献为sz[v]*(n-sz[v])。

#include<cstdio>
#include<algorithm>
using namespace std;

typedef long long LL;
const int maxn=1e5+5;
const int MOD=1e9+7;
int n,cnt,head[maxn],sz[maxn];
LL ans;

struct node{
    int v,nex;
    LL w;
}edge[maxn<<1];

void adde(int u,int v,LL w){
    edge[++cnt].v=v;
    edge[cnt].w=w;
    edge[cnt].nex=head[u];
    head[u]=cnt;
}

void dfs(int u,int fa){
    sz[u]=1;
    for(int i=head[u];i;i=edge[i].nex){
        int v=edge[i].v;
        if(v==fa) continue;
        dfs(v,u);
        sz[u]+=sz[v];
        ans=(ans+edge[i].w*sz[v]%MOD*(n-sz[v])%MOD)%MOD;
    }
}

int main(){
    while(~scanf("%d",&n)){
        ans=0;
        cnt=0;
        for(int i=1;i<=n;++i)
            head[i]=0;
        for(int i=1;i<n;++i){
            int u,v;LL w;
            scanf("%d%d%lld",&u,&v,&w);
            adde(u,v,w);
            adde(v,u,w);
        }
        dfs(1,0);
        for(int i=1;i<n;++i)
            ans=ans*i%MOD;
        ans=ans*2%MOD;
        printf("%lld
",ans);
    }
    return 0;
}

  另外因为前几天学点分治,看到这题想到可以用点分治求距离和。具体做法是,通过求得重心u后求所有点到重心的距离dis[i],然后采用点分治的第二种写法,遍历u的所有子结点v,用sum表示前面计算过的距离总和,num表示前面的结点数,那么对当前遍历的dis[i],其贡献为num*dis[i]+sum。但要注意不要漏了以重心为端点的边,所以将num初始化1,而不是0。

  点分治代码:

#include<cstdio>
#include<algorithm>
using namespace std;

typedef long long LL;
const int maxn=1e5+5;
const int MOD=1e9+7;
const int inf=0x3f3f3f3f;
int n,cnt,head[maxn],sz[maxn],mson[maxn],Min,size,root;
int vis[maxn],t,num;
LL dis[maxn],tmp,sum,ans;

struct node{
    int v,nex;
    LL w;
}edge[maxn<<1];

void adde(int u,int v,LL w){
    edge[++cnt].v=v;
    edge[cnt].w=w;
    edge[cnt].nex=head[u];
    head[u]=cnt;
}

void getroot(int u,int fa){
    sz[u]=1,mson[u]=0;
    for(int i=head[u];i;i=edge[i].nex){
        int v=edge[i].v;
        if(vis[v]||v==fa) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        mson[u]=max(mson[u],sz[v]);
    }
    mson[u]=max(mson[u],size-sz[u]);
    if(mson[u]<Min) Min=mson[u],root=u;
}

void getdis(int u,int fa,LL len){
    dis[++t]=len;
    for(int i=head[u];i;i=edge[i].nex){
        int v=edge[i].v;
        if(vis[v]||v==fa) continue;
        getdis(v,u,(len+edge[i].w)%MOD);
    }
}

void solve(int u){
    sum=0;
    num=1;
    for(int i=head[u];i;i=edge[i].nex){
        int v=edge[i].v;
        if(vis[v]) continue;
        t=0;
        tmp=0;
        getdis(v,u,edge[i].w);
        for(int i=1;i<=t;++i){
            ans=(ans+num*dis[i]%MOD+sum)%MOD;
            tmp=(tmp+dis[i])%MOD;
        }
        sum=(sum+tmp)%MOD;
        num+=t;
    }
}

void fenzhi(int u,int ssize){
    vis[u]=1;
    solve(u);
    for(int i=head[u];i;i=edge[i].nex){
        int v=edge[i].v;
        if(vis[v]) continue;
        Min=inf,root=0;
        size=sz[v]<sz[u]?sz[v]:(ssize-sz[u]);
        getroot(v,0);
        fenzhi(root,size);
    }
}

int main(){
    while(~scanf("%d",&n)){
        cnt=0;
        ans=0;
        for(int i=1;i<=n;++i)
            head[i]=vis[i]=0;
        for(int i=1;i<n;++i){
            int u,v;LL w;
            scanf("%d%d%lld",&u,&v,&w);
            adde(u,v,w);
            adde(v,u,w);
        }
        Min=inf,root=0,size=n;
        getroot(1,0);
        fenzhi(root,n);
        for(int i=1;i<n;++i)
            ans=ans*i%MOD;
        ans=ans*2%MOD;
        printf("%lld
",ans);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/FrankChen831X/p/11420582.html