Task4.用PyTorch实现多层网络

1.引入模块,读取数据 

2.构建计算图(构建网络模型)

3.损失函数与优化器

4.开始训练模型

5.对训练的模型预测结果进行评估

 1 import torch.nn.functional as F
 2 import torch.nn.init as init
 3 import torch
 4 from torch.autograd import Variable
 5 import matplotlib.pyplot as  plt
 6 import numpy as np
 7 import math
 8 %matplotlib inline
 9 #%matplotlib inline 可以在Ipython编译器里直接使用
10 #功能是可以内嵌绘图,并且可以省略掉plt.show()这一步。
11 
12 xy=np.loadtxt('./data/diabetes.csv.gz',delimiter=',',dtype=np.float32)
13 x_data=torch.from_numpy(xy[:,0:-1])#取除了最后一列的数据
14 y_data=torch.from_numpy(xy[:,[-1]])#取最后一列的数据,[-1]加中括号是为了keepdim
15 
16 print(x_data.size(),y_data.size())
17 #print(x_data.shape,y_data.shape)
18 
19 #建立网络模型
20 class Model(torch.nn.Module):
21     
22     def __init__(self):
23         super(Model,self).__init__()
24         self.l1=torch.nn.Linear(8,6)
25         self.l2=torch.nn.Linear(6,4)
26         self.l3=torch.nn.Linear(4,1)
27         
28     def forward(self,x):
29         out1=F.relu(self.l1(x))
30         out2=F.dropout(out1,p=0.5)
31         out3=F.relu(self.l2(out2))
32         out4=F.dropout(out3,p=0.5)
33         y_pred=F.sigmoid(self.l3(out3))
34         return y_pred
35     
36 def weights_init(m):
37     classname=m.__class__.__name__
38     if classname.find('Linear')!=-1:
39         m.weight.data=torch.randn(m.weight.data.size()[0],m.weight.data.size()[1])
40         m.bias.data=torch.randn(m.bias.data.size()[0])
41         
42 #our model
43 model=Model()
44 model.apply(weights_init)
45 criterion=torch.nn.BCELoss()
46 optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
47 
48 #training loop
49 Loss=[]
50 for epoch in range(2000):
51     y_pred=model(x_data)
52     loss=criterion(y_pred,y_data)
53     if epoch%100 == 0:
54         print("epoch = ",epoch," loss = ",loss.data)
55         Loss.append(loss.data)
56         optimizer.zero_grad()
57         loss.backward()
58         optimizer.step()
59         
60 hour_var = Variable(torch.randn(1,8))
61 print("predict",model(hour_var).data[0]>0.5)
62 plt.plot(Loss)

这里说明一下,这个dataset不是自带的,需要大家自己去下载,我找的时候费了不少功夫,这里提供一个网址给大家下载https://github.com/LianHaiMiao/pytorch-lesson-zh/blob/master/dataSet/diabetes.csv.gz
参考:https://blog.csdn.net/qq_35547281/article/details/89285980

原文地址:https://www.cnblogs.com/NPC-assange/p/11348338.html