beam_search 和 viterbi算法的区别

1 beam search

beam search 在每次预测的时候是选择概率最高的top_k个路径。

要点:

  • 是基于贪心算法的思想,当k = 1时就是贪心算法
  • 常用于搜索空间非常大的情况,如语言生成任务,每一步选择一个词,而词表非常大,beam search可以大大减少计算量
  • beam search 将概率较低的分支删除,大大减少了搜索空间,其得到的解是一个近似解,不是全局最优解。
  • 时间复杂度为O(TKN)

python实现一个简单的beam search

序列长度为10,词典大小为5的单词

# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
from numpy import *
def beam_search(data,k):
    """
    k:beam_size的大小
    """
    sequences = [[list(),1.0]]
    for row in data:   # 这个row_data可以看作时间t
        print("
")
        print("row的值为:",row)
        print(sequences,len(sequences))
        all_candidates = list()
        for i in range(len(sequences)):
            seq,score = sequences[i]
            print("seq的值为{},score的值为{}".format(seq,score))
            for j in range(len(row)):
                candidates = [seq +[j],score*-log(row[j])]
                all_candidates.append(candidates)
            print("all_candidates的值为:",all_candidates)
        ordered = sorted(all_candidates,key = lambda tup:tup[1]) # 之前取了副对数,所以这里为升序排列的
        print("排序之后的顺序为:",ordered)
        
        sequences = ordered[:k]
    return sequences
beam_search(array(data),3)
row的值为: [0.1 0.2 0.3 0.4 0.5]
[[[], 1.0]] 1
seq的值为[],score的值为1.0
all_candidates的值为: [[[0], 2.3025850929940455], [[1], 1.6094379124341003], [[2], 1.2039728043259361], [[3], 0.916290731874155], [[4], 0.6931471805599453]]
排序之后的顺序为: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361], [[1], 1.6094379124341003], [[0], 2.3025850929940455]]


row的值为: [0.5 0.4 0.3 0.2 0.1]
[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]] 3
seq的值为[4],score的值为0.6931471805599453
all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182]]
seq的值为[3],score的值为0.916290731874155
all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033]]
seq的值为[2],score的值为1.2039728043259361
all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033], [[2, 0], 0.8345303547893733], [[2, 1], 1.1031891220323908], [[2, 2], 1.4495505135564588], [[2, 3], 1.937719476821764], [[2, 4], 2.7722498316111372]]
排序之后的顺序为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[2, 0], 0.8345303547893733], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[2, 1], 1.1031891220323908], [[4, 3], 1.1155773512899807], [[2, 2], 1.4495505135564588], [[3, 3], 1.474713042690254], [[4, 4], 1.596030365208182], [[2, 3], 1.937719476821764], [[3, 4], 2.109837380062033], [[2, 4], 2.7722498316111372]]


row的值为: [0.1 0.2 0.3 0.4 0.5]
[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]] 3
seq的值为[4, 0],score的值为0.4804530139182014
all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944]]
seq的值为[4, 1],score的值为0.6351243373717793
all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523]]
seq的值为[3, 0],score的值为0.6351243373717793
all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523], [[3, 0, 0], 1.46242783142998], [[3, 0, 1], 1.0221931876757278], [[3, 0, 2], 0.7646724295611531], [[3, 0, 3], 0.5819585439214754], [[3, 0, 4], 0.4402346437542523]]
排序之后的顺序为: [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523], [[3, 0, 4], 0.4402346437542523], [[4, 0, 2], 0.5784523625139449], [[4, 1, 3], 0.5819585439214754], [[3, 0, 3], 0.5819585439214754], [[4, 1, 2], 0.7646724295611531], [[3, 0, 2], 0.7646724295611531], [[4, 0, 1], 0.7732592957431818], [[4, 1, 1], 1.0221931876757278], [[3, 0, 1], 1.0221931876757278], [[4, 0, 0], 1.1062839477321111], [[4, 1, 0], 1.46242783142998], [[3, 0, 0], 1.46242783142998]]


row的值为: [0.5 0.4 0.3 0.2 0.1]
[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]] 3
seq的值为[4, 0, 4],score的值为0.33302465198892944
all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388]]
seq的值为[4, 0, 3],score的值为0.4402346437542523
all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855]]
seq的值为[4, 1, 4],score的值为0.4402346437542523
all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 0], 0.3051474021030719], [[4, 1, 4, 1], 0.40338292392194175], [[4, 1, 4, 2], 0.5300305386022366], [[4, 1, 4, 3], 0.7085303260250136], [[4, 1, 4, 4], 1.0136777281280855]]
排序之后的顺序为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 3, 1], 0.40338292392194175], [[4, 1, 4, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 1, 4, 2], 0.5300305386022366], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 3, 3], 0.7085303260250136], [[4, 1, 4, 3], 0.7085303260250136], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 4], 1.0136777281280855]]


