pytorch函数

1、squeeze()函数和unsqueeze()函数
首先要知道tensor的维度,比如tensor([[0, 1, 2],[ 3, 4, 5]]),维度是(2, 3),当其维度变为(2, 1, 3)时,代表2个1行3列的矩阵,为tensor([[[0, 1, 2],[[3, 4, 5]]])。
squeeze()函数就是减少一个维度,unsqueeze()函数就是增加一个维度。比如上述将(2, 3)变为(2, 1, 3)就是unsqueeze()操作。而squeeze()只有维数为1时才能去掉,去掉只需在括号中写入要去掉的维度数即可,如(2, 1, 3)去掉第1维变成(2, 3)。
2、transpose()函数
作用是交换维度,比如:

x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893,  0.5809],
        [-0.1669,  0.7299,  0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
        [-0.9893,  0.7299],
        [ 0.5809,  0.4942]])

3、expand()和expand_as()函数
expand()函数

>>> x = torch.Tensor([[1], [2], [3]])
>>> y = x.expand(3, 3)
>>> print(x)
tensor([[1.],
        [2.],
        [3.]])
>>> print(y)
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]])
>>> print(x.shape)
torch.Size([3, 1])
>>> print(y.shape)
torch.Size([3, 3])

expand_as()函数:expand_as(tensor)将张量扩展为参数tensor的大小。

>>> x = torch.randn(1, 3, 1, 1)
>>> y = torch.randn(1, 3, 3, 3)
>>> z = x.expand_as(y)
>>> print(x)
tensor([[[[ 0.4383]],

         [[-1.5909]],

         [[ 0.0814]]]])
>>> print(z)
tensor([[[[ 0.4383,  0.4383,  0.4383],
          [ 0.4383,  0.4383,  0.4383],
          [ 0.4383,  0.4383,  0.4383]],

         [[-1.5909, -1.5909, -1.5909],
          [-1.5909, -1.5909, -1.5909],
          [-1.5909, -1.5909, -1.5909]],

         [[ 0.0814,  0.0814,  0.0814],
          [ 0.0814,  0.0814,  0.0814],
          [ 0.0814,  0.0814,  0.0814]]]])

4、permute()函数
将tensor的维度换位
比如图片img的size是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

比如Tensor([[[1,2,3],[4,5,6]]]),使用permute(0,2,1)可以将其转换成tensor([[[1., 4.], [2., 5.], [3., 6.]]])。

5、argmax()函数和argmin()函数
获取张量在某个维度的最大值和最小值的位置。
argmax函数:torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号。dim代表该维度会消失,例如

import torch
t = torch.tensor([[1,2],[3,4],[2,8]])
print(torch.argmax(t,0))

g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],
                            [2,8,4],[7,5,3]]])
print(g)
print(torch.argmax(g,0))

对于二维张量t来说,大小为(3, 2),使dim为0,意思是求第0维的最大值的序号,则固定行,直接看列,比较结果为tensor([1, 2])。
对于三维张量g来说,大小为(3, 3, 3),使dim为0,则固定第一个维度,其余维度对应位置进行比较,得到结果为tensor([[2, 2, 1], [1, 2, 1], [2, 0, 0]])。
6、numel()函数
返回数组中元素的个数。例如:

params = sum(p.numel() for p in list(net.parameters())) / 1e6 # numel()
print('#Params: %.1fM' % (params))

net.parameters():是Pytorch用法,用来返回net网络中的参数,而params则用来返回net网络中的参数的总数目。
7、fit, transform, fit_transform

原文地址:https://www.cnblogs.com/zyr001/p/14539689.html