torch.unsqueeze(tensor, dim)函数

  • 功能: 返回一个新的tensor,新tensor增加了一个纬度,新的纬度的大小是1。
  • 参数
    • tensor: 在哪个tensor上增加纬度
    • dim:在哪个纬度之前增加一个纬度
  • 示例:
    • target1: 在第一个纬度之前增加一个纬度(原来是[4],增加以后变为[1, 4])
    • target2: 在第二个纬度之前(也就是第一个纬度之后)增加一个纬度(原来是[4],增加以后变为[4, 1])
target = torch.arange(1,5)
print(target.shape)
target1 = target.unsqueeze(0)
print(target1.shape)
print(target1)

target2 = target.unsqueeze(1)
print(target2.shape)
print(target2)

输出结果如下:


torch.Size([4])
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])
torch.Size([4, 1])
tensor([[1],
        [2],
        [3],
        [4]])

原文地址:https://www.cnblogs.com/jizhao/p/15501475.html