Pytorch中返回super().forward()

 https://github.com/pytorch/pytorch/issues/42885

import torch
import torch.nn as nn


class Foo(nn.Conv1d):
  def forward(self, x):
    return super().forward(x) 

这里return super.forward(x)怎么理解?

返回父类中的forward()方法。

参考:https://stackoverflow.com/questions/54752983/calling-supers-forward-method

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

  

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

  

例如DenseNet中出现类似:

定义DenseLayer(这里似乎仅仅定义了网络层,而forward行为则是直接返回super().forward(x))

class DenseLayer(nn.Sequential):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.add_module('norm', nn.BatchNorm1d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv1d(in_channels, growth_rate, kernel_size=3,
                                           stride=1, padding=1, bias=False))
        self.add_module('drop', nn.Dropout1d(p=0.2))

    def forward(self, x):

        return super().forward(x)

通过DenseLayer组装DenseBlock:

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([DenseLayer(in_channels + i*growth_rate, growth_rate) for i in range(n_layers)])

    def forward(self, x):
        for layer in self.layers:
            out = layer(x)
            x = torch.cat([x, out], 1)  # 1 = channel axis

        return x

  

快去成为你想要的样子!
原文地址:https://www.cnblogs.com/jiangkejie/p/14347153.html