Detectron训练并调用自己的数据集(附yarm文件说明)

1.创建如下目录

 2.将train_annotations中的xml文件转成json文件

# coding=utf-8
import xml.etree.ElementTree as ET
import os
import json
import collections

coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []

#category_set = dict()
image_set = set()
image_id = 2019100001  # train:2018xxx; val:2019xxx; test:2020xxx
category_item_id = 1
annotation_id = 1
category_set = ['people', 'bicycle', 'electric bicycle'] #这里填写好检测的类别
'''
def addCatItem(name):
    global category_item_id
    category_item = dict()
    category_item['supercategory'] = 'none'
    category_item_id += 1
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_set[name] = category_item_id
    return category_item_id
'''


def addCatItem(name):
    '''
    增加json格式中的categories部分
    '''
    global category_item_id
    category_item = collections.OrderedDict()
    category_item['supercategory'] = 'none'
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_item_id += 1


def addImgItem(file_name, size):
    global image_id
    if file_name is None:
        raise Exception('Could not find filename tag in xml file.')
    if size['width'] is None:
        raise Exception('Could not find width tag in xml file.')
    if size['height'] is None:
        raise Exception('Could not find height tag in xml file.')
    # image_item = dict()    #按照一定的顺序,这里采用collections.OrderedDict()
    image_item = collections.OrderedDict()
    jpg_name = os.path.splitext(file_name)[0] + '.jpg'
    image_item['file_name'] = jpg_name
    image_item['width'] = size['width']
    image_item['height'] = size['height']
    image_item['id'] = image_id
    coco['images'].append(image_item)
    image_set.add(jpg_name)
    image_id = image_id + 1
    return image_id


def addAnnoItem(object_name, image_id, category_id, bbox):
    global annotation_id
    #annotation_item = dict()
    annotation_item = collections.OrderedDict()
    annotation_item['segmentation'] = []
    seg = []
    # bbox[] is x,y,w,h
    # left_top
    seg.append(bbox[0])
    seg.append(bbox[1])
    # left_bottom
    seg.append(bbox[0])
    seg.append(bbox[1] + bbox[3])
    # right_bottom
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1] + bbox[3])
    # right_top
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1])
    annotation_item['segmentation'].append(seg)
    annotation_item['area'] = bbox[2] * bbox[3]
    annotation_item['iscrowd'] = 0
    annotation_item['image_id'] = image_id
    annotation_item['bbox'] = bbox
    annotation_item['category_id'] = category_id
    annotation_item['id'] = annotation_id
    annotation_item['ignore'] = 0
    annotation_id += 1
    coco['annotations'].append(annotation_item)


def parseXmlFiles(xml_path):
    xmllist = os.listdir(xml_path)
    xmllist.sort()
    for f in xmllist:
        if not f.endswith('.xml'):
            continue

        bndbox = dict()
        size = dict()
        current_image_id = None
        current_category_id = None
        file_name = None
        size['width'] = None
        size['height'] = None
        size['depth'] = None

        xml_file = os.path.join(xml_path, f)
        print(xml_file)

        tree = ET.parse(xml_file)
        root = tree.getroot()  # 抓根结点元素

        if root.tag != 'annotation':  # 根节点标签
            raise Exception(
                'pascal voc xml root element should be annotation, rather than {}'.format(root.tag))

        # elem is <folder>, <filename>, <size>, <object>
        for elem in root:
            current_parent = elem.tag
            current_sub = None
            object_name = None

            # elem.tag, elem.attrib,elem.text
            if elem.tag == 'folder':
                continue

            if elem.tag == 'filename':
                file_name = elem.text
                if file_name in category_set:
                    raise Exception('file_name duplicated')

            # add img item only after parse <size> tag
            elif current_image_id is None and file_name is not None and size['width'] is not None:
                if file_name not in image_set:
                    current_image_id = addImgItem(file_name, size)  # 图片信息
                    print('add image with {} and {}'.format(file_name, size))
                else:
                    raise Exception('duplicated image: {}'.format(file_name))
                    # subelem is <width>, <height>, <depth>, <name>, <bndbox>
            for subelem in elem:
                bndbox['xmin'] = None
                bndbox['xmax'] = None
                bndbox['ymin'] = None
                bndbox['ymax'] = None

                current_sub = subelem.tag
                if current_parent == 'object' and subelem.tag == 'name':
                    object_name = subelem.text
                    # if object_name not in category_set:
                    #    current_category_id = addCatItem(object_name)
                    # else:
                    #current_category_id = category_set[object_name]
                    current_category_id = category_set.index(
                        object_name) + 1  # index默认从0开始,但是json文件是从1开始,所以+1
                elif current_parent == 'size':
                    if size[subelem.tag] is not None:
                        raise Exception('xml structure broken at size tag.')
                    size[subelem.tag] = int(subelem.text)

                # option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
                for option in subelem:
                    if current_sub == 'bndbox':
                        if bndbox[option.tag] is not None:
                            raise Exception(
                                'xml structure corrupted at bndbox tag.')
                        bndbox[option.tag] = int(option.text)

                # only after parse the <object> tag
                if bndbox['xmin'] is not None:
                    if object_name is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_image_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_category_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    bbox = []
                    # x
                    bbox.append(bndbox['xmin'])
                    # y
                    bbox.append(bndbox['ymin'])
                    # w
                    bbox.append(bndbox['xmax'] - bndbox['xmin'])
                    # h
                    bbox.append(bndbox['ymax'] - bndbox['ymin'])
                    print(
                        'add annotation with {},{},{},{}'.format(object_name, current_image_id - 1, current_category_id, bbox))
                    addAnnoItem(object_name, current_image_id -
                                1, current_category_id, bbox)
    # categories部分
    for categoryname in category_set:
        addCatItem(categoryname)


