PyTorch入门:使用PyTorch搭建神经网络LeNet5

前言

在本文中,我们基于PyTorch构建一个简单的神经网络LeNet5。

在阅读本文之前,建议您了解一些卷积神经网络的前置知识,比如卷积、Max Pooling和全连接层等等,可以看我写的相关文章:李宏毅机器学习课程笔记-7.1CNN入门详解

通过阅读本文,您可以学习到如何使用PyTorch构建神经网络LeNet5。

模型说明

在本例中,我们使用如下图所示的神经网络模型:LeNet5。

img

该模型有1个输入层、2个卷积层、2次Max Pooling、2个全连接层和1个输出层。

  • 输入层INPUT

    1个channel,图片size是32×32。

  • 卷积层C1

    6个channel,特征图的size是28×28,即每个卷积核的size为(5,5),stride为1。

  • 下采样操作S2

    6个channel,特征图的size是14×14,即Max Pooling窗口size为(2,2)。

  • 卷积层C3

    16个channel,特征图的size是10×10,即每个卷积核的size为(5,5),stride为1。

  • 下采样操作S4

    16个channel,特征图的size是5×5,即Max Pooling窗口size为(2,2)。

  • 全连接层F5

    120个神经元。

  • 全连接层F6

    84个神经元。

  • 输出层OUTPUT

    10个神经元。

另外,除了输入层和输出层,剩下的卷积层、最大池化操作和全连接层后面都要加上Relu激活函数,下采样操作S4之后需要进行Flatten以和全连接层F5衔接起来。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 卷积层
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv3 = nn.Conv2d(6, 16, 5)
        # 全连接层
        self.fc5 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc6 = nn.Linear(120, 84)
        self.OUTPUT = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, (2, 2)) # Max pooling over a (2, 2) window
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2) # If the size is a square you can only specify a single number
        x = x.view(-1, 16*5*5) # Flatten
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = self.OUTPUT(x)
        return x

net = LeNet5()
output = net(torch.rand(1, 1, 32, 32))
# print(output)

参考链接

https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

其实本文内容主要是PyTorch的官方教程。

PyTorch官方教程中代码实现与图片所示的LeNet5不符(PyTorch官方教程代码中是3×3的卷积核,而图片中LeNet5是5×5的卷积核),本文中我是按照图片所示模型结构实现的。

其实PyTorch开发者和其他开发者也注意到了这一问题,详见:

https://github.com/pytorch/tutorials/pull/515

https://github.com/pytorch/tutorials/commit/630802450c13c78f02f744af1c47d1033b6fe206

https://github.com/pytorch/tutorials/pull/1257


Github(github.com):@chouxianyu

Github Pages(github.io):@臭咸鱼

知乎(zhihu.com):@臭咸鱼

博客园(cnblogs.com):@臭咸鱼

B站(bilibili.com):@绝版臭咸鱼

微信公众号:@臭咸鱼

转载请注明出处,欢迎讨论和交流!


原文地址:https://www.cnblogs.com/chouxianyu/p/14613460.html