PyTorch——模型搭建——MobileNetV1(一)

MobileNetv1模型

PyTorch自带的MobileNetV1

没有实现MobileNetV1

详情参考:https://pytorch.org/vision/stable/models.html

自己搭建

 1 import torch
 2 import torch.nn as nn
 3 import torch.nn.functional as F
 4 
 5 class Block(nn.Module):
 6     "Depthwise conv + Pointwise conv"
 7     def __init__(self, in_channels, out_channels, stride=1):
 8         super(Block, self).__init__()
 9         self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
10         self.bn1 = nn.BatchNorm2d(in_channels)
11         self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
12         self.bn2 = nn.BatchNorm2d(out_channels)
13 
14     def forward(self, x):
15         x = self.conv1(x)
16         x = self.bn1(x)
17         x = F.relu(x)
18         x = self.conv2(x)
19         x = self.bn2(x)
20         x = F.relu(x)
21         return x
22 
23 class MobileNet(nn.Module):
24 
25     cfg = [64, (128, 2), 128, (256, 2), (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]
26 
27     def __init__(self, num_classes=10):
28         super(MobileNet, self).__init__()
29         self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
30         self.bn1 = nn.BatchNorm2d(32)
31         self.layers = self._make_layers(in_planes=32)
32         self.linear = nn.Linear(1024, num_classes)
33 
34     def _make_layers(self, in_planes):
35         layers = []
36         for x in self.cfg:
37             out_planes = x if isinstance(x, int) else x[0]
38             stride = 1 if isinstance(x, int) else x[1]
39             layers.append(Block(in_planes, out_planes, stride))
40             in_planes = out_planes
41         return nn.Sequential(*layers)
42 
43     def forward(self, x):
44         out = F.relu(self.bn1(self.conv1(x)))
45         out = self.layers(out)
46         out = F.avg_pool2d(out, 7)
47         out = out.view(out.size(0), -1)
48         out = self.linear(out)
49         return out
50 
51 input = torch.randn(32, 3, 224, 224)
52 net = MobileNet(8)
53 out = net(input)
54 print(out.size())
原文地址:https://www.cnblogs.com/timelesszxl/p/14566107.html