Hackerrank > 101 Hack 49 > Summing the Path Weights Between Nodes

类似拓扑排序的思想,每次从最末端的节点收缩。记录收缩过程中每个节点的红点和黑点的数量,计算每条边被使用的次数,复杂度O(2*N)。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<algorithm>
#include<iostream>
#include<queue>
#include<map>
#include<cmath>
#include<set>
#include<stack>
#define ll long long
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
#define cls(name,x) memset(name,x,sizeof(name))
using namespace std;
const int inf=1<<28;
const int maxn=1e5;
const int maxm=110;
const int mod=1e9+7;
const double pi=acos(-1.0);
int n;
struct node
{
    int to;
    ll val;
};
vector<node> edge[maxn];//邻接表
int sum_red,sum_black;
struct vex
{
    int red,black;
    int inc;//入度
}v[maxn];
ll ans;
void solve()
{
    int vis[maxn];
    memset(vis,0,sizeof(vis));
    queue<int> Q;
    for(int i=0;i<n;i++)
        if(v[i].inc==1)
        Q.push(i);
    while(!Q.empty())
    {
        int temp=Q.front();
        Q.pop();
        vis[temp]=1;
        for(int i=0;i<edge[temp].size();i++)
        {
            int x=edge[temp][i].to;
            if(vis[x]==0)
            {
                ans+=edge[temp][i].val*((sum_red-v[temp].red)*v[temp].black+v[temp].red*(sum_black-v[temp].black));
                v[x].red+=v[temp].red;
                v[x].black+=v[temp].black;
                v[x].inc--;
                if(v[x].inc==1)
                    Q.push(x);
            }
        }
    }
}
int main()
{
    //freopen("in.txt","r",stdin);
    while(~scanf("%d",&n))
    {
        ans=0;sum_red=sum_black=0;
        for(int i=0;i<n;i++)
            {v[i].inc=0;}
        for(int i=0;i<n;i++)
            edge[i].clear();
        for(int i=0;i<n;i++)
        {
            int a;
            scanf("%d",&a);
            if(a==0)
            {
                v[i].red=1;
                v[i].black=0;
            }
            else
            {
                v[i].red=0;
                v[i].black=1;
            }
            sum_red+=v[i].red;
            sum_black+=v[i].black;
        }
        for(int i=0;i<n-1;i++)
        {
            int a,b,c;
            scanf("%d %d %d",&a,&b,&c);
            a--;b--;
            v[a].inc++;
            v[b].inc++;
            node t;
            t.to=b; t.val=c;
            edge[a].push_back(t);
            t.to=a;
            edge[b].push_back(t);
        }
        solve();
        printf("%lld
",ans);
    }
}

  

原文地址:https://www.cnblogs.com/mgz-/p/6908253.html