pytorch:全连接多分类小网络代码实现

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
import torchvision
import torchvision.transforms as transforms

mnist_train=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
mnist_test=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())

# feature_train,label_train=mnist_train[0]
# feature_test,label_test=mnist_test[0]
# print(len(mnist_train))
# print(len(mnist_test))
# print(feature_train.size(),label_train)
# print(feature_test.size(),label_test)

batch_size=256
num_inputs=28*28
num_outputs=10

if sys.platform.startswith('win'):
    num_workers=0
else:
    num_workers=4

train_iter=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers)

class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super().__init__()
        self.linear=nn.Linear(num_inputs,num_outputs)

        #init params way
        init.normal_(self.linear.weight, mean=0, std=0.01)
        init.constant_(self.linear.bias, val=0)

    def forward(self,x):
        y=self.linear(x.view(x.shape[0],-1))
        return y

net=LinearNet(num_inputs,num_outputs)
loss=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(net.parameters(),lr=0.1)

num_epochs=20

def evaluate_accuracy(data_iter,net):
    acc_sum,n=0.,0
    for X,y in data_iter:
        acc_sum+=(net(X).argmax(dim=1)==y).float().sum().item()
        n+=y.shape[0]
    return acc_sum/n

for epoch in range(num_epochs):
    train_l_sum,train_acc_sum,n=0.,0.,0
    for X,y in train_iter:
        y_hat=net(X)
        l=loss(y_hat,y).sum()

        optimizer.zero_grad()
        l.backward()
        optimizer.step()

        train_l_sum+=l.item()
        train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
        n+=y_hat.shape[0]

    test_acc= evaluate_accuracy(test_iter,net)
    print('epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f,'
          %(epoch,train_l_sum/n,train_acc_sum/n,test_acc))

  

原文地址:https://www.cnblogs.com/liutianrui1/p/13824863.html