『PyTorch』第十三弹_torch.nn.init参数初始化

初始化参数的方法

nn.Module模块对于参数进行了内置的较为合理的初始化方式,当我们使用nn.Parameter时,初始化就很重要,而且我们也可以指定代替内置初始化的方式对nn.Module模块进行补充。

除了之前的.data进行赋值,或者.data.初始化方式外,我们可以使用torch.nn.init进行初始化参数。

from torch.nn import init

linear = nn.Linear(3, 4)

t.manual_seed(1)

init.xavier_normal(linear.weight)
print(linear.weight.data)

import math

std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0, std)

不同层类型定制化初始化

除此之外,我们可以使用如下的方式对不同的类型的层(卷积层、全连接层……)进行不同的赋值方式,

for name, params in net.named_parameters():
    if name.find('linear') != -1:
        params[0]  # weights
        params[1]  # bias
    elif name.find('conv') != -1:
        pass
    elif name.find('norm') != -1:
        pass

这里使用了str.find()方法,如下:

'asda'.find('a')
Out[3]:
0

即返回第一个find参数在原str中的位置索引。

原文地址:https://www.cnblogs.com/hellcat/p/8496956.html