LeNet网络在fashion MNIST数据集上的实现

import torch
from torch import nn,optim
import torchvision
import torchvision.transforms as transforms
import sys

#params
batch_size=256
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs=10

#dataset
mnist_train=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
mnist_test=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
if sys.platform.startswith('win'):
    num_workers=0
else:
    num_workers=4
train_iter=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers)

#net
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(1,6,5),
            nn.Sigmoid(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6, 16, 5),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
            )

        self.fc=nn.Sequential(
            nn.Linear(16*4*4,120),
            nn.Sigmoid(),
            nn.Linear(120,84),
            nn.Sigmoid(),
            nn.Linear(84,10)
        )

    def forward(self,img):
        feature=self.conv(img)
        output=self.fc(feature.view(img.shape[0],-1))
        return output

net=LeNet().to(device)

def evaluate_accuracy(data_iter,net,device):
    acc_sum,n=0.,0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(net,torch.nn.Module):
                net.eval()
                acc_sum+=(net(X.to(device)).argmax(dim=1)==y.to(device)).float().sum().cpu().item()
                net.train()
            n+=y.shape[0]
    return acc_sum/n

loss=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.001)

for epoch in range(num_epochs):
    train_l_sum,train_acc_sum,n=0.,0.,0
    for X,y in train_iter:
        X,y=X.to(device),y.to(device)
        y_hat=net(X)
        l=loss(y_hat,y).sum()

        optimizer.zero_grad()
        l.backward()
        optimizer.step()

        train_l_sum+=l.cpu().item()
        train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().cpu().item()
        n+=y_hat.shape[0]

    test_acc= evaluate_accuracy(test_iter,net,device)
    print('epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f,'
          %(epoch,train_l_sum/n,train_acc_sum/n,test_acc))

  

原文地址:https://www.cnblogs.com/liutianrui1/p/13839229.html