PyTorch——模型搭建——ResNet18(一)

试一试

 1 import torch
 2 import torch.nn as nn
 3 import torch.nn.functional as F
 4 from torchsummary import summary
 5 
 6 class ResBlock(nn.Module):
 7     def __init__(self, inchannel, outchannel, stride=1):
 8         super(ResBlock, self).__init__()
 9         self.left = nn.Sequential(
10             nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
11             nn.BatchNorm2d(outchannel),
12             nn.ReLU(inplace=True),
13             nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
14             nn.BatchNorm2d(outchannel)
15         )
16         self.shortcut = nn.Sequential()
17         if stride != 1 or inchannel != outchannel:
18             self.shortcut = nn.Sequential(
19                 nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
20                 nn.BatchNorm2d(outchannel)
21             )
22 
23     def forward(self, x):
24         out = self.left(x)
25         out = out + self.shortcut(x)
26         out = F.relu(out)
27         return out
28 
29 class ResNet(nn.Module):
30     def __init__(self, ResBlock, num_classes=10):
31         super(ResNet, self).__init__()
32         self.inchannel = 64
33         self.conv1 = nn.Sequential(
34             nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
35             nn.BatchNorm2d(64),
36             nn.ReLU()
37         )
38         self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)
39         self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)
40         self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)
41         self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)
42         self.fc = nn.Linear(512, num_classes)
43     def make_layer(self, block, channels, num_blocks, stride):
44         strides = [stride] + [1] * (num_blocks - 1)
45         layers = []
46         for stride in strides:
47             layers.append(block(self.inchannel, channels, stride))
48             self.inchannel = channels
49         return nn.Sequential(*layers)
50 
51     def forward(self, x):
52         out = self.conv1(x)
53         out = self.layer1(out)
54         out = self.layer2(out)
55         out = self.layer3(out)
56         out = self.layer4(out)
57         out = F.avg_pool2d(out, 28)
58         out = out.view(out.size(0), -1)
59         out = self.fc(out)
60         return out
61 
62 def ResNet18():
63     return ResNet(ResBlock, num_classes=40)
64 
65 if __name__ == "__main__":
66     model = ResNet18().cuda()
67     #summary(model, (3, 32, 32))
68     summary(model, (3, 224, 224))
原文地址:https://www.cnblogs.com/timelesszxl/p/14611411.html