匈牙利匹配和最大权值匹配算法

在使用多目标跟踪算法时,接触到了匈牙利匹配算法,一直没时间好好总结下,现在来填坑。。

1. 基础概念

1.1 二分图

我们之前了解过图(Graph)的概念,图一般可以用G(V, E)来表示,V表示图中的顶点,E表示图中的边。如下面,这个图中有四个顶点,五条边。

二分图(Bipartite graph)是一类特殊的图,它可以被划分为两个部分,每个部分内的点互不相连,如下面是一个典型的二分图,图中的点可分为X,Y两部分,X内部的点互补相连,Y内部的点也互不相连。我们也可以发现二分图中一定不存在环。(二分图又称为二部图,偶图)

1.2 二分图匹配

二分图的匹配可以看成是二分图的一个子图,该子图满足以下条件:子图中不存在有任意的两条边依附于同一个顶点

如下面左图是一个二分图,右图就是它的一个匹配,右图中每条边都没有公共端点,可以看出其是二分图的一个子集。概念上有点绕,我们通俗点理解:有一个班级的学生要结成男女两两一组,但每个学生只想自己喜欢的异性结成一组,于是这就会有冲突,而匹配就是要找出这样的男女组成,保证一个男生只和一个女生组合。

二分图的匹配问题在有限资源分配时经常会用到,主要是为了保证某一个资源分且只分到某一个用户的手中

1.3 二分图最大匹配

二分图最大匹配,就是在二分图的所有匹配中,找出边数最大的匹配。还是以上面的情景来理解:有一个班级的学生要结成男女两两一组,但每个学生只想自己喜欢的异性结成一组,匹配是保证一个男生只和一个女生组合,而最大匹配则是尽量保证没有人落单,即二分图最大匹配就是要给出一个最优方案,使得结成的组数最多

匈牙利算法就是寻找二分图最大匹配方案的经典算法

1.4 二分图最大权完美匹配

首先说二分图完美匹配,如果一个二分图的所有点都是匹配点(匹配边中某一条边的端点),则称这个匹配是完美匹配。回到上面的情景,完美匹配就是可以得到一个方案,使得所有男女同学都可以结成两两一组。

  • 完美匹配要求二分图两部分的点数相等,因为若X中包括4个点,Y中包含5个点,则Y中必然会有一个点不会被匹配
  • 完美匹配一定是最大匹配,最大匹配不一定是完美匹配

二分图最大权完美匹配:假定有一个二分图 G,每条边有一个权值(可为负数),权值和最大的完美匹配是二分图最大权完美匹配。

还有一些概念,二分图最优匹配,二分图最大权值匹配,二分图最小权值匹配(将权值转化为负数,即转为最大权值匹配),都是指二分图最大权完美匹配。

求解二分图最大权完美匹配一般采用KM(Kuhn-Munkres)匹配算法

2. 匈牙利匹配算法

参考:https://zhuanlan.zhihu.com/p/105212518, https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session

2.1 匈牙利算法解析

匈牙利算法(Hungary Algorithm)是由Edmonds在1965年提出的,是求解二分图最大匹配的经典算法,算法的核心就是根据一个初始匹配不停的找增广路,直到没有增广路为止。几个概念如下:

  • 交替路:从任意一个未匹配点出发,依次经过未匹配边-匹配边-非匹配边-匹配边-未匹配边……所得到的路径被称为交替路。(即未匹配边和匹配边交替出现)
  • 增广路:如果一条交替路的终点是一个未匹配点,那么这条路径是增广路,由于从未匹配点出发,又在未匹配点结束,未匹配边比匹配边多一条。
  • 增广路定理:如果可以找到一条增广路,那么将匹配边与未匹配边互换,这个匹配就可以多一条边,否则当前匹配就是最大匹配。即任意一个匹配是最大匹配的充分必要条件是不存在增广路。

