1.入门:pytorch基本数据类型/加法/取数/改变形状

pytorch里面处理的最基本的操作对象就是Tensor(张量),它就是一个多维矩阵。它和numpy唯一的不同就是,pytorch可以在GPU上运行,而numpy不可以。

Tensor的基本数据类型有五种:

  • 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。
  • 64位整型:torch.LongTensor。
  • 32位整型:torch.IntTensor。
  • 16位整型:torch.ShortTensor。
  • 64位浮点型:torch.DoubleTensor。
import torch
#x=torch.rand((5,3))#构造一个随机初始化的矩阵,有没有括号都可
#x = torch.empty((5, 3))#构造一个5x3矩阵,不初始化,有没有括号都可
#x=torch.zeros((5, 3), dtype=torch.long)#构造一个矩阵全为 0,而且数据类型是 long.,有没有括号都可
#x = torch.tensor([[5.5,3],[3,4]])#直接使用数据构造一个张量,注意必须用中括号括好

x=torch.zeros((5, 3), dtype=torch.long)#基于已经存在的 tensor创建一个tensor
#y = x.new_ones((5, 3), dtype=torch.double)#基于已经存在的 tensor创建一个tensor,全部是1
#y = torch.randn_like(x, dtype=torch.float)#基于已经存在的 tensor创建一个tensor,随机初始化
print(x.size())#获取维度信息
print(x)
#print(y)

两种加法:

x=torch.zeros((5, 3), dtype=torch.long)
y = torch.rand((5, 3))
print(y)
#print(x+y)#加法方式1
print(torch.add(x, y))#加法方式2

tensor([[0.3135, 0.3338, 0.1780],
[0.0653, 0.5707, 0.1647],
[0.9389, 0.4725, 0.2722],
[0.6090, 0.1351, 0.6424],
[0.9465, 0.4593, 0.1661]])
tensor([[0.3135, 0.3338, 0.1780],
[0.0653, 0.5707, 0.1647],
[0.9389, 0.4725, 0.2722],
[0.6090, 0.1351, 0.6424],
[0.9465, 0.4593, 0.1661]])

附:

x=torch.zeros((5, 3), dtype=torch.long)
y = torch.rand((5, 3))

#print(x+y)#加法方式1
#print(torch.add(x, y))#加法方式2
result = torch.empty(5, 3)
torch.add(x, y, out=result)#提供一个输出 tensor 作为参数,加法方式3
print(result)

y.add_(x)#带有_会使张量发生变化
print(y)

取数像numpy一样

x = torch.rand((5, 3))
print(x)
print(x[:, 1])#所有行的第一列

tensor([[0.1158, 0.1455, 0.6182],
[0.5442, 0.9168, 0.2065],
[0.6274, 0.6169, 0.8121],
[0.2683, 0.1492, 0.3545],
[0.2574, 0.4477, 0.1026]])
tensor([0.1455, 0.9168, 0.6169, 0.1492, 0.4477])

改变张量的形状

x = torch.rand((4, 4))
print(x)
#print(x[:, 1])#所有行的第一列
y = x.view((16))#改变张量形状,有没有括号都可
print(y)

tensor([[0.9212, 0.9648, 0.3745, 0.6670],
[0.7536, 0.5412, 0.6765, 0.6377],
[0.8996, 0.9950, 0.1279, 0.5086],
[0.4452, 0.0982, 0.8753, 0.8960]])
tensor([0.9212, 0.9648, 0.3745, 0.6670, 0.7536, 0.5412, 0.6765, 0.6377, 0.8996,
0.9950, 0.1279, 0.5086, 0.4452, 0.0982, 0.8753, 0.8960])

原文地址:https://www.cnblogs.com/liuxiangyan/p/14631198.html