学习笔记1 pytorch实现Lenet5

参考资料:https://cuijiahua.com/blog/2018/01/dl_3.html

代码实现: 

 1 import torch
 2 from torch import nn
 3 from torch.nn import functional as F
 4 class Lenet5(nn.Module):
 5     """
 6     for cifar10 dataset
 7     """
 8     def __init__(self):
 9         super(Lenet5, self).__init__()
10         self.conv_unit=nn.Sequential(
11             #卷积层 x:[b,3,32,32] =>[b,6,30,30]
12             nn.Conv2d(3,6,kernel_size=3,stride=1,padding=0),
13             #池化层 [b,6,30,30]=>[b,6,15,15]
14             nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
15             #卷积层 [b,6,15,15] =>[b,16,13,13]
16             nn.Conv2d(6,16,kernel_size=3,stride=1,padding=0),
17             #池化层 [b,16,13,13]=>[b,16,6,6]
18             nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
19         )
20         #flatten
21         #全连接层
22         self.fc_unit=nn.Sequential(
23             nn.Linear(16*6*6,120),
24             nn.ReLU(inplace=True),
25             nn.Linear(120,84),
26             nn.ReLU(inplace=True),
27             nn.Linear(84,10)
28         )
29 
30         #use Cross Entropy loss
31         #self.criteon=nn.CrossEntropyLoss()
32 
33     def forward(self,x):
34         """
35         x:[b,3,32,32]
36         :param x:
37         :return:
38         """
39         batchsz=x.size(0)
40         #[b,3,32,32]=>[b,16,6,6]
41         x=self.conv_unit(x)
42         #[b,16,6,6] => [b,16*6*6]
43         x=x.view(batchsz,16*6*6)
44         #[b,16*6*6]=>[b,10]
45         logits=self.fc_unit(x)
46 
47         return logits
48         #
49     
50 #测试函数
51 def main():
52     net=Lenet5()
53     tmp = torch.randn(2, 3, 32, 32)
54     #通过测试,确定conv_unit输入维度
55     out = net(tmp)
56     print('conv out:', out.shape)
57 if __name__=='__main__':
58     main()
 1 import  torch
 2 from    torch.utils.data import DataLoader
 3 from    torchvision import datasets
 4 from    torchvision import transforms
 5 from    torch import nn, optim
 6 
 7 from    lenet5 import Lenet5
 8 
 9 def main():
10     #batchsize大小
11     batchsz = 32
12     #读取一张图片并进行数据增强
13     cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
14         transforms.Resize((32, 32)),
15         transforms.ToTensor()
16     ]), download=True)
17     #读取bathsize大小的一批图像
18     cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
19 
20     cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
21         transforms.Resize((32, 32)),
22         transforms.ToTensor()
23     ]), download=True)
24     cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
25 
26     #测试图像读取是否正确
27     #x, label = iter(cifar_train).next()
28     #print('x:', x.shape, 'label:', label.shape)
29 
30     #选用gpu设备
31     device = torch.device('cuda:0,1')
32     model = Lenet5().to(device)
33     #model = ResNet18().to(device)
34     
35     #定义交叉熵评估损失值
36     criteon = nn.CrossEntropyLoss().to(device)
37     #定义优化器
38     optimizer = optim.Adam(model.parameters(), lr=1e-3)
39     #打印模型信息
40     print(model)
41 
42     #train
43     for epoch in range(1000):
44         #model.train()
45         for batchidx, (x, label) in enumerate(cifar_train):
46             # [b, 3, 32, 32]
47             # [b]
48             x, label = x.to(device), label.to(device)
49             
50             #预测值
51             logits = model(x)
52             # logits: [b, 10]
53             # label:  [b]
54             # loss: tensor scalar
55             loss = criteon(logits, label)
56 
57             # 反向传播
58             optimizer.zero_grad()
59             loss.backward()
60             optimizer.step()
61 
62         print(epoch, 'loss:', loss.item())
63         #test
64         model.eval()
65         with torch.no_grad():
66             # test
67             total_correct = 0
68             total_num = 0
69             for x, label in cifar_test:
70                 # [b, 3, 32, 32]
71                 # [b]
72                 x, label = x.to(device), label.to(device)
73 
74                 # [b, 10]
75                 logits = model(x)
76                 # [b] 选取所有类别得分的最大值作为预测类
77                 pred = logits.argmax(dim=1)
78                 # [b] vs [b] => scalar tensor
79                 #计算每个batch正确的预测数量
80                 correct = torch.eq(pred, label).float().sum().item()
81                 total_correct += correct
82                 total_num += x.size(0)
83             acc = total_correct / total_num
84             print(epoch, 'acc:', acc)
85 
86 if __name__ == '__main__':
87     main()

结果:

 

原文地址:https://www.cnblogs.com/sclu/p/11947643.html