Pytorch数据类型转换

Pytorch数据类型转换

载入模块生成数据

import torch
import numpy as np
a_numpy = np.array([1,2,3])

Numpy转换为Tensor

a_tensor = torch.from_numpy(a_numpy)
print(a_tensor)

Tensor转换为Numpy

a_numpy = a_tensor.numpy()
print(a_numpy)

Int, float 转换为tensor

c = torch.tensor(2)
print(c)

tensor 转换为int

c = c.item()
print(c)

Numpy转换为Variable

a_variable = Variable(torch.from_numpy(a_numpy))
print(a_variable)

Variable转换为Numpy

a_numpy = a_variable.data.numpy()
print(a_numpy)
原文地址:https://www.cnblogs.com/icodeworld/p/11448684.html