POJ 2449 Remmarguts' Date(单源最短路径 + A*)

题意:

找到从 T 到 S 的第 k 最短路。

思路:

1. 先把 T 点当作源点用 SPFA 算法求到各个点的最短路径,注意此时要用反向图才能求得正确结果;

2. 然后再求从 S -> T 的第 k 最短路,此时要用到启发式搜索里面的技巧,距离 T 的理想距离就是 1 中求的最短路径长度;

3. 由于 A* 算法中用的是优先队列,所以每次最先出队列的一定是小的,设置一个标记数组,当 T 第 k 次出队列即是到 T 的第 k 最短路;

#include <iostream>
#include <algorithm>
#include <queue>
#include <vector>
using namespace std;

const int MAXN = 1010;
const int INFS = 0x3fffffff;

struct ST {
    int v, cost, f;
    ST(int _v, int _cost, int _f) : v(_v), cost(_cost), f(_f) {}
    bool operator < (const ST& other) const { return f > other.f; }
};

struct edge {
    int v, cost;
    edge(int _v, int _cost) : v(_v), cost(_cost) {}
};

int N, M, S, T, K, dis[MAXN];

vector<edge> G[MAXN];
vector<edge> P[MAXN];

void spfa(int s) {
    bool vis[MAXN];
    for (int i = 1; i <= N; i++)
        dis[i] = INFS, vis[i] = false;

    dis[s] = 0;
    vis[s] = true;
    queue<int> Q;
    Q.push(s);

    while (!Q.empty()) {
        int u = Q.front();
        Q.pop();

        for (int i = 0; i < P[u].size(); i++) {
            int v = P[u][i].v;
            int cost = P[u][i].cost;
            if (dis[v] > dis[u] + cost) {
                dis[v] = dis[u] + cost;
                if (!vis[v]) {
                    vis[v] = true; Q.push(v);
                }
            }
        }
        vis[u] = false;
    }
}

int bfs(int s, int t, int k) {
    if (dis[s] == INFS)
        return -1;

    if (s == t) 
        k += 1;

    priority_queue<ST> Q;
    Q.push(ST(s, 0, dis[s]));

    int count[MAXN];
    memset(count, 0, sizeof(count));

    while (!Q.empty()) {
        ST u = Q.top();
        Q.pop();

        count[u.v] += 1;
        if (count[t] == k)
            return u.cost;
        if (count[u.v] > k)
            continue;

        for (int i = 0; i < G[u.v].size(); i++)
            Q.push(ST(G[u.v][i].v, u.cost + G[u.v][i].cost, u.cost + G[u.v][i].cost + dis[G[u.v][i].v]));
    }
    return -1;
}

int main() {
    scanf("%d%d", &N, &M);
    for (int i = 0; i < M; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G[u].push_back(edge(v, w));
        P[v].push_back(edge(u, w));
    }
    scanf("%d%d%d", &S, &T, &K);
    spfa(T);
    printf("%d\n", bfs(S, T, K));
    return 0;
}
原文地址:https://www.cnblogs.com/kedebug/p/2982283.html