pytorch faster_rcnn

代码地址:https://github.com/jwyang/faster-rcnn.pytorch

1.fasterRCNN.train():这个不是让网络进行训练,而是让module in training mode,有些module在traing model和testing model下不同,比如bn

即self.training这个成员变量为true(这个成员变量属于nn.Module,fasterRCNN继承了这个成员变量),以下是train成员函数的源码

2.bn的train和test不同,train的时候应该是要学习参数的,test的时候关闭,pytorch的用法如下:

pytorch的batchnorm使用时需要小心,training和track_running_stats可以组合出三种behavior,很容易掉坑里(我刚发现我对track_running_stats的理解错了)。

  1. training=True, track_running_stats=True, 这是常用的training时期待的行为,running_mean 和running_var会跟踪不同batch数据的mean和variance。
  2. training=True, track_running_stats=False, 这时候batchnorm不跟踪跨batch数据的statistics了,而是用每个batch的mean和variance做normalization。
  3. training=False, track_running_stats=True, 这是我们期待的test时候的行为,即使用training阶段估计的running_mean 和running_var.
  4. training=False, track_running_stats=False,同2(!!!).
https://www.zhihu.com/question/282672547/answer/529154567李韶华的回答
3.class_agnostic == true就是所有类别回归同一个坐标,也就是一个框回归一个坐标
        == false是每个类别单独回归4个坐标
    if self.class_agnostic:
      self.RCNN_bbox_pred = nn.Linear(4096, 4)
    else:
      self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)
4.真正开始训练的代码不是fasterRCNN.train(),而是下面这段代码:
      rois, cls_prob, bbox_pred, 
      rpn_loss_cls, rpn_loss_box, 
      RCNN_loss_cls, RCNN_loss_bbox, 
      rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

fasterRCNN是一个实例,应该是没办法进行调用的,但实际上这段代码执行的是forward函数。为什么?其实就是python的括号重载。fasterRCNN这个实例继承于nn.Module类,这个类定义了forward成员函数,nn.Module类使用了__call__进行了重载,让实例能够调用,并且调用的函数是forward函数,具体代码见下面的源码:

python中__call__函数的作用是使实例能够像函数一样被调用https://blog.csdn.net/Yaokai_AssultMaster/article/details/70256621,也称之为括号重载,即‘()’

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            hook(self, input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

nn.Module定义了一个forward的成员函数,这个函数在基类中没有实现,而是在各个子类自己实现的,每个子类都必须实现forward函数:

    def forward(self, *input):
        r"""Defines the computation performed at every call.
        Should be overridden by all subclasses.
        .. note::
            Although the recipe for forward pass needs to be defined within
            this function, one should call the :class:`Module` instance afterwards
            instead of this since the former takes care of running the
            registered hooks while the latter silently ignores them.
        """
        raise NotImplementedError

子类调用forward函数不能直接用calss.forward(),而是用实例的函数调用,具体的原因好像是hook,这个在上面__call__函数中也看到调用forward使用了跟hook有关的input

 

原文地址:https://www.cnblogs.com/ymjyqsx/p/10084507.html