torch.repeat()


>>> import torch
>>> 
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
#repeat(repeat_counts_axis_0,repeat_counts_axis_1)
>>> 
#(arg1,arg2,arg3,...,axis = 0, axis=1),除了最后两个参数是指在相应维度上复制到的结果维度,其余都是在tensor.size前追加维度
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>> 
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>> 
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>> 
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数
#即repeat的参数最少是tensor的维度个数
>>> # 下面是一些错误示例
>>> a.repeat(2).size()  # 1D < 2D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>> 
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>

原文地址:https://www.cnblogs.com/Henry-ZHAO/p/13857361.html