if __name__ == '__main__':
    xml_path = 'dataset/train_anatations'
    json_file = 'VOC2007/anatations/voc_2007_train.json'
    # xml_path = 'dataset/test_anatation'
    # json_file = 'dataset/test.json'
    parseXmlFiles(xml_path)
    json.dump(coco, open(json_file, 'w'))

3.制作txt文件

# -*- coding: utf-8 -*-
# @Author: zhiwei
# @Date:   2019-01-31 09:38:58
# @Last Modified by:   zhiwei
# @Last Modified time: 2019-01-31 10:21:58
#
import os
import re

fp1_path = "JPEGImages"
f = open('VOCdevkit2020/VOC2020/ImageSets/Main/train.txt', 'w')
s = ""
i = 0
for filename in os.listdir(fp1_path):
    # bicycle_train.txt
    # if re.match('^bicycle.+', filename) != None:
    #     s1 = os.path.splitext(os.path.basename(filename))[0]
    #     s = s + s1 + ' ' + str(1) + "
"
    # else:
    #     s1 = os.path.splitext(os.path.basename(filename))[0]
    #     s = s + s1 + ' ' + str(0) + "
"

    # electric bicycle_train.txt
    # if "electric_bicycle" in filename:
    #     s1 = os.path.splitext(os.path.basename(filename))[0]
    #     s = s + s1 + ' ' + str(1) + "
"
    # else:
    #     s1 = os.path.splitext(os.path.basename(filename))[0]
    #     s = s + s1 + ' ' + str(0) + "
"

    # train.txt
    s1 = os.path.splitext(os.path.basename(filename))[0]
    s = s + s1 + "
"

print(s)

f.write(s)
print(len([name for name in os.listdir(fp1_path)]))

4.配置yarm文件

在Detectronconfigs12_2017_baselines目录下,复制文件retinanet_R-50-FPN_1x.yaml,到Detectronconfigsmy目录下重命名为retinanet_R-50-FPN_1x1.0.yaml

5.修改yarm文件 (retinanet_R-50-FPN_1x1.0.yaml)

MODEL:
  TYPE: retinanet
  CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
  NUM_CLASSES: 4
NUM_GPUS: 1
SOLVER:
  WEIGHT_DECAY: 0.0001
  LR_POLICY: steps_with_decay
  BASE_LR: 0.001
  GAMMA: 0.1
  MAX_ITER: 1000
  STEPS: [0, 600, 800]
FPN:
  FPN_ON: True
  MULTILEVEL_RPN: True
  RPN_MAX_LEVEL: 7
  RPN_MIN_LEVEL: 3
  COARSEST_STRIDE: 128
  EXTRA_CONV_LEVELS: True
RETINANET:
  RETINANET_ON: True
  NUM_CONVS: 4
  ASPECT_RATIOS: (1.0, 2.0, 0.5)
  SCALES_PER_OCTAVE: 3
  ANCHOR_SCALE: 4
  LOSS_GAMMA: 2.0
  LOSS_ALPHA: 0.25
TRAIN:
  WEIGHTS: /home/Desktop/test/trainMOdel/R-50.pkl
  DATASETS: ('voc_2007_train',)
  SCALES: (800,)
  MAX_SIZE: 1333
  RPN_STRADDLE_THRESH: -1  # default 0
