生成one-hot的方法

================================①=========================================

def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

np.identity() 生成单位矩阵,所以每一行即代表了一个label的one-hot表示,通过字典的形式保存下来

map(classes_dict.get, labels) 对于lables中的每一个lable,都带入到字典中,即得到其对应的one-hot编码。

================================②=========================================

另外可以直接调用pytorch中的one-hot方法:

labels = torch.tensor([1,2,3,2,1])
[nn.functional.one_hot(labels[i], 4) for i in range(5)]

 

原文地址:https://www.cnblogs.com/zyb993963526/p/14501448.html