Pytorch模型定义的三要

首先,必须继承nn.Module这个类,要让Pytorch知道这个类是一个Module。

其次,在_init_(self)中设置需要的组件,比如(Conv,Pooling,Linear,BatchNorm等)

最后,在forward(self,x)中用定义好的组件进行组装,就像搭积木,把网络结构搭建出来,这样一个网络模型就定义好了!!!

原文地址:https://www.cnblogs.com/Terrypython/p/11543900.html