pytorch中多个loss回传的参数影响示例

写了一段代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.fc1 = nn.Linear(5, 4)
        self.fc2 = nn.Linear(4, 3)
        self.fc3 = nn.Linear(4, 3)

    def forward(self, x):
        mid = self.fc1(x)
        out1 = self.fc2(mid)
        out2 = self.fc3(mid)
        return out1, out2


x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001)

print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
    out1, out2 = model(x)
    loss1 = F.cross_entropy(out1, y)
    loss2 = F.cross_entropy(out2, y)
    loss = loss1 + loss2
    optim.zero_grad()
    loss.backward()
    optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)

在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。

得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。

(粗略分析,抛砖引玉)

原文地址:https://www.cnblogs.com/peony-jing/p/14462289.html