关于Pytorch的二维tensor的gather和scatter_操作用法分析

看得不明不白(我在下一篇中写了如何理解gather的用法)

gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]]  # dim=1

二维tensor的gather操作

针对0轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]])
print("index = 
", index)      #index是2维
print("index的形状: ",index.shape)  #index形状是(1,4)  

输出

index = 
 tensor([[0, 1, 2, 3]])
index的形状:  torch.Size([1, 4])

分割线============

针对1轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]]).t()  #index是2维
print("index = 
", index)    #index形状是(4,1)
print("index的形状: ",index.shape)

输出

index = 
 tensor([[0],
        [1],
        [2],
        [3]])
index的形状:  torch.Size([4, 1])

分割线===========

再来看看几个例子

注意index在以0轴和1轴为标准时的表达式是不一样的。
b.gather()中取0维时,输出的结果是行形式,取1维时,输出的结果是列形式。

  • b是一个 $ 3 imes4 $ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]])

>>> index
tensor([[0, 1, 2]])

>>> b.gather(0,index)     #运行失败了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620    

>>> index2 = t.LongTensor([[0,1,2]]).t()

>>> b.gather(1,index2)  #运行成功了
tensor([[ 0],
        [ 5],
        [10]])

>>> index3 = t.LongTensor([[0,1,2,3]]).t()

>>> b.gather(1,index3)  #运行失败了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620
  • b是一个 $ 6 imes6 $ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])

>>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index)     #运行失败了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620  

>>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index)    #运行成功了
tensor([[ 0,  7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620

>>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()  
>>> b.gather(1,index2)     #运行成功了
tensor([[ 0],     
        [ 7],
        [14],
        [21],
        [28],
        [35]])

>>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620  

>>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:
ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620

与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。



与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)

根据StackOverflow上的问题修改代码如下:
输入

# 把两个对角线元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())

输出

tensor([[ 0.,  0.,  0.,  3.],
        [ 0.,  5.,  6.,  0.],
        [ 0.,  9., 10.,  0.],
        [12.,  0.,  0., 15.]])
原文地址:https://www.cnblogs.com/HongjianChen/p/9450987.html