7. pytorch 现有网络模型的使用与修改和模型的保存与加载

  PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。他提供了大量的模型供我们所使用,如下图所示:


下面,我们选择其中一个网络进行使用,介绍如何使用、并修改 pytorch 本身为我们提供的现有网络。最后介绍一下模型的保存和修改。

pytorch 现有网络的使用与修改

  下面我们以 VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)的使用为例,进行介绍该网络。

     VGG 16 简介

  VGG16网络是14年牛津大学计算机视觉组和Google DeepMind公司研究员一起研发的深度网络模型。该网络一共有16个训练参数的网络,该网络的具体网络结构如下所示:


  不难看出,该网络主要用于对 224 x 224 的图像进行 1000 分类。下面我们查看 VGG 在 pytorch 上的官方文档。

     VGG 16 doc

  从帮助文档中,我们可以清楚的看到 pytorch 为我们提供了各种版本的 VGG,我们选择 VGG 16 进行查看。


     VGG16 的简单使用

  从 vgg 16的帮助文档可以得知,该模型训练的数据是 ImageNet,我们进入 torchvision.datasets 查看 ImageNet


但是该数据集实在是太大了,根本下不了,还是不搞了。建立一个该网络的模型查看参数: ```python import torch import torchvision import torch.nn as nn # import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)

print(vgg_model_original)
print(vgg_model_pretrained)

vgg_model_pretrained.add_module()



<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091424359-823205952.png" style="zoom:100%"/>
</p>
<br/>


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091441147-673036772.png" style="zoom:100%"/>
</p>
<br/>

仔细查看这个网络的组成,你可以发现,组成该网络的一个个小 module 就是我们之前所介绍过的`Conv2d`, `ReLU`, `MaxPool2d`, `Linear`, `Dropout` 等等函数,


### &nbsp;&nbsp;&nbsp;&nbsp; VGG16 模型修改
&nbsp;&nbsp;经过上面的代码,我们可以较为轻松的看到 VGG16 神经网络的结构框架,那么我们如何修改别人已经写好的模型呢?
&nbsp;&nbsp;想要修改别人写好的模型,主要有一下这几种操作


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113092149496-1776250931.png" style="zoom:100%"/>
</p>
<br/>

选中模型,进行 add_module() 或者是直接对模型进行修改

<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094536565-291308585.png" style="zoom:100%"/>
</p>
<br/>


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094752626-1549811574.png" style="zoom:100%"/>
</p>
<br/>

```python
import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)

print(vgg_model_original)
print(vgg_model_pretrained)
# vgg_model_pretrained.add_module()
vgg_model_original.classifier.add_module('15', nn.Linear(in_features=1000, out_features=10, bias=True))
print(vgg_model_original)
vgg_model_original.classifier[7] = nn.Linear(in_features=1000, out_features=15, bias=True)
print(vgg_model_original)

根据上诉代码,我们就将 1000 分类问题的网络修改成了 10 分类或者是 15 分类问题的网络了。

模型的保存和加载

  当我们利用数据将模型训练好之后,往往需要保存模型。同时,当我们创建模型的时候,也可能需要加载我们之前已经训练好的参数,下面我来介绍一下操作方法。

     保留模型结构和模型参数

通过 torch.save() 和 torch.load() 进行保存模型和参数

import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)

torch.save(vgg_model_pretrained, "../../models_param/vgg_model_pretrained.pth")

vgg_model_load = torch.load(f="../../models_param/vgg_model_pretrained.pth")
print(111)

打一个断点,查看保存模型和加载模型的参数情况



     仅保留模型参数

  同样是使用 save 和 load 参数,但是用法有所不同,他所保存的是一个模型参数,以字典dict 的形式保存

import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)

torch.save(vgg_model_pretrained.state_dict(), "../../models_param/vgg_model_pretrained_method2.pth")

vgg_model_load_method2 = torchvision.models.vgg16()
vgg_model_load_method2.load_state_dict(torch.load("../../models_param/vgg_model_pretrained_method2.pth"))
print("this is a breakpoint!")

断点查看 save 和 load 模型的参数情况



一模一样,没有问题。

Author:luckylight(xyg)
Date:2021/11/13
原文地址:https://www.cnblogs.com/lucky-light/p/15547358.html