TEST:
  DATASETS: ('coco_2014_minival',)
  SCALE: 800
  MAX_SIZE: 1333
  NMS: 0.5
  RPN_PRE_NMS_TOP_N: 10000  # Per FPN level
  RPN_POST_NMS_TOP_N: 2000
OUTPUT_DIR: .

我们这里只解释具体的几个参数的含义。

1)cfg,是配置文件,均存在于configs目录下。
在Model中:

MODEL:
    TYPE: generalized_rcnn
    CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
    NUM_CLASSES: 81
    FASTER_RCNN: True

其中需要初学者注意的是NUM_CLASSES,对于customer的数据集,该值为 类别数+1,因此对于coco来说就是80+1
对于Mask网络,model部分还应加上

MASK_ON: True

设置GPU的数量:

NUM_GPUS: 1

SOLVER设置:

SOLVER:
    WEIGHT_DECAY: 0.0001
    LR_POLICY: steps_with_decay
    BASE_LR: 0.0025
    GAMMA: 0.1
    MAX_ITER: 60000
    STEPS: [0, 30000, 40000]

首先,对于训练次数而言,如果数据集不大,我们可以设置为几千次,如果像coco这类数据量较大,几万次还是有必要的。default的单GPU下为60000次。这里值得一提的是,对于多GPU下,MAX_ITER的次数与GPU的数量成反比。另一个需要说的参数是BASE_LR,初始的学习率对于网络训练很重要,太大会使网络不容易收敛到最小值,太小又会使网络收敛过慢。正如之前博客所言,通常取e-3到e-4比较安全。这里0.0025的取值,猜测是作者根据网络和数据多次实验所取的较优的配置。在多GPU下,BASE_LR的取值与GPU成正比。

FPN和FAST RCNN设置:

FPN:
    FPN_ON: True
    MULTILEVEL_ROIS: True
    MULTILEVEL_RPN: True
FAST_RCNN:
    ROI_BOX_HEAD: fast_rcnn_heads.add_roi_2mlp_head
    ROI_XFORM_METHOD: RoIAlign
    ROI_XFORM_RESOLUTION: 7
    ROI_XFORM_SAMPLING_RATIO: 2

曾经尝试改动FPN,但是与此同时model也应该改动,正常情况下这两项在customer数据集上不需要改动。

Train 的设置:

TRAIN:
    WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/MSRA/R-50.pkl
    DATASETS: ('coco_2014_train',)
    SCALES: (500,)
    MAX_SIZE: 833
    BATCH_SIZE_PER_IM: 256
    RPN_PRE_NMS_TOP_N: 2000  # Per FPN level

这里要注意的是DATASETS中,自己的数据集需要在dataset_catalog.py中补充,其中数据集的生成方式github中也有详细说明

Test 的设置:

TEST:
    DATASETS: ('coco_2014_minival',)
    SCALE: 500
    MAX_SIZE: 833
    NMS: 0.5
    RPN_PRE_NMS_TOP_N: 1000  # Per FPN level
    RPN_POST_NMS_TOP_N: 1000

其中,需要注意的包括,SCALE和MAX_SIZE,从他人处得到经验,想在inference时提高准确率其中的一个办法就是换成高分辨率,那么这两个参数就是需要改动的了。
另外NMS的设置可以在非密集场景下减少重复出现的box

6.开始训练

cd到$Detectron/tools/目录下执行命令

python train_net.py --cfg ../configs/my2020/retinanet_R-50-FPN_1x.0.yaml OUTPUT_DIR /home/gaomh/Desktop/test/trainMOdel
  • --cfg :配置文件路径
  • OUTPUT_DIR:训练的输出路径

接下来就是自行训练过程了~

7.遇到的问题

INFO loader.py: 126: Stopping enqueue thread
INFO loader.py: 113: Stopping mini-batch loading thread
INFO loader.py: 113: Stopping mini-batch loading thread
INFO loader.py: 113: Stopping mini-batch loading thread
INFO loader.py: 113: Stopping mini-batch loading thread
Traceback (most recent call last):
  File "tools/train_net.py", line 132, in <module>
    main()
  File "tools/train_net.py", line 114, in main
    checkpoints = detectron.utils.train.train_model()
  File "/home/learner/github/detectron/detectron/utils/train.py", line 86, in train_model
    handle_critical_error(model, 'Loss is NaN')

修改BASE_LR的值从0.01改为0.001

原文地址:https://www.cnblogs.com/answerThe/p/12120913.html