4-6 TF之TFRecord数据打包案例

import numpy as np
import tensorflow as tf
import cv2
import numpy as np

classification = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck']

import glob  # 获取类别文件夹图片的获取

###读入图片的src,并且相应的在im_labels标注图片类别
idx = 0
im_data = []
im_labels = []
for path in classification:
    path = 'data/image/train/' + path
    im_list = glob.glob(path + '/*')  # get images url
    im_label = [idx for i in range(im_list.__len__())]  # 对于每一个i,都加入idx
    idx += 1
    im_data += im_list
    im_labels += im_label
print(im_labels)
print(im_data)


tfrecord_file = 'data/train.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)  # 定义写入实例

index=[i for i in range(im_data.__len__())]#打乱图片顺序
np.random.shuffle(index)#实际上是把数字打乱,然后根据数字来取图片,达到乱序取图

##循环把每张图片都改变储存结构
##value=的值需要转换为适当类型,在tf,train.BytesList是byte列表转换函数,因为value可能是多维
##cv2和tf都有读取图片的函数,区别在于:
##  tf的图片读取后就是byte型,所以value不需要转换类型,并且tf读取图片会重新编码,减小内存,但图片输出需要被解压
for i in range(im_data.__len__()):
    im_d = im_data[index[i]]
    im_l = im_labels[index[i]]
    data = cv2.imread(im_d)  # 从图片url获取到真实数据
    #tf.gfile.FastGFile(src,'rb').read()#tf的图片读取方式,优点是读取的图片本身就是byte型,下面就不需要类型转换
    ex = tf.train.Example(  #主要用在将数据处理成二进制方面
        features=tf.train.Features(
            feature={
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tobytes()])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[im_l])),
                #'height':tf.train.Feature(int64_list=tf.train.Int64List(value=[data.shape[1]])),
                #'width':tf.train.Feature(int64_list=tf.train.Int64List(value=[data.shape[2]])),
            ##这里也可以记录图片大小,因为cifar图像都是32*32,所以这里不记录
            ##对于图片尺寸可以在这里feature记录,也可以在opcv处理时归一化
            }
        )
    )

    writer.write(ex.SerializeToString())
writer.close()

ex = tf.train.Example
tf.train.Example有一个属性为features
tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
当然,既然有对象序列化为字符串的方法,那么肯定有从字符串反序列化到对象的方法,该方法是FromString(),需要传递一个tf.train.Example对象序列化后的字符串进去做为参数才能得到反序列化的对象。

原文地址:https://www.cnblogs.com/thgpddl/p/12843446.html