怎么仅加载一部分内容的预训练模型参数

在pytorch中提供了很多预训练好的模型,以分类为例,基本上都是用ImageNet数据集来训练的,分为1000类。

但是很多时候我们要实现的分类项目可能并没有这么简单,比如我们可能并不仅仅只是实现单分类,可能想实现双分类或者是多分类,这个时候就需要对模型进行一定的修改

修改的同时还希望该修改后的模型中与预训练模型相同的部分仍能够使用预训练的参数来初始化,这时候应该怎么做?

1.单分类

这是最简单的情况,就是将1000类更改为自己想要分的类别数即可。比如你想要对性别分类,分两类,使用pytorch中的预训练模型resnet18

#coding:utf-8
import torch
from torchvision import  models
from torch import nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 然后选择使用的模型
model_conv = models.resnet18(pretrained=True)

# resnet18仅有一个全连接层
# 得到该全连接层输入神经元数.in_features
fc_features = model_conv.fc.in_features

# 默认的输出神经元数为1000
# 这里修改为自己想进行的二分类,类别为2,即man和woman
model_conv.fc = nn.Linear(fc_features, 2)
model_conv.to(device)

这样模型就设置成功了

2.双分类或多分类

这里以双分类为例,在上面的单分类中,我们仅是在原有的模型上修改了参数值,并没有改变整个模型的结构

但是单我们要实现双分类,如同时进行性别和人种分类,这个时候就需要在原来代码的基础上添加一些新的层,构造一个新的模型

如下面代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None ,gender_classes=2, race_classes=4):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # 注释掉之前的全连接层
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        # 变成两个并行的全连接层
        self.gen_fc = nn.Linear(512 * block.expansion, gender_classes)
        self.race_fc = nn.Linear(512 * block.expansion, race_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

       #变成两个并行的全连接层
        gender = F.softmax(self.gen_fc(x), 1)
        race = F.softmax(self.race_fc(x), 1)

        return gender, race

def resnet18Owned(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


def test():
    net = resnet18Owned(gender_classes=2,race_classes=4)
    gender, race = net(Variable(torch.randn(2,3,224,224)))
    print('gender :', gender.size(),gender)
    print('race :', race.size(), race)

if __name__ == '__main__':
    test()

这里举的是一个比较简单的例子,仅是将一个全连接层的resnet18更改为了两个并行全连接层的resnet18,那么这个时候怎么使用之前训练的resnet18模型参数呢?

#coding:utf-8
    import torch
    from torchvision import models
    from torch import nn

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #导入预训练模型,得到结构和参数
    pretrained_resnet18 = models.resnet18(pretrained=True)
    pretrained_resnet18_dict = pretrained_resnet18.state_dict()

    #调用自己设置的模型,也得到结构即相应参数
    model_conv = resnet18Owned(gender_classes=2, race_classes=3)
    model_conv_dict = model_conv.state_dict()
    
    #当模型中的某层是同时在两个模型中共有时才取出,即得到除了全连接层以外的所有层次对应的参数
    pretrained_resnet18_dict = {k:v for k,v in pretrained_resnet18_dict.items() if k in model_conv_dict}
    #然后用该新参数的值取更新你自己的模型的参数
    #这样,除了你修改的全连接层外,其他层次的参数就都是预训练模型的参数了
    model_conv_dict.update(pretrained_resnet18_dict)
    #然后将参数导入你的模型即可
    model_conv.load_state_dict(model_conv_dict)
model_conv.to(device)

后面了解到有一种更简单的方法:

就是当你设置好你自己的模型后,如果仅想使用预训练模型相同结构处的参数,即在加载的时候将参数strict设置为False即可。该参数值默认为True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),否则无法加载,实现如下:

model_conv.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth'), strict=False)

看看是否仅将strict设置为False即可

#coding:utf-8
import torch
from torchvision import models
from torch import nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 调用自己设置的模型,也得到结构即相应参数
model_conv = resnet18Owned(gender_classes=2, race_classes=3)
model_conv.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth'), strict=False)

model_conv.to(device)

这个是官方给的预训练模型的下载地址 https://download.pytorch.org/models/resnet18-5c106cde.pth

⚠️如果你的torch版本是1.0.1及以下,那就使用torch.utils.model_zoo.load_url();如果是1.1.0及以上,可以使用新方法torch.hub.load_state_dict_from_url()

原文地址:https://www.cnblogs.com/wanghui-garcia/p/11232094.html