refinedet一些pytorch,python语法学习

官方链接:
https://github.com/luuuyi/RefineDet.PyTorch

product

for k, f in enumerate([10, 8, 5, 3]):
    print("f:=====",f)
    for i, j in product(range(f), repeat=2):
        print(i,j)
f:===== 3
0 0
0 1
0 2
1 0
1 1
1 2
2 0
2 1
2 2
f:===== 5
0 0
0 1
0 2
0 3
0 4
1 0
1 1
1 2
1 3
1 4
2 0
2 1
2 2
2 3
2 4
3 0
3 1
3 2
3 3
3 4
4 0
4 1
4 2
4 3
4 4

解析voc xml

根据代码,写的测试样例:
例如xml里面内容如下:voc格式

<annotation>
   <folder>VOC2007</folder>
   <filename>seat_190530_623.jpg</filename>
   <source>
       <database>The VOC2007 Database</database>
       <annotation>PASCAL VOC2007</annotation>
       <image>flickr</image>
       <flickrid>329145082</flickrid>
   </source>
   <owner>&gt;
       <flickrid>hiromori2</flickrid>
       <name>Hiroyuki Mori</name>
   </owner>&gt;
   <size>
       <width>1024</width>
       <height>768</height>
       <depth>3</depth>
   </size>
   <segmented>0</segmented>
   <object>
       <name>zuoyianquandai</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>476</xmin>
           <ymin>276</ymin>
           <xmax>562</xmax>
           <ymax>372</ymax>
       </bndbox>
   </object>
   <object>
       <name>zuoyianquandai</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>440</xmin>
           <ymin>271</ymin>
           <xmax>506</xmax>
           <ymax>372</ymax>
       </bndbox>
   </object>
   <object>
       <name>zuoyianquandai</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>622</xmin>
           <ymin>616</ymin>
           <xmax>726</xmax>
           <ymax>717</ymax>
       </bndbox>
   </object>
   <object>
       <name>zuoyianquandai</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>348</xmin>
           <ymin>598</ymin>
           <xmax>456</xmax>
           <ymax>720</ymax>
       </bndbox>
   </object>
   <object>
       <name>seat</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>270</xmin>
           <ymin>15</ymin>
           <xmax>825</xmax>
           <ymax>367</ymax>
       </bndbox>
   </object>
   <object>
       <name>seat</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>72</xmin>
           <ymin>66</ymin>
           <xmax>492</xmax>
           <ymax>683</ymax>
       </bndbox>
   </object>
   <object>
       <name>seat</name>
       <pose>Unspecified</pose>
       <truncated>0</truncated>
       <difficult>0</difficult>
       <bndbox>
           <xmin>612</xmin>
           <ymin>0</ymin>
           <xmax>1024</xmax>
           <ymax>704</ymax>
       </bndbox>
   </object>
</annotation>

代码如下:

import os
import xml.etree.ElementTree as ET


root_dir = "/data_2/project_2021/refinedet/pytorch_refinedet/data/VOCdevkit/VOC2007/Annotations/"

list_xml = os.listdir(root_dir)
for cnt, name in enumerate(list_xml):
    print(cnt,name)
    path_xml = root_dir + name
    target = ET.parse(path_xml).getroot()

    res = []
    for obj in target.iter('object'):
        difficult = int(obj.find('difficult').text) == 1
        if difficult:
            continue
        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')

        pts = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []
        for i, pt in enumerate(pts):
            cur_pt = int(float((bbox.find(pt).text)) + 0.5) - 1
            # scale height or width
            # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
            bndbox.append(cur_pt)
        # label_idx = self.class_to_ind[name]
        # bndbox.append(label_idx)

        # label_idx = self.class_to_ind[name]
        bndbox.append(name)
        res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]

    #a = 0

res里面值如下:
<class 'list'>: [[475, 275, 561, 371, 'zuoyianquandai'], [439, 270, 505, 371, 'zuoyianquandai'], [621, 615, 725, 716, 'zuoyianquandai'], [347, 597, 455, 719, 'zuoyianquandai'], [269, 14, 824, 366, 'seat'], [71, 65, 491, 682, 'seat'], [611, -1, 1023, 703, 'seat']]

np.hstack() np.vstack() target = np.hstack((boxes, np.expand_dims(labels, axis=1)))

