pytorch-使用.apply 和 init.normal_()模拟net网络的参数初始化过程

# 构建apply函数体
from torch.nn import init
import torch
class A:
    def __init__(self):
        self.weight = torch.tensor([0.0, 0.0])
        self.bias = 0
        pass
    def apply(self, func):
        func(self)


B = A()


def init_weight(B):
    def init_value(m):
        if hasattr(m, 'weight'):
            init.normal_(m.weight, 0.0, 0.02)

    B.apply(init_value)


init_weight(B)
print(B.weight)
原文地址:https://www.cnblogs.com/my-love-is-python/p/12738488.html