row的值为: [0.1 0.2 0.3 0.4 0.5]
[[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]] 3
seq的值为[4, 0, 4, 0],score的值为0.23083509858308343
all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413]]
seq的值为[4, 0, 3, 0],score的值为0.3051474021030719
all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622]]
seq的值为[4, 1, 4, 0],score的值为0.3051474021030719
all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 4], 0.21151206142293622]]
排序之后的顺序为: [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 0], 0.7026278592483932]]


row的值为: [0.5 0.4 0.3 0.2 0.1]
[[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]] 3
seq的值为[4, 0, 4, 0, 4],score的值为0.1600026977571413
all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536]]
seq的值为[4, 0, 3, 0, 4],score的值为0.21151206142293622
all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938]]
seq的值为[4, 1, 4, 0, 4],score的值为0.21151206142293622
all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
排序之后的顺序为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]


row的值为: [0.1 0.2 0.3 0.4 0.5]
[[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]] 3
seq的值为[4, 0, 4, 0, 4, 0],score的值为0.11090541883234757
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158]]
seq的值为[4, 0, 4, 0, 4, 1],score的值为0.1466089890297302
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
seq的值为[4, 0, 3, 0, 4, 0],score的值为0.1466089890297302
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145]]
排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436]]


row的值为: [0.5 0.4 0.3 0.2 0.1]
[[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]] 3
seq的值为[4, 0, 4, 0, 4, 0, 4],score的值为0.07687377837246158
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808]]
seq的值为[4, 0, 4, 0, 4, 0, 3],score的值为0.10162160739070145
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266]]
seq的值为[4, 0, 4, 0, 4, 1, 4],score的值为0.10162160739070145
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]


row的值为: [0.1 0.2 0.3 0.4 0.5]
[[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]] 3
seq的值为[4, 0, 4, 0, 4, 0, 4, 0],score的值为0.05328484273786184
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901]]
seq的值为[4, 0, 4, 0, 4, 0, 4, 1],score的值为0.07043873064683441
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
seq的值为[4, 0, 4, 0, 4, 0, 3, 0],score的值为0.07043873064683441
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468]]
排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374]]


row的值为: [0.5 0.4 0.3 0.2 0.1]
[[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]] 3
seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 4],score的值为0.03693423851032901
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018]]
seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 3],score的值为0.04882440755007468
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789]]
seq的值为[4, 0, 4, 0, 4, 0, 4, 1, 4],score的值为0.04882440755007468
all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]

[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
 [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
 [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]

1.1 贪心算法(k=1)

def greedy_search(data):
    return [argmax(row) for row in data]
greedy_search(data)
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

2 viterbi算法

要点:

  • 基于动态规划的思想
  • 每一步是根据上一步全部可能选择的最高概率推测当前所有选择的最高概率,保证有全局最优解
  • 适合搜索宽度较小的图,即每一步选择较少的时候
  • 时间复杂度为O(TNN)
import numpy as np
def viterbi(data):
    row,col = data.shape
    sigma = np.zeros((row,col))  # 存储最大的概率
    phi = np.zeros((row,col))    # 存储概率最大的索引
    sigma[0] = data[0]
    
    for i in range(1,row):
        for k in range(len(sigma[i-1])):
            tmp = float("-inf")
            for j in range(col):
                if sigma[i-1][j]*data[i][k]>tmp:
                    tmp = sigma[i-1][j]*data[i][k]
                    index = j
#                 tmp = max(tmp,sigma[i-1][k]*data[i][j])
            sigma[i][k] = tmp
            phi[i][k] = index
    # 回溯
    print(sigma)
    ans = [0]*row
    i_T = argmax(sigma[-1])
    i_t = i_T
    for t in range(row-2,-1,-1):
        i_t = int(phi[t+1][i_t])
        ans[t] = i_t
    ans[-1] = i_T
    return ans
        
viterbi(data)
[[1.000000e-01 2.000000e-01 3.000000e-01 4.000000e-01 5.000000e-01]
 [2.500000e-01 2.000000e-01 1.500000e-01 1.000000e-01 5.000000e-02]
 [2.500000e-02 5.000000e-02 7.500000e-02 1.000000e-01 1.250000e-01]
 [6.250000e-02 5.000000e-02 3.750000e-02 2.500000e-02 1.250000e-02]
 [6.250000e-03 1.250000e-02 1.875000e-02 2.500000e-02 3.125000e-02]
 [1.562500e-02 1.250000e-02 9.375000e-03 6.250000e-03 3.125000e-03]
 [1.562500e-03 3.125000e-03 4.687500e-03 6.250000e-03 7.812500e-03]
 [3.906250e-03 3.125000e-03 2.343750e-03 1.562500e-03 7.812500e-04]
 [3.906250e-04 7.812500e-04 1.171875e-03 1.562500e-03 1.953125e-03]
 [9.765625e-04 7.812500e-04 5.859375e-04 3.906250e-04 1.953125e-04]]

[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
原文地址:https://www.cnblogs.com/zhou-lin/p/15016429.html