用pytorch进行CIFAR-10数据集分类

CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

 先要对该数据集进行分类

步骤如下
1.使用torchvision加载并预处理CIFAR-10数据集、
2.定义网络
3.定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络

 1 import torchvision as tv            #里面含有许多数据集
 2 import torch
 3 import torchvision.transforms as transforms    #实现图片变换处理的包
 4 from torchvision.transforms import ToPILImage
 5 
 6 #使用torchvision加载并预处理CIFAR10数据集
 7 show = ToPILImage()         #可以把Tensor转成Image,方便进行可视化
 8 transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [0, 255] -> [0.0,1.0]
 9 trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
10 trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
11 testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
12 testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True,num_workers=0)
13 classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
14 (data,label) = trainset[100]
15 print(classes[label])#输出ship
16 show((data+1)/2).resize((100,100))
17 dataiter = iter(trainloader)
18 images, labels = dataiter.next()
19 print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
20 show(tv.utils.make_grid((images+1)/2)).resize((400,100))#make_grid的作用是将若干幅图像拼成一幅图像
21 
22 #定义网络
23 import torch.nn as nn
24 import torch.nn.functional as F
25 class Net(nn.Module):
26     def __init__(self):
27         super(Net,self).__init__()
28         self.conv1 = nn.Conv2d(3,6,5)
29         self.conv2 = nn.Conv2d(6,16,5)
30         self.fc1 = nn.Linear(16*5*5,120)
31         self.fc2 = nn.Linear(120,84)
32         self.fc3 = nn.Linear(84,10)
33     def forward(self,x):
34         x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
35         x = F.max_pool2d(F.relu(self.conv2(x)),2)
36         x = x.view(x.size()[0],-1)
37         x = F.relu(self.fc1(x))
38         x = F.relu(self.fc2(x))
39         x = self.fc3(x)
40         return  x
41 
42 net = Net()
43 print(net)
44 
45 #定义损失函数和优化器
46 from torch import optim
47 criterion  = nn.CrossEntropyLoss()#定义交叉熵损失函数
48 optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
49 
50 #训练网络
51 from torch.autograd  import Variable
52 for epoch in range(2):
53     running_loss = 0.0
54     for i, data in enumerate(trainloader, 0):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
55         inputs, labels = data
56         inputs, labels = Variable(inputs), Variable(labels)
57         optimizer.zero_grad()
58         outputs = net(inputs)
59         loss  = criterion(outputs, labels)
60         loss.backward()
61         optimizer.step()
62         running_loss += loss.item()
63         if i % 2000 ==1999:
64             print('[%d, %5d] loss: %.3f'%(epoch+1,i+1,running_loss/2000))
65             running_loss = 0.0
66 print("----------finished training---------")
67 dataiter = iter(testloader)
68 images, labels = dataiter.next()
69 print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range(4)))
70 show(tv.utils.make_grid(images/2 - 0.5)).resize((400,100))#?????
71 outputs = net(Variable(images))
72 _, predicted = torch.max(outputs.data,1)#返回最大值和其索引
73 print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range(4)))
74 correct = 0
75 total = 0
76 for data in testloader:
77     images, labels = data
78     outputs = net(Variable(images))
79     _, predicted = torch.max(outputs.data, 1)
80     total +=labels.size(0)
81     correct +=(predicted == labels).sum()
82 print('10000张测试集中的准确率为: %d %%'%(100*correct/total))
83 if torch.cuda.is_available():
84     net.cuda()
85     images = images.cuda()
86     labels = labels.cuda()
87     output = net(Variable(images))
88     loss = criterion(output, Variable(labels))

学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。

这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。

原文地址:https://www.cnblogs.com/henuliulei/p/11981109.html