pytorch/Python的一些函数用法(日常更新)

torch.nn.Embedding(num_embeddings: int, embedding_dim: int)
是用来embed词成为word embedding的
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector

例如:self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)




Python内置函数:getattr(object, name[, default])
  • object -- 对象。
  • name -- 字符串,对象属性。
  • default -- 默认返回值,如果不提供该参数,在没有对应属性时,将触发 AttributeError。

X.expand(size)

用来将X原样复制size遍,成为size的shape,例如

t
Out[22]: tensor([0, 1, 2, 3])
t.expand((2,3,-1))
Out[23]:
tensor([[[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3]],
        [[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3]]])

register_buffer

应该就是在内存中定义一个常量,同时,模型保存和加载的时候可以写入和读出。

例如:self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

tensor.permute()

维度转换

比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

利用这个函数permute(0,2,1)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

  1. tensor([[[1., 4.],
  2. [2., 5.],
  3. [3., 6.]]])

如果使用view,可以得到

  1. tensor([[[1., 2.],
  2. [3., 4.],
  3. [5., 6.]]])

tensor.view(-1)

把tensor中所有数字放置成一个list返回

import torch
a = torch.randn(3,5,2)
print(a)
print(a.view(-1))

运行结果:

tensor([[[-0.6887,  0.2203],
         [-1.6103, -0.7423],
         [ 0.3097, -2.9694],
         [ 1.2073, -0.3370],
         [-0.5506,  0.4753]],

        [[-1.3605,  1.9303],
         [-1.5382, -1.0865],
         [-0.9208, -0.1754],
         [ 0.1476, -0.8866],
         [ 0.4519,  0.2771]],

        [[ 0.6662,  1.1027],
         [-0.0912, -0.6284],
         [-1.0253, -0.3542],
         [ 0.6909, -1.3905],
         [-2.1140,  1.3426]]])
tensor([-0.6887,  0.2203, -1.6103, -0.7423,  0.3097, -2.9694,  1.2073, -0.3370,
        -0.5506,  0.4753, -1.3605,  1.9303, -1.5382, -1.0865, -0.9208, -0.1754,
         0.1476, -0.8866,  0.4519,  0.2771,  0.6662,  1.1027, -0.0912, -0.6284,
        -1.0253, -0.3542,  0.6909, -1.3905, -2.1140,  1.3426])

Optional[X]

等价于Union[X, None]

from typing import Optional

def foo_v2(a: int, b: Optional[int] = None):
    if b:
        print(a + b)
    else:
        print("parameter b is a NoneType!")

#只传入a位置的实参
foo_v2(2)

# 输出
>>> parameter b is a NoneType!

d

原文地址:https://www.cnblogs.com/gagaein/p/14391853.html