np.vstack():在竖直方向上堆叠
np.hstack():在水平方向上平铺

import numpy as np

arr1=np.array([1,2,3])
arr2=np.array([4,5,6])
print(np.vstack)
print (np.vstack((arr1,arr2)))
print(np.hstack)
print (np.hstack((arr1,arr2)))

打印如下:
<function vstack at 0x7ff6e333d0e0>
[[1 2 3]
[4 5 6]]
<function hstack at 0x7ff6e333d290>
[1 2 3 4 5 6]
Process finished with exit code 0

target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
boxes[5,4]
label:[5] -- >np.expand_dims(labels, axis=1) -->>>>>[5,1]
==>target[5,5]

a[::-1]

a = [1,2,3,4,5]

b = a[::-1]

print(a)
print(b)
#[1, 2, 3, 4, 5]
#[5, 4, 3, 2, 1]

zip 例如:for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):

sources, self.arm_loc, self.arm_conf都是长度相同的列表,sources是数据,arm_loc和arm_conf是conv2d之类的操作方法

for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):
    arm_loc.append(l(x).permute(0, 2, 3, 1).contiguous())
    arm_conf.append(c(x).permute(0, 2, 3, 1).contiguous())

torch.max() | tensor([[6, 3, 0, ..., 6, 0, 2]]) | best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
按维度dim 返回最大值,并且返回索引。
torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)
torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

import torch
a = torch.rand(3,5)
print(a)
print("========================")
print("a.max(0)")
print(a.max(0))
print("========================")
print("a.max(1)")
print(a.max(1))
Connected to pydev debugger (build 182.4505.26)
tensor([[0.2695, 0.3127, 0.5122, 0.4659, 0.8935],
        [0.8419, 0.1534, 0.4232, 0.7792, 0.4795],
        [0.9919, 0.9686, 0.1972, 0.2406, 0.4112]])
========================
a.max(0)
torch.return_types.max(
values=tensor([0.9919, 0.9686, 0.5122, 0.7792, 0.8935]),
indices=tensor([2, 2, 0, 1, 0]))
========================
a.max(1)
torch.return_types.max(
values=tensor([0.8935, 0.8419, 0.9919]),
indices=tensor([4, 0, 0]))

这里我有点儿迷糊,max(0),max(1)分的不清,0代表列?1代表行?
原本shape[3,5]的tensor经过max(0)就得到[1,5]
在refinedet里面,下面的代码:

overlap = torch.rand(7,6375)
best_prior_overlap, best_prior_idx = overlap.max(1, keepdim=True)
best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)

overlap的含义是7个groundtruth与6375个prior的交并比,所以best_prior_overlap的维度知道是什么样子的吗?代表的含义又是啥?
best_prior_overlap的shape[7,1]
best_prior_idx的shape[7,1],取值范围是[0,6375)
每个groundtruth与哪个prior的iou最大,最大的prior是多少。

best_truth_overlap的shape是[1,6375]
best_truth_idx的shape是[1,6375],取值范围是[0,7)
每个prior与哪个groundtruth的iou最大

index_fill_(dim,index,val) |||| best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior

x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.LongTensor([0, 2])
x.index_fill_(1, index, 8)#([[8., 2., 8.],
                           # [8., 5., 8.],
                           # [8., 8., 8.]])

refinedet代码中:

best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # ensure best prior
aaa = best_truth_overlap[best_prior_idx[0].type(torch.LongTensor)] ##==2?   yes!

这个就有点儿意思了,首先best_truth_overlap里面存放的都是交并比0到1的值,best_truth_overlap是竖直的[6375]找的最大,即每个prior与groundtruth找的最大值。
best_prior_idx的shape[7,1],取值范围是[0,6375)。best_prior_idx是横向找到的最大值的位置。
代码best_truth_overlap.index_fill_(0, best_prior_idx, 2) 意思就是在best_prior_idx的位置上把best_truth_overlap对应位置赋值2。感觉就是best_truth_overlap[best_prior_idx]=2类似的操作。
总的来说好像就是代码注释的这句# ensure best prior

好记性不如烂键盘---点滴、积累、进步!
原文地址:https://www.cnblogs.com/yanghailin/p/14378830.html