多层感知机的从零开始实现

一、前言

使用Fashion-MNIST图像分类数据集

import torch
from torch import nn
from d2l import torch as d2l

#批量大小等于256
batch_size = 256

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

  

二、初始化模型参数

1、对于每一层我们都要记录一个权重矩阵和一个偏置向量

2、要为这些参数的损失的梯度分配内存

# num_inputs:输入;num_outputs:输出;输入与输出是由数据决定的
#num_hiddens:人为决定,隐藏层的大小
num_inputs, num_outputs, num_hiddens = 784, 10, 256


# nn.Parameter()函数的目的就是让该变量在学习的过程中不断的修改其值以达到最优化。
# nn.Parameter()参考:https://www.jianshu.com/p/d8b77cc02410
# torch.randn():返回一个张量,包含了从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的一组随机数

W1 = nn.Parameter(
    torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)

# randn():第一个参数是行数(输入数),第二个参数是列数(隐藏层大小)
# zeros():偏差,是一个向量-大小为隐藏层的大小,默认设为0
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))


W2 = nn.Parameter(
    torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)

# num_hiddens作为行数(隐藏层的大小),num_outputs作为列数(输出数)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))


# W1,b1:第一层
# W2, b2:第二层(隐藏层)
params = [W1, b1, W2, b2]

  

三、激活函数

确保熟悉流程,使用内置的relu函数

# relu()函数:将相应的激活值设为0来仅保留正元素并丢弃所有负元素

# 矩阵在relu()中的实现
def relu(X):
    
    # 使得生成的a的数据类型和形状与传入参数一致,但大小为zero
    a = torch.zeros_like(X)
    # 再将X与a作比较,返回出最大值
    return torch.max(X, a)

  

四、构建模型

# 构建模型-将三维图形构建为二维矩阵

def net(X):
    # 第一步,先把图片拉成一个矩阵
    # reshape((a,b)):a为行数,b为列数
    
    X = X.reshape((-1, num_inputs)) # num_inputs为输入数据=784
    
    # 模型的构建
    H = relu(X @ W1 + b1)  # 这里“@”代表矩阵乘法
    return (H @ W2 + b2)

  

五、损失函数

1、之前在softmax函数中已经定义了softmax和交叉熵损失

2、这里直接使用高级API中的内置函数

# 直接调用高级API中的内置函数计算softmax和交叉熵损失

loss = nn.CrossEntropyLoss()

  

六、训练

1、多层感知机的训练过程与softmax回归的训练过程完全相同

2、直接调用在softmax中定义的train_ch3函数

num_epochs, lr = 10, 0.1

# torch.optim.SGD:实现随机梯度下降,params: 待优化参数的iterable或者是定义了参数组的dict
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

#可以看见损失在下降,但精度没有怎么变化
#因为模型更大了,数据拟合性更好,所以损失在下降

七、评估

调用之前定义的predict_ch3函数来测试数据

# 因为内置的predict_ch3设置的是只显示6张图片,其实有256个预测结果
d2l.predict_ch3(net, test_iter)

原文地址:https://www.cnblogs.com/xiaoqing-ing/p/15069454.html