Pytorch的gather用法理解

先放一张表,可以看成是二维数组

行(列)索引 索引0 索引1 索引2 索引3
索引0 0 1 2 3
索引1 4 5 6 7
索引2 8 9 10 11
索引3 12 13 14 15

看一下下面例子代码:

针对0维(输出为行形式)

>>> import torch as t
>>> a = t.arange(0,16).view(4,4)
>>> a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
 
#选取对角线的元素
>>> index = t.LongTensor([[0,1,2,3]])
>>> a.gather(0,index)
tensor([[ 0,  5, 10, 15]])

如何理解结果呢?其实很简单,就是a.gather(0,index)中第一个0已经表明输出结果是行形式(0维),如果第一个是1说明输出结果是列形式(1维),然后按照index = tensor([[0, 1, 2, 3]])顺序作用在行上索引依次为0,1,2,3

  • a[0][0] = 0
  • a[1][1] = 5
  • a[2][2] = 10
  • a[3][3] = 15

针对0维

# 选取反对角线上的元素,注意与上面的不同
>>> index2 = t.LongTensor([[3,2,1,0]])
>>> a.gather(0,index2)
tensor([[12,  9,  6,  3]])

如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在行上索引依次为3,2,1,0:

  • a[3][0] = 12
  • a[2][1] = 9
  • a[1][2] = 6
  • a[0][3] = 3

针对1维(输出为列形式)

选取对角线的元素

>>> index3 = t.LongTensor([[0,1,2,3]]).t()
>>> a.gather(1,index3)
tensor([[ 0],
        [ 5],
        [10],
        [15]])

如何理解结果呢?同理,按照index = tensor([[0, 1, 2, 3]])顺序作用在列上索引依次为0,1,2,3:

  • a[0][0] = 0
  • a[1][1] = 5
  • a[2][2] = 10
  • a[3][3] = 15

针对1维

选取反对角线上的元素

>>> index4 = t.LongTensor([[3,2,1,0]]).t()
>>> a.gather(1,index4)
tensor([[ 3],
        [ 6],
        [ 9],
        [12]])

如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在列上索引依次为3,2,1,0:

  • a[0][3] = 3
  • a[1][2] = 6
  • a[2][1] = 9
  • a[3][0] = 12
原文地址:https://www.cnblogs.com/HongjianChen/p/9451526.html