pytorch实现mnist手写识别

直接上代码,还有可视化的结果

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from visdom import Visdom
from torch.utils.data import DataLoader

batch_size=512
learning_rate=0.01
epoches=20

# 5:1:1
train_db=datasets.MNIST('data',train=True,transform=transforms.Compose([transforms.ToTensor()]))
train_db,validation_db=torch.utils.data.random_split(train_db,[50000,10000])
test_db=datasets.MNIST('data',train=False,transform=transforms.Compose([transforms.ToTensor()]))
train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True)
validation_loader=DataLoader(validation_db,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_db,batch_size=batch_size,shuffle=False)
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self,x):
        x=x.view(x.size(0),-1)
        return x
# 定义网络的结构
class Mnist(nn.Module):
    def __init__(self):
        super(Mnist, self).__init__()
        self.net=nn.Sequential(
            Flatten(),
            nn.Linear(784,512),
            nn.ReLU(inplace=True),
            nn.Linear(512,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,10),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        logits=self.net(x)
        return logits

def main():
    mod=Mnist()
    optimizer=optim.SGD(mod.parameters(),lr=learning_rate)
    loss_fun=nn.CrossEntropyLoss()
    vis=Visdom()
    vis.line([0.],[0.],win='train_loss',opts=dict(title='trai_loss'))
    vis.line([0.],[0.],win='accuracy',opts=dict(title='acc'))
    # vis.line([0.],[0.], win='val_loss', opts=dict(title='val_loss'))
    correct=0
    total_num=0
    global_step=0
    for epoch in range(5000):
        for batch_index,(x,y) in enumerate(train_loader):
            # x=x.view(-1,28*28)
            logits=mod(x)
            train_loss=loss_fun(logits,y)
            pred=logits.argmax(dim=1)
            correct+=torch.eq(y,pred).float().sum()
            total_num += x.size(0)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            global_step+=1
            acc=100.*correct/total_num
            vis.line([train_loss.item()],[global_step],win='train_loss',update='append')
            vis.line([acc],[global_step],win='accuracy',update='append')
            print('the loss of {:d} step is {:.3f},the accuracy is {:.3f}%'.format(global_step,train_loss.item(),acc))

        mod.eval()
        with torch.no_grad():
            val_correct=0
            val_total=0
            for validation_images,validation_label in validation_loader:
                # validation_images=validation_images.view(-1,28*28)
                val_logits=mod(validation_images)
                pred=val_logits.argmax(dim=1)
                val_loss=loss_fun(val_logits,validation_label)
                val_correct+=torch.eq(pred,validation_label).float().sum()
                val_total+=validation_images.size(0)
            # vis.line([val_loss.item()],[global_step],win='val_loss',update='append')
            vis.images(validation_images.view(-1,1,28,28),win='x')
            vis.text(str(pred.detach().cpu().numpy()), win='pred',
                     opts=dict(title='pred'))
            val_acc=100.* val_correct/val_total
            print('the val acc of {:d} epoch is {:.3f}%'.format(epoch,val_acc))



if __name__ == '__main__':
    main()

原文地址:https://www.cnblogs.com/daremosiranaihana/p/12591884.html