PyTorch-模型保存与加载

保存: 

model = LinearRegression()
# ......各种操作
model.eval()
#训练完成,保存状态字典到linear.pkl
torch.save(model.state_dict(), './linear.pkl')

加载:

model = LinearRegression()
model.load_state_dict(torch.load('linear.pth'))
#...各种使用,比如预测...
x_test=np.arrar([..............])
x_test = torch.from_numpy(x_test)
predict_y = model(Variable(x_test))
原文地址:https://www.cnblogs.com/onenoteone/p/12441710.html