增广路互换的实质可以这么考虑,如下图:从未匹配点 A 出发,A 想与 B 匹配,于是通过未匹配边找到 B,然而 B 已经是匹配点,于是只能经过匹配边去问 C 能不能与别人匹配,C 经过未匹配边找到 D,由于 D 是未匹配点,所以 C 成功与 D 匹配。CD 之间的边变为匹配边;BC 之间解除关系,变为未匹配边;AB 之间建立关系,变为匹配边。这便是增广路互换的实质。

因此,总结下匈牙利算法的思想:就是不断的寻找增广路,如果找到,就互换匹配边和非匹配边,让匹配边增加一条,如果找不到匹配边了,就表示已经是最大匹配了。

2.2 匈牙利算法代码实现

python实现如下:

import math
import numpy as np

# 匈牙利匹配算法
class HungaryMatch(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分图的必须采用numpy array 格式")
        assert graph.ndim == 2, print("二分图的维度必须为2")
        self.garph = graph
        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        # self.vx = np.zeros(cols, dtype=np.int32)   # visit flag, 横向结点的访问标志
        # self.vy = np.zeros(rows, dtype=np.int32)  # visit flag, 竖向结点的访问标志

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 横向结点匹配的竖向结点的index (默认-1,表示未匹配任何竖向结点)
        self.match_count = 0  # 总共有多少条匹配边

    def match(self):
        for y in range(self.rows):  # 从每一竖向结点开始,寻找增广路
            self.vx = np.zeros(self.cols, dtype=np.int32)  # visit flag, 横向结点的访问标志置0
            self.vy = np.zeros(self.rows, dtype=np.int32)  # visit flag, 竖向结点的访问标志置0
            if self.dfs(y):
                self.match_count += 1  # 采用dfs寻找增广路,如果找到,匹配边加1
        return self.match_index, self.match_count

    def dfs(self, y):  # 递归版深度优先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0 and self.garph[y][x] == 1:  # 横向结点x没有访问过,而且竖向结点y和横向结点x有边连接
                self.vx[x] = 1
                # 两种情况:一是结点x没有匹配,那么找到一条增广路;二是X结点已经匹配,采用DFS,沿着X继续往下走,最后若以未匹配点结束,则也是一条增广路
                if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                    self.match_index[x] = y  # 未匹配边变成匹配边
                    print(y, x, self.match_index)
                    return True
        return False
if __name__ == '__main__':
    graph = np.array([[0, 1, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
    hungary = HungaryMatch(graph)
    index, count = hungary.match()
    print(index)  # [-1  1  2  0]:三组匹配边(x, y): (1, 1), (2, 2), (3, 0)
    print(count)  # 3:共有三条匹配边        

cpp实现如下:

参考:https://zhuanlan.zhihu.com/p/104901134?utm_source=wechat_session

bool dfs(int x){
   for(int i=0; i<m; i++){
      if (edge[x][i]==0 || vis[i]) continue;
      vis[i] = true;
      if (y_match[i]==-1 || dfs(y_match[i]))
           return true;
   }
   return false;
}

int cnt = 0;
for (int i=0; i<n; i++){
    memset(vis, false, sizeof(vis));
    if (dfs(i))
         cnt++;
}

3. KM算法(Kuhn-Munkres Algorithm)

参考:https://blog.sengxian.com/algorithms/km,https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/

3.1 可行顶标和相等子图

二分图最优匹配(最大权值匹配)的经典算法是由Kuhn和Munkres独立提出的KM算法,值得一提的是最初的KM算法是在1955年和1957年提出的,因此当时的KM算法是以矩阵为基础的,随着匈牙利算法被Edmonds提出之后,现有的KM算法利用匈牙利树可以得到更漂亮的实现。

KM算法是通过给每个顶点一个标号(叫做顶标,或者节点函数)来把求最大权完美匹配的问题转化为求完美匹配的问题的。可以简单理解为节点函数就是节点的一个值。几个概念如下:

  • 顶标(节点函数):指的是图中的每个顶点,给它赋予一个值(就像边的权重值),这个值也称为节点函数值。
  • 可行顶标:对于所有顶点的函数值(l),使得对于任意边 (e(x ightarrow y)),都满足 (l_{x} + l_{y} ge W_{e}),(其中,(l_x)为顶点x的顶标,(l_y)为顶点y的顶标,(w_e)为边(e(x ightarrow y))的权值)
  • 相等子图:相等子图包含原图中所有的点,但只包含满足 (l_{x} + l_{y} = W_{e})的所有边 (e(x ightarrow y))。根据定义,这些边一定是当前权值最大的边(不等式已经取到等号),那么如果相等子图有完美匹配,那这个完美匹配一定是最大权值完美匹配。因为相等子图的权值和为所有点的顶标之和,而随便一个匹配中的边因为受到 (W_{e} le l_{x} + l_{y})的限制,不可能比所有点的顶标之和大。

3.2 KM算法步骤解析

KM算法的主要目标就在于寻找可行顶标,使得相等子图有完美匹配。可行顶标的修改过程中,每一步都运用了贪心的思想,这样我们的最终结果一定是最优的。下面是算法的叙述:

步骤一:顶标初始化

因为有 (l_{x} + l_{y} = W_{e})恒成立,我们设左侧(Y集)的所有节点顶标为 0,那么所有 X集的点的顶标就必须为从它出发所有的边的权值最大值。

步骤二:寻找完美匹配

寻找当前顶标条件下, 采用增广路定理对每个点进行匹配(匈牙利算法),若最大匹配就是完美匹配,结束算法,否则必须修改顶标,使得有更多的边能够参与进来。

步骤三:修改顶标,加入更多可行顶标及对应边

我们求当前相等子图的完美匹配失败,是因为对于某个未匹配顶点 u,我们找不到一条从它出发的增广路,这时我们只能获得一条交替路。我们把 X集中在交替路的点集叫做 S, X集中不在交替路的点集叫做 S',同理 Y集中在交替路的点集叫做 T, Y集中不在交替路的点集叫做 T'。如果我们把交替路中 X 集顶点的顶标(点集S中的点)全都减小某个值 d,Y集的顶标(点集T中的点)全都增加同一个值 d,那么我们会发现:

  • 两端都在交替路中的边 (e(i ightarrow j))(l_{i} + l_{j}) 的值没有变化。也就是说,它原来属于相等子图,现在仍属于相等子图。
  • 两端都不在交替路中的边 (e(i ightarrow j))(l_{i}, l_{j}) 都没有变化,(l_{i} + l_{j}) 的值没有变化。也就是说,它原来属于(或不属于)相等子图,现在仍属于(或不属于)相等子图。
  • X集一端在 S' 中, Y端在 T中的边 (e(i ightarrow j)),它的 (l_{i})不变, (l_{j})增加了d,(l_{i} + l_{j})的值有所增大。它原来不属于相等子图,现在仍不可能属于相等子图。
  • X集一端在 S中,Y 端在 T'中的边(e(i ightarrow j)),它的 (l_{i})减小了d, (l_{j})不变,(l_{i} + l_{j})的值有所减小。也就说,它原来不属于相等子图,现在可能进入了相等子图,因而使相等子图得到了扩大。

也就是说,只有 X集一端在 S 中,Y端在 T'中的边才有可能被选中。继续贪心,我们只能让满足条件的边权最大的边被选中,即满足(l_{x} + l_{y} = W_{e}),那么这个 d 值,就应该取 (d = min{l_{x} + l_{y} - W_{e(x ightarrow y)} vert x in S, y in T'})

于是有新的边加入相等子图,我们可以愉快的继续对于未匹配顶点 u寻找增广路,这样的修改最多进行n次,而一共有 n个点,所以除去修改顶标的时间,复杂度已经达到(O(n^{2}))。因此算法的复杂度主要取决于修改顶标的时间, 修改顶标主要两个思路:

  • 思路一:枚举所有(n^{2})条边,看是否满足条件,满足条件就更新d值。最直观清晰,然而总的复杂度飙升至(O(n^{4}))
  • 思路二:对于T'​的每个点v,定义松弛变量(slack(v) = min{l_{x}+l_{y} -W_{e(x ightarrow y)} vert xin S}),这个松弛变量在匹配的过程中就可以更新,修改顶标的过程中(d = min{slack(v) vert v in T'})。总复杂度(O(n^{3})),但不是严格的(想一想为什么)?

