pytorch中的unsqueeze函数和squeeze函数

在pytorch中,我们经常对张量Tensor的维度进行压缩或者扩充(压缩或者扩充的维度为1)。其中经常使用的是squeeze()函数和unsqueeze函数;
squeeze在英文中的意思就是“挤、压”,所以故名思议,squeeze()函数就是对张量的维度进行减少的操作,话不多说,我们直接看下例子:

import torch
#定义两个整型的张量a,b
a = torch.IntTensor([[1,2,3],[4,5,6]])
b = torch.IntTensor([[[1,2,3],[4,5,6]]])
#看一下a,b的形状
print(a.shape)
print(b.shape)
'''
===output===
torch.Size([2, 3])
torch.Size([1, 2, 3])
'''

#我们看到张量b比较膨胀,有三个维度:1*2*3,所以我们要挤压一下张量b的第0个维度(因为是1才能挤压,否则没有效果)
c = torch.squeeze(b,0)  # 对应的维度为第0维
print(c.shape)
'''
===output===
torch.Size([2, 3])
'''
#那如果想想张量a膨胀一下,怎么办
c = torch.unsqueeze(a,0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
#可以看到张量a在第0维也膨胀了, 如果你看不惯的话,再压缩一下它。

另外,squeeze()函数和unsqueeze()函数还有另一种写法,直接用张量类型的变量来调用这两个函数:

c = a.unsqueeze(0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''

你看出差别了么?这里直接用张量变量a来调用了unsqueeze()函数,当然squeeze()也是一样的,不信你可以试试^_^

原文地址:https://www.cnblogs.com/datasnail/p/13086803.html