图论之堆优化的Prim

本题模板,最小生成树,洛谷P3366

题目描述

如题,给出一个无向图,求出最小生成树,如果该图不连通,则输出orz

输入输出格式

输入格式:

第一行包含两个整数N、M,表示该图共有N个结点和M条无向边。(N<=5000,M<=200000)

接下来M行每行包含三个整数Xi、Yi、Zi,表示有一条长度为Zi的无向边连接结点Xi、Yi

输出格式:

输出包含一个数,即最小生成树的各边的长度之和;如果该图不连通则输出orz

输入输出样例

输入样例#1: 
4 5
1 2 2
1 3 2
1 4 3
2 3 4
3 4 3
输出样例#1: 
7

说明

时空限制:1000ms,128M

数据规模:

对于20%的数据:N<=5,M<=20

对于40%的数据:N<=50,M<=2500

对于70%的数据:N<=500,M<=10000

对于100%的数据:N<=5000,M<=200000

 

下面介绍堆优化后的Prim

首先清楚prim的过程:

给出一个适用于这类问题的推论:给定一张无向图G=(V,E)。n=V的大小,m=E的大小。从E中选出k<n-1条边构成G的一个生成森林。若再从剩余的m-k条边选n-1-k条添加到生成森林中,使其成为G的生成树,并且选出的边的权值之和最小,则该生成树一定包含这m-k条边中连接两个森林的不连通节点的最小边

无论是Prim还是Kruskal都是基于这个推论,但Prim略微有一些改动。

Prim算法总是维护最小生成树的一部分。最初,Prim算法仅确定1号节点属于最小生成树。在任意的时刻,设已经确定属于最小生成树的节点集合为T,剩余节点集合为S。Prim算法找到min(x属于S,y属于T){z},即两个端点分别属于集合S,T的权值最小的边,然后把点x从集合S中删除,加入到集合T中去,并把Z累计到答案(最后答案就是最小生成树的边权值和)。

具体来说,可以维护数组d:对于x属于S,则把d[x]表示节点x与集合T中的节点之间权值最小的边的权值。若x属于T,则d[x]就等于x被加入T选出的最小边的权值

用一个数组标记节点是否属于T。每次从未标记的节点中选出d值最小的,把它标记(加入T),同时
扫描所有出边,更新另一个端点的d值。最后得出答案

可用二叉堆将上述的d数组优化,但其实都不如Kruskal方便。因此Prim主要用于稠密图,尤其是完全图的最下生成树的求解

 

那么所谓的二叉堆优化实际上就是对于每一次拓展的边加入到一个小根堆中,下面笔者的代码实现用的是priority_queue(我懒)


具体实现看代码

注意不仅要判断堆是否为空还要统计已经维护的点的个数,确保还是小于等于n的

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int maxn=2e5+15;
const int mxn=5e3+15;
struct node
{
    int t;int d;
    bool operator < (const node &a) const
    {
        return d>a.d;
    }
};
int n,m;
int vis[mxn];
vector <node> e[mxn];
priority_queue <node> q;
inline int read()
{
    char ch=getchar();
    int s=0,f=1;
    while (!(ch>='0'&&ch<='9')) {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') {s=(s<<3)+(s<<1)+ch-'0';ch=getchar();}
    return s*f;
}
ll prim()
{
    ll ans=0;
    int cnt=0;
    q.push((node){1,0});
    while (!q.empty()&&cnt<=n)
    {
        node k=q.top();q.pop();
        if (vis[k.t]) continue;
        vis[k.t]=1;
        ans+=k.d;
        cnt++;
        for (int i=0;i<e[k.t].size();i++)
        if (!vis[e[k.t][i].t]){
            q.push((node){e[k.t][i].t,e[k.t][i].d});
        }
    }
    return ans;
}
int main()
{
    n=read();m=read();
    for (int i=1;i<=m;i++)
    {
        int x=read(),y=read(),z=read();
        e[x].push_back((node){y,z});e[y].push_back((node){x,z});
    }
    printf("%lld",prim());
    return 0;
}
原文地址:https://www.cnblogs.com/xxzh/p/9201355.html