Pytorch02_GPU加速

GPU加速

1. 定义GPU设备

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

 2. 将模型、张量等放在GPU设备上

# 
loss.to(device)
tensor.to(device)
model.to(device)

 3. 将数据等放回CPu

predict = model(data)
predict = predict.cpu().detach().numpy()  
# detach() 和 data效果相似,但detach是深拷贝,data是浅拷贝

--------------------------------

随有随更 2021.6.9

--------------------------------

我喜欢一致,可是世界并不一致
原文地址:https://www.cnblogs.com/Haozi-D17/p/14866447.html