pytorch学习-1

读文献1. 的faster rcnn的rpn loss计算部分,遇到问题,比如某些函数,找的资料整理:

1、tensor.view(-1)

把原先tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的tensor。参数只能有一个(-1)用做推理。所以view(-1)的输出是1*?。如果要一列数据,有permute函数,将tensor的维度换位。

2、unsqueeze()函数

增加一个维度,squeeze()函数将指定的一维去掉,注意这个去掉的必须是一维(不损失数据,只是降维)

3、ne()函数

torch.ne(input, other, out=Tensor) -> Tensor:

逐元素比较inputother, 即是否input != other。第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
返回值:一个torch.ByteTensor张量,包含了每个位置的比较结果(如果tensor != other 为True,返回1)。

4、contiguous()
返回一个内存连续的有相同数据的tensor,如果原tensor内存连续,则返回原tensor;

   contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() )

rpn loss里是:rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)

contiguous:view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。

5、torch.index_select()

选择indices的数据

参数说明:index_select(x, 1, indices)

1代表维度1,即列,indices是筛选的索引序号

6、torch.nonzero( )

返回一个包含输入 input 中非零元素索引的张量.输出张量中的每行包含 input 中非零元素的索引。

def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
  # classification loss
  rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
  rpn_label = rpn_data[0].view(-1)

  rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
  rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
  rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

  fg_cnt = torch.sum(rpn_label.data.ne(0))

  rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

  # box loss
  rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
  rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
  rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

  rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)

  return rpn_cross_entropy, rpn_loss_box
————————————————

参考资料:

https://blog.csdn.net/admintan/article/details/91366551 

同样解读:https://www.cnblogs.com/kerwins-AC/p/9728731.html

https://www.cnblogs.com/wind-chaser/p/11359948.html代码备注写的很好

view和permute

https://blog.csdn.net/york1996/article/details/81949843

https://blog.csdn.net/zkq_1986/article/details/100319146?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase

ne()函数,其他torch函数

https://www.jianshu.com/p/d678c5e44a6b

https://pytorch.org/docs/master/generated/torch.ne.html?highlight=ne#torch.ne

contiguous( )

https://zhuanlan.zhihu.com/p/64376950

torch.nonzero( )

https://blog.csdn.net/monchin/article/details/79750216

 

 

原文地址:https://www.cnblogs.com/haiyanli/p/12940588.html