TFRecord 使用

tfrecord生成

import os
import xmltodict
import tensorflow as tf
import numpy as np

dir_path = 'F:数据存储VOCdevkitVOC2012Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:数据存储VOCdevkitVOC2012JPEGImages"
out_path = 'F:数据存储VOCdevkit\voc2012.tfrecord'

classes = [
    "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
    "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session()


def get_and_resize_img(img_file):
    '''
    将图片设置为224*224的尺寸大小
    返回图片,返回变化倍数,shape
    '''
    img = tf.read_file(imgs_dir + '/' + img_file)
    img = tf.image.decode_jpeg(img)
    shape_old = sess.run(img).shape
    resized_img = tf.image.resize_images(img, [224, 224], method=0)
    resized_img = sess.run(resized_img)
    resized_img = np.asarray(resized_img, dtype='uint8')
    resized_img_str = resized_img.tostring()
    shape_new = resized_img.shape
    # print(shape_new)
    # print(shape_old)
    # print('shape_old的长是width是维度1,height是维度0')
    w_scale = shape_new[0] / shape_old[1]
    h_scale = shape_new[1] / shape_old[0]

    return resized_img_str, w_scale, h_scale, shape_new


writer = tf.python_io.TFRecordWriter(out_path)

i = 0
for file in dirs:
    i = i + 1
    # if i > 1000:
    #     break
    with open(dir_path + '/' + file) as xml_txt:
        doc = xmltodict.parse(xml_txt.read())
        img_file_name = file.split('.')[0]
        resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
        img_obtain_classes = []
        y_mins = []
        x_mins = []
        y_maxes = []
        x_maxes = []
        if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
            if doc['annotation']["object"]['name'] in classes:
                img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
                y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
                x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
                y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
                x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
        else:
            for one_object in doc['annotation']["object"]:
                # ['annotation']["object"][0]["name"]
                if one_object['name'] in classes:
                    img_obtain_classes.append(classes.index(one_object['name']))
                    y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
                    x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
                    y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
                    x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
        # example = tf.train.Example(features=tf.train.Features(feature={
        #     'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
        #     'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
        #     'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
        # }
        # ))
        img_file_name = bytes(img_file_name, encoding='utf8')

        example = tf.train.Example(features=tf.train.Features(feature={
            'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
            'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
            'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
            'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)),  # 各个 object 的  ymin
            'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
            'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
            'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
            'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
        }))
        writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')

tfrecord读取

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..")

classes = [
    "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
    "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]



# 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)),  # 各个 object 的  ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))

def _parse_record(example_proto):
    features = {
        'filename': tf.FixedLenFeature([], tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'classes': tf.VarLenFeature(tf.int64),
        'y_mins': tf.VarLenFeature(tf.float32),
        'x_mins': tf.VarLenFeature(tf.float32),
        'y_maxes': tf.VarLenFeature(tf.float32),
        'x_maxes': tf.VarLenFeature(tf.float32),
        'encoded': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_proto, features=features)
    return parsed_features


def read_test(input_file):
    # 用 dataset 读取 tfrecord 文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_initializable_iterator()
    max_value = tf.placeholder(tf.int64, shape=[])
    with tf.Session() as sess:
        sess.run(iterator.initializer, feed_dict={max_value: 100})
        for i in range(2):
            features = sess.run(iterator.get_next())
            name = features['filename']
            name = name.decode()
            shape = features['shape']
            classes = features['classes']
            y_mins = features['y_mins']
            x_mins = features['x_mins']
            y_maxes = features['y_maxes']
            x_maxes = features['x_maxes']
            # name = name.decode()
            img_data = features['encoded']

            print(len(img_data))
            print('=======')
            print("shape", shape)
            print("name", name)
            print("classes", classes.values)
            print("y_mins", y_mins.values)
            print("x_mins", x_mins.values)
            print("y_maxes", y_maxes.values)
            print("x_maxes", x_maxes.values)
            img_data = np.fromstring(img_data, dtype=np.uint8)
            image_data = np.reshape(img_data, shape)
            print("img_data", image_data)
            # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
            # img_data = np.fromstring(img_data, dtype=np.uint8)
            # image_data = np.reshape(img_data, shape)
            #
            # plt.figure()
            # # 显示图片
            plt.imshow(image_data)
            plt.show()


read_test('F:数据存储VOCdevkit\voc2012.tfrecord')


尺寸不固定矩阵的存储和读取

import json
import jieba
import tensorflow as tf

with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
    dic = json.loads(file.read())
    all_words_word2id = dic["all_words_word2id"]

stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
    line = f.readline()
    while line:
        stop_words.append(line[:-1])
        line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))

dir_path = 'F:\数据存储新闻语料\news2016zh_train.json'
dir_path_test = 'F:\数据存储新闻语料\news2016zh_valid.json'
out_path = 'F:\数据存储新闻语料\news2016zh_train_new.tfrecord'


def getCutSequnce(line):
    # 使用jieba 进行中文分词
    raw_words = list(jieba.cut(line, cut_all=False))
    # 存储一句话的分词结果
    raw_word_list = []
    # 去除停用词
    for word in raw_words:
        if word not in stop_words and word not in ['www', 'com', 'http']:
            raw_word_list.append(word)

    return raw_word_list


writer = tf.python_io.TFRecordWriter(out_path)
i = 0

with open(dir_path, encoding='utf-8') as txt:
    one_dic = txt.readline()
    while one_dic:
        i = i + 1
        if i > 10000:
            break
        if (i % 1000) == 0:
            print(i)
        one_dic_json = json.loads(one_dic)

        title = one_dic_json['title']
        content = one_dic_json['content']
        if len(content) > 3000:
            one_dic = txt.readline()
            continue
        one_dic = txt.readline()

        if len(title) == 0 or len(content) == 0:
            continue
        title_list = getCutSequnce(title)
        content_list = getCutSequnce(content)
        title_list_index = []
        for one in title_list:
            try:
                title_list_index.append(all_words_word2id[one])
            except:
                pass

        content_list_index = []
        for one_word in content_list:
            try:
                content_list_index.append(all_words_word2id[one_word])
            except:
                pass

        example = tf.train.Example(features=tf.train.Features(feature={
            'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
            'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
        }))
        writer.write(example.SerializeToString())






import tensorflow as tf
import numpy as np

def _parse_record(example_proto):
    features = {
        'title': tf.VarLenFeature(tf.int64),
        'content': tf.VarLenFeature(dtype=tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features=features)
    return parsed_features

def read_test(input_file):
    # 用 dataset 读取 tfrecord 文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_initializable_iterator()
    with tf.Session() as sess:
        sess.run(iterator.initializer)
        for i in range(5):
            features = sess.run(iterator.get_next())
            name = features['title']
            content = features['content']

            print("xx", content)
            print("xx", np.array(content).shape)
            # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
           
read_test('F:\数据存储新闻语料\news2016zh_train_new.tfrecord')


统计数据条数

import tensorflow as tf


def total_sample(file_name):
    sample_nums = 0
    for record in tf.python_io.tf_record_iterator(file_name):
        sample_nums += 1
    return sample_nums


result = total_sample('F:\数据存储新闻语料\news2016zh_train_new.tfrecord')
print(result)

原文地址:https://www.cnblogs.com/panfengde/p/11302960.html