Pytorch:Tensor 张量操作

张量操作

一、张量的拼接与切分

1.1 torch.cat()

功能:将张量按维度dim进行拼接

tensors:张量序列

dim:要拼接的维度

1.2 torch.stack()

功能:在新创建的维度的上进行拼接

tensors:张量序列

dim:要拼接的维度(如果dim为新的维度,则新增一个维度进行拼接,新维度只能高一维)

         

       

1.3 torch.chunk()

功能:将张量按维度进行平均切分

返回值:张量列表

注意事项:若不能整除,最后一份小于其他张量;整除时令商为向上取整的数,如7/3=2.333,取整为3

input:要切分的张量

chunks:要切分的份数

dim:要切分的维度

将张量a在第一维上的数据分成三份
运行​​​​结果

1.4 torch.split()

功能:将张量按维度进行平均切分

返回值:张量列表

input:要切分的张量

split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分(注意list的各元素之和需等于维度上的长度

dim:要切分的维度

 

二、张量索引

2.1 torch.index_select()

功能:在维度dim上,按index索引数据

返回值:依index索引数据拼接的张量

input:要索引的张量

dim:要索引的维度

index:要索引数据的序号(注意index的数据类型要为torch.long,float会报错)

2.2 torch.masked_select()

功能:按mask中的True进行索引

返回值:一维张量

input:要索引的张量

mask:与input同形状的布尔类型张量(mask的生成可以通过比较大小关系得出,le为小于等于,详见图英文注释)

 三、张量变换

3.1 torch.reshape()

功能:变换张量形状

注意事项:当张量在内存中是连续时,新张量与input共享数据内存

input:要变换的张量

size:新张量的形状(形状中若有-1,则该处的值有其他维数及总数来计算得出)

3.2 torch.transpose()

功能:交换两个张量的维度

input:要交换的张量

dim0:要交换的维度

dim1:要交换的维度

3.3 torch.t()

功能:2维张量转置,对矩阵而言,等价于torch.transpose(input,0,1)

3.4 torch.squeeze()

功能:压缩长度为1的维度(轴)

dim:若为None,移除所有长度为1的轴;如果指定维度,当且仅当该轴长度为1时,可以被移除

3.5 torch.unsqueeze()

功能:依据dim扩展维度

dim:扩展的维度

三、张量数学运算

主要可分为三类:

1.加减乘除   2. 对数、指数、幂函数  3.三角函数

其中加法比较特殊

torch.add()

功能:逐元素计算该式 input+alpha*other(为了简便于梯度下降的运算)

input:第一个张量

alpha:乘项因子

other:第二个张量

另外的拓展还有

原文地址:https://www.cnblogs.com/SakuraYuki/p/13341452.html