PyTorch——模型搭建——VGG(一)

模型搭建

0、VGG模型

1、torchvision自带的VGG模型

1 import torch
2 import torchvision
3 from torchsummary import summary
4 
5 model = torchvision.models.vgg13(num_classes=8).cuda()
6 summary(model, (3, 224, 224))

2、自己搭建

 1 import torch
 2 import torch.nn as nn
 3 import torch.nn.functional as F
 4 from torchsummary import summary
 5 
 6 class VGG(nn.Module):
 7 
 8     def __init__(self, arch, num_classes=1000):
 9         super(VGG, self).__init__()
10         self.in_channels = 3
11         self.conv3_64 = self._make_layer(64, arch[0])
12         self.conv3_128 = self._make_layer(128, arch[1])
13         self.conv3_256 = self._make_layer(256, arch[2])
14         self.conv3_512a = self._make_layer(512, arch[3])
15         self.conv3_512b = self._make_layer(512, arch[4])
16         self.flatten = nn.Flatten()
17         self.fc1 = nn.Linear(7*7*512, 4096)
18         self.bn1 = nn.BatchNorm1d(4096)
19         self.fc2 = nn.Linear(4096, 4096)
20         self.fc3 = nn.Linear(4096, num_classes)
21 
22     def _make_layer(self, channels, num):
23         layers = []
24         for i in range(num):
25             layers.append(nn.Conv2d(self.in_channels, channels, 3, stride=1, padding=1, bias=False))
26             layers.append(nn.BatchNorm2d(channels))
27             layers.append(nn.ReLU())
28             self.in_channels = channels
29         return nn.Sequential(*layers)
30 
31     def forward(self, x):
32         x = self.conv3_64(x)
33         x = F.max_pool2d(x, 2)
34         x = self.conv3_128(x)
35         x = F.max_pool2d(x, 2)
36         x = self.conv3_256(x)
37         x = F.max_pool2d(x, 2)
38         x = self.conv3_512a(x)
39         x = F.max_pool2d(x, 2)
40         x = self.conv3_512b(x)
41         x = F.max_pool2d(x, 2)
42         #x = x.view(x.size(0), -1)
43         x = self.flatten(x)
44         x = self.fc1(x)
45         x = self.bn1(x)
46         x = F.relu(x)
47         x = self.fc2(x)
48         x = self.bn1(x)
49         x = F.relu(x)
50         x = self.fc3(x)
51         return x
52 
53 def VGG_11():
54     return VGG([1,1,2,2,2], num_classes=6)
55 
56 def VGG_13():
57     return VGG([2,2,2,2,2], num_classes=7)
58 
59 def VGG_16():
60     return VGG([2,2,3,3,3], num_classes=8)
61 
62 def VGG_19():
63     return VGG([2,2,4,4,4], num_classes=9)
64 
65 net = VGG_16().cuda()
66 summary(net, (3, 224, 224))
原文地址:https://www.cnblogs.com/timelesszxl/p/14555192.html