pytorch中torch.narrow()函数

torch.narrow(inputdimstartlength) → Tensor

Returns a new tensor that is a narrowed version of input tensor. The dimension dim is input from start to start +length. The returned tensor and input tensor share the same underlying storage.

Parameters
  • input (Tensor) – the tensor to narrow

  • dim (int) – the dimension along which to narrow

  • start (int) – the starting dimension

  • length (int) – the distance to the ending dimension

Example:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)
tensor([[ 2,  3],
        [ 5,  6],
        [ 8,  9]])

根据定义得知,这个函数是返回tensor的第dim维切片start: start+length的数, 针对例子,

x.size() = (3, 3)

torch.narrow(x, 0, 0, 2) == x[0:0+2, :]

torch.narrow(x, 1, 2, 1) == x[:, 2:2+1]

原文地址:https://www.cnblogs.com/qinduanyinghua/p/11862641.html