pytorch深度学习:卷积神经网络

CNN中的果蝇,哈哈。

 1 import torch
 2 from torchvision import datasets,transforms
 3 from torch import nn,optim
 4 import torch.nn.functional as F
 5 
 6 trans=(transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)))
 7 trainset=datasets.MNIST('data',train=True,download=True,transform=trans)
 8 testset=datasets.MNIST('data',train=False,download=True,transform=trans)
 9 
10 class LeNet(nn.Module):
11     def __init__(self):
12         super(LeNet, self).__init__()
13         self.c1=nn.Conv2d(1,6,(5,5))
14         self.c3=nn.Conv2d(6,16,5)
15         self.fc1=nn.Linear(16*4*4,120)
16         self.fc2=nn.Linear(120,84)
17         self.fc3=nn.Linear(84,10)
18 
19     def forward(self,x):
20         x=F.max_pool2d(F.relu(self.c1(x)),2)
21         x=F.max_pool2d(F.relu(self.c3(x)),2)
22         x=x.view(-1,self.num_flat_features(x))
23         x=F.relu(self.fc1(x))
24         x=F.relu(self.fc2(x))
25         x=self.fc3(x)
26         return x
27 
28     def sum_flat_features(self,x):
29         size=x.size()[1:]
30         num_features=1
31         for s in size:
32             num_features*=s
33         return num_features
原文地址:https://www.cnblogs.com/St-Lovaer/p/13726669.html