3.3 KM算法步骤总结

KM算法仅仅只适用于找二分图最佳完美匹配,如果无完美匹配,那么算法很可能陷入死循环(如果不存在的边为 -INF 的话就不会,但正确性就无法保证了),对于这种情况要小心处理。
最后回顾一下总的流程,理一下思路:

  1. 初始化可行顶标。
  2. 用增广路定理寻对每个点找匹配。
  3. 若点未找到匹配则修改可行顶标的值。
  4. 重复2、3步直到所有点均有匹配为止,即找到相等子图的完美匹配为止

3.4 KM代码实现

3.4.1 python实现

(O(n^{4}))版本:

# Kuhn-Munkres匹配算法, O(n^4)时间复杂度
class KMMatchOriginal(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分图的必须采用numpy array 格式")
        assert graph.ndim == 2, print("二分图的维度必须为2")
        self.graph = graph

        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        self.lx = np.zeros(self.cols, dtype=np.float32)  # 横向结点的顶标
        self.ly = np.zeros(self.rows, dtype=np.float32)  # 竖向结点的顶标

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 横向结点匹配的竖向结点的index (默认-1,表示未匹配任何竖向结点)
        self.match_weight = 0  # 匹配边的权值之和

    def match(self):
        # 初始化顶标, ly初始化为0,lx初始化为节点对应权值最大边的权值
        for y in range(self.rows):
            self.ly[y] = max(self.graph[y, :])

        for y in range(self.rows):  # 从每一竖向结点开始,寻找增广路
            while True:
                self.vx = np.zeros(self.cols, dtype=np.int32)  # 横向结点的匹配标志
                self.vy = np.zeros(self.rows, dtype=np.int32)  # 竖向结点的匹配标志
                if self.dfs(y):
                    break
                else:
                    self.update()
        return self.match_index

    # 更新顶标
    def update(self):
        d = np.inf
        # 寻找y中已匹配,x中未匹配,对应需要减小的最小权值
        for y in range(self.rows):
            if self.vy[y]:
                for x in range(self.cols):
                    if not self.vx[x]:
                        d = min(d, self.lx[x] + self.ly[y] - self.graph[y][x])

        for x in range(self.cols):  # x顶标初始化值为0,因此所有匹配点顶标+d
            if self.vx[x]:
                self.lx[x] += d
        for y in range(self.rows):  # y顶标初始化值为对应边的最大权值,因此所有匹配点顶标-d
            if self.vy[y]:
                self.ly[y] -= d

    def dfs(self, y):  # 递归版深度优先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0 and self.lx[x] + self.ly[y] == self.graph[y][x]:
                self.vx[x] = 1
                # 两种情况:一是结点x没有匹配,那么找到一条增广路;二是X结点已经匹配,采用DFS,沿着X继续往下走,最后若以未匹配点结束,则也是一条增广路
                if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                    self.match_index[x] = y  # 未匹配边变成匹配边
                    return True
        return False
if __name__ == '__main__':
    graph = np.array([[2,1,1],[3,2,1],[1,1,1]])
    kmo = KMMatchOriginal(graph)
    print(kmo.match())

(O(n^{3}))版本:

# Kuhn-Munkres匹配算法
class KMMatch(object):

    def __init__(self, graph):
        assert isinstance(graph, np.ndarray), print("二分图的必须采用numpy array 格式")
        assert graph.ndim == 2, print("二分图的维度必须为2")
        self.graph = graph

        rows, cols = graph.shape
        self.rows = rows
        self.cols = cols

        self.lx = np.zeros(self.cols, dtype=np.float32)  # 横向结点的顶标
        self.ly = np.zeros(self.rows, dtype=np.float32)  # 竖向结点的顶标

        self.match_index = np.ones(cols, dtype=np.int32) * -1  # 横向结点匹配的竖向结点的index (默认-1,表示未匹配任何竖向结点)
        self.match_weight = 0  # 匹配边的权值之和

        self.inc = math.inf

    def match(self):
        # 初始化顶标, lx初始化为0,ly初始化为节点对应权值最大边的权值
        for y in range(self.rows):
            self.ly[y] = max(self.graph[y, :])

        for y in range(self.rows):  # 从每一竖向结点开始,寻找增广路
            while True:
                self.inc = np.inf
                self.vx = np.zeros(self.cols, dtype=np.int32)  # 横向结点的匹配标志
                self.vy = np.zeros(self.rows, dtype=np.int32)  # 竖向结点的匹配标志
                if self.dfs(y):
                    break
                else:
                    self.update()
                # print(y, self.lx, self.ly, self.vx, self.vy)
        return self.match_index

    # 更新顶标
    def update(self):
        for x in range(self.cols):
            if self.vx[x]:
                self.lx[x] += self.inc
        for y in range(self.rows):
            if self.vy[y]:
                self.ly[y] -= self.inc

    def dfs(self, y):  # 递归版深度优先搜索
        self.vy[y] = 1
        for x in range(self.cols):
            if self.vx[x] == 0:
                t = self.lx[x] + self.ly[y] - self.graph[y][x]
                if t == 0:
                    self.vx[x] = 1
                    # 两种情况:一是结点x没有匹配,那么找到一条增广路;二是X结点已经匹配,采用DFS,沿着X继续往下走,最后若以未匹配点结束,则也是一条增广路
                    if self.match_index[x] == -1 or self.dfs(self.match_index[x]):
                        self.match_index[x] = y  # 未匹配边变成匹配边
                        # print(y, x, self.match_index)
                        return True
                else:
                    if self.inc > t:
                        self.inc = t
        return False
if __name__ == '__main__':
    graph = np.array([[2, 1, 1], [3, 2, 1], [1, 1, 1]])
    # # graph = np.array([[3,4,6,4,9],[6,4,5,3,8],[7,5,3,4,2],[6,3,2,2,5],[8,4,5,4,7]])
    km = KMMatch(graph)
    print(km.match())

在代码撰写过程中,踩了几个坑,也发现了一些问题,总结如下:

  • 在初始化顶标时,若行结点初始化为最大边权值,列结点初始化为0,则必须从行结点出发,遍历寻找满足条件的增广路,否则代码会陷入死循环。(即从初始化为最大边权值的结点开始遍历
  • KM算法要求行结点和列结点个数相同,如果不相同时,保证行结点个数少,列结点个数多,然后通过padding来使行结点和列结点个数相同
  • KM算法求最大权值匹配,若要求最小权值匹配,可以对权值矩阵进行转换,如采用一个很大值(如sys.maxint)减去权值矩阵
3.4.2 cpp代码实现

(O(n^{4}))版本:

int Weight[maxm][maxn];
int Lx[maxm], Ly[maxn]; // 顶标
int match[maxn];    // 记录匹配
bool S[maxm], T[maxn];  // 算法中的两个集合S和T

// 步骤 1: 初始化可行顶标和初始化匹配
void Init()
{
    // 将X集合的顶标设为最大边权,Y集合的顶标设为0
    for (int i = 1; i <= m; i++)
    {
        Lx[i] = 0;
        for (int j = 1; j <= n; j++)
        {
            match[j] = 0;   // match记录的是Y集合里的点与谁匹配
            Ly[j] = 0;
            Lx[i] = max(Lx[i], Weight[i][j]);
        }
    }
}
//步骤2:增广路定理寻找匹配点(匈牙利算法中的DFS)
bool findPath(int i)
{
    S[i] = true;
    for (int j = 1; j <= n; j++)
    {
        if (Lx[i] + Ly[j] == Weight[i][j] && !T[j]) // 找出在相等子图里又还未被标记的边
        {
            T[j] = true;
            if (!match[j] || findPath(match[j])) // 未被匹配,或者已经匹配又找到增广路
            {
                match[j] = i;
                return true;
            }
        }
    }
    return false;
}

//步骤 3: 更新顶标
void update() 
{
    // 计算a
    int a = 1 << 30;
    for (int i = 1; i <= m; i++)
        if (S[i])
            for (int j = 1; j <= n; j++)
                if (!T[j])
                    a = min(a, Lx[i] + Ly[j] - Weight[i][j]);

    // 修改顶标
    for (int i = 1; i <= m; i++)
        if (S[i])
            Lx[i] -= a;
    for (int j = 1; j <= n; j++)
        if (T[j]) 
            Ly[j] += a;
}
// 整体的KM算法
void KM()
{
    Init();

    for (int i = 1; i <= m; i++)
    {
        while (true)
        {
            for (int i = 1; i <= m; i++)
                S[i] = 0;
            for (int j = 1; j <= n; j++)
                T[j] = 0;
            if (!findPath(i))
                update();
            else
                break;
        }
    }
}

(O(n^{3}))版本:

const int maxn = 500 + 3, INF = 0x3f3f3f3f;
int n, W[maxn][maxn];
int mat[maxn];
int Lx[maxn], Ly[maxn], slack[maxn];
bool S[maxn], T[maxn];

inline void tension(int &a, const int b) {
    if(b < a) a = b;
}

inline bool match(int u) {
    S[u] = true;
    for(int v = 0; v < n; ++v) {
        if(T[v]) continue;
        int t = Lx[u] + Ly[v] - W[u][v];
        if(!t) {
            T[v] = true;
            if(mat[v] == -1 || match(mat[v])) {
                mat[v] = u;
                return true;
            }
        }else tension(slack[v], t);
    }
    return false;
}

inline void update() {
    int d = INF;
    for(int i = 0; i < n; ++i)
        if(!T[i]) tension(d, slack[i]);
    for(int i = 0; i < n; ++i) {
        if(S[i]) Lx[i] -= d;
        if(T[i]) Ly[i] += d;
    }
}

inline void KM() {
    for(int i = 0; i < n; ++i) {
        Lx[i] = Ly[i] = 0; mat[i] = -1;
        for(int j = 0; j < n; ++j) Lx[i] = max(Lx[i], W[i][j]);
    }
    for(int i = 0; i < n; ++i) {
        fill(slack, slack + n, INF);
        while(true) {
            for(int j = 0; j < n; ++j) S[j] = T[j] = false;
            if(match(i)) break;
            else update();
        }
    }
}

参考:https://nymrli.top/2019/12/05/KM-Kuhn-Munkres-算法/

https://piggerzzm.github.io/2020/03/28/Kuhn-Munkres/

https://www.cnblogs.com/xingnie/p/10395788.html

4. Kuhn-Munkres算法开源包

在实际项目中涉及到最大权值匹配问题时,可以采用开源包中的Kuhn-Munkres算法,如下面两个:

munkres

python有实现了munkres算法的安装包,可以直接安装:pip install munkres

官方使用文档:https://software.clapper.org/munkres/

scipy

scipy模块中scipy.optimize.linear_sum_assignment实现了KM匹配算法,可以直接调用。

原文地址:https://www.cnblogs.com/silence-cho/p/15112326.html