labelimg 数据 转 tfrecord 数据

一、labelImg 的使用请自行百度~

二、xml 转 csv

labelImg 标注好图片后,得到的是N个xml文件;这里,我们处理一下 xml 目录,得到一个 csv 文件

import glob
import pandas as pd
import xml.etree.ElementTree as ET


def xml_to_csv(xml_dir):
    xml_list = []
    for xml_file in glob.glob(xml_dir + '/*.xml'):
        print(xml_file)
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)

    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df


if __name__ == '__main__':
    # 输入:标注图像集后生成的 .xml 文件目录
    xml_dir = r'/path_to_xml'
    xml_df = xml_to_csv(xml_dir)
    # 输出:生成的 .csv 文件的存放位置
    csv_path = r'/path_to_output/xxx.csv'
    xml_df.to_csv(csv_path, index=None)
    print('Successfully converted xml to csv')
    print(csv_path)

三、csv 转 tfrecord

注:先要安装好 object_detection api,安装教程:https://www.cnblogs.com/tujia/p/13952108.html

import os
import json
import pandas as pd
from object_detection.dataset_tools import create_coco_tf_record


def create_tfrecord(csv_path, data_dir, output_dir):
    examples = pd.read_csv(csv_path)

    images, annotations = [], []
    image, exists = None, []

    for i, row in examples.iterrows():
        if row['filename'] not in exists:
            image = {
                'id': i,
                'file_name': row['filename'],
                'width': row['width'],
                'height': row['height']
            }
            images.append(image)
            exists.append(row['filename'])

        annotations.append({
            'area': 0.5,
            'iscrowd': False,
            'image_id': image['id'],
            'bbox': [row['xmin'], row['ymin'], row['xmax']-row['xmin'], row['ymax']-row['ymin']],
            'category_id': 1,
            'id': i
        })
    
    groundtruth_data = {'images': images, 'annotations': annotations, 'categories': [category_index[1]]}
    annotation_file = os.path.join(output_dir, class_name + '_annotation.json')
    with open(annotation_file, 'w') as annotation_fid:
        json.dump(groundtruth_data, annotation_fid)

    output_path = os.path.join(output_dir, class_name + '.record')
    create_coco_tf_record._create_tf_record_from_coco_annotations(
        annotation_file,
        data_dir,
        output_path,
        False,
        2)
    print('Finish!!')
    print(output_path.replace(class_name, 'xxx') + '...')


if __name__ == '__main__':
    class_name = 'xxx'
    category_index = {1: {'id': 1, 'name': class_name}}
    CSV_PATH = '/tf/datasets/%s.csv' % class_name
    DATA_DIR = '/tf/datasets/%s' % class_name
    OUTPUT_DIR = '/tf/object_detection/data/'
    create_tfrecord(CSV_PATH, DATA_DIR, OUTPUT_DIR)

注:我这里只有一个类,category_index 我就直接写了,多个类的,自己修改一下

四、验证 tfrecord 数据准确性(可视化)

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from object_detection.utils import visualization_utils as viz_utils
from six import BytesIO
from PIL import Image

%matplotlib inline


def load_image_into_numpy_array(img_data):
    image = Image.open(BytesIO(img_data))
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)


def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    figsize=(12, 16),
                    image_name=None):
    image_np_with_annotations = image_np.copy()
    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_annotations,
        boxes,
        classes,
        scores,
        category_index,
        use_normalized_coordinates=True,
        min_score_thresh=0.8)
    if image_name:
        plt.imsave(image_name, image_np_with_annotations)
    else:
        plt.figure()
        plt.imshow(image_np_with_annotations)

def get_boxes(filenames):
    raw_dataset = tf.data.TFRecordDataset(filenames)

    images_np = []
    gt_boxes = []
    for raw_record in raw_dataset.take(2):
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        for key, item in example.features.feature.items():
            if key == 'image/encoded':
                images_np.append(load_image_into_numpy_array(item.bytes_list.value[0]))
            #if item.float_list.value:
                #print(key + ':', end='')
                #print(item.float_list.value)
        gt_boxes.append(np.array([[
                example.features.feature['image/object/bbox/ymin'].float_list.value[0],
                example.features.feature['image/object/bbox/xmin'].float_list.value[0],
                example.features.feature['image/object/bbox/ymax'].float_list.value[0],
                example.features.feature['image/object/bbox/xmax'].float_list.value[0]
            ]
        ], dtype=np.float32))
    return images_np, gt_boxes


if __name__ == '__main__':
    class_name = 'xxx'
    category_index = {1: {'id': 1, 'name': class_name}}
    filenames = ['/tf/object_detection/data/%s.record-00000-of-00001' % class_name]
    (images_np, gt_boxes) = get_boxes(filenames)

    # give boxes a score of 100%
    dummy_scores = np.array([1.0], dtype=np.float32)

    for idx in range(2):
        plot_detections(
            images_np[idx],
            gt_boxes[idx],
            np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),
            dummy_scores, category_index)
原文地址:https://www.cnblogs.com/tujia/p/14085045.html