pytorch one hot

print(torch.nn.functional.one_hot(t, num_classes=7))

有个坑,使用的时候必须转换为 torch.int64 类型,不然会报错

t = t.to(torch.int64)
原文地址:https://www.cnblogs.com/consolexinhun/p/14303734.html