1、squeeze()函数和unsqueeze()函数
首先要知道tensor的维度,比如tensor([[0, 1, 2],[ 3, 4, 5]]),维度是(2, 3),当其维度变为(2, 1, 3)时,代表2个1行3列的矩阵,为tensor([[[0, 1, 2],[[3, 4, 5]]])。
squeeze()函数就是减少一个维度,unsqueeze()函数就是增加一个维度。比如上述将(2, 3)变为(2, 1, 3)就是unsqueeze()操作。而squeeze()只有维数为1时才能去掉,去掉只需在括号中写入要去掉的维度数即可,如(2, 1, 3)去掉第1维变成(2, 3)。
2、transpose()函数
作用是交换维度,比如:
x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
3、expand()和expand_as()函数
expand()函数
>>> x = torch.Tensor([[1], [2], [3]])
>>> y = x.expand(3, 3)
>>> print(x)
tensor([[1.],
[2.],
[3.]])
>>> print(y)
tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]])
>>> print(x.shape)
torch.Size([3, 1])
>>> print(y.shape)
torch.Size([3, 3])
expand_as()函数:expand_as(tensor)将张量扩展为参数tensor的大小。
>>> x = torch.randn(1, 3, 1, 1)
>>> y = torch.randn(1, 3, 3, 3)
>>> z = x.expand_as(y)
>>> print(x)
tensor([[[[ 0.4383]],
[[-1.5909]],
[[ 0.0814]]]])
>>> print(z)
tensor([[[[ 0.4383, 0.4383, 0.4383],
[ 0.4383, 0.4383, 0.4383],
[ 0.4383, 0.4383, 0.4383]],
[[-1.5909, -1.5909, -1.5909],
[-1.5909, -1.5909, -1.5909],
[-1.5909, -1.5909, -1.5909]],
[[ 0.0814, 0.0814, 0.0814],
[ 0.0814, 0.0814, 0.0814],
[ 0.0814, 0.0814, 0.0814]]]])
4、permute()函数
将tensor的维度换位
比如图片img的size是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。
比如Tensor([[[1,2,3],[4,5,6]]]),使用permute(0,2,1)可以将其转换成tensor([[[1., 4.], [2., 5.], [3., 6.]]])。
5、argmax()函数和argmin()函数
获取张量在某个维度的最大值和最小值的位置。
argmax函数:torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号。dim代表该维度会消失,例如
import torch
t = torch.tensor([[1,2],[3,4],[2,8]])
print(torch.argmax(t,0))
g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],
[2,8,4],[7,5,3]]])
print(g)
print(torch.argmax(g,0))
对于二维张量t来说,大小为(3, 2),使dim为0,意思是求第0维的最大值的序号,则固定行,直接看列,比较结果为tensor([1, 2])。
对于三维张量g来说,大小为(3, 3, 3),使dim为0,则固定第一个维度,其余维度对应位置进行比较,得到结果为tensor([[2, 2, 1], [1, 2, 1], [2, 0, 0]])。
6、numel()函数
返回数组中元素的个数。例如:
params = sum(p.numel() for p in list(net.parameters())) / 1e6 # numel()
print('#Params: %.1fM' % (params))
net.parameters():是Pytorch用法,用来返回net网络中的参数,而params则用来返回net网络中的参数的总数目。
7、fit, transform, fit_transform