采用tfrecord形式读写训练数据

tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。尤其在面对海量数据时,使用常用的内存读取方式变得不切实际,tfrecored方式为我们带来了更大的便捷,同时还可以配合shuffe大大提高model的train效率。

示例代def convert_tfrecord(data, label):

"""保存为tfrecord形式
    :param data:
    :param label:
    :return:
    """
    record_path = './resources/train.tfrecord'
    # 调用example和features函数将数据格式化保存起来
    cnt = 0
    writer = tf.python_io.TFRecordWriter(record_path)
    for d, s, l in zip(data[0], data[1], label):
        if cnt % 100 == 0:
            print('write example {}'.format(cnt))
        cnt += 1
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=d)),
                    'score': tf.train.Feature(float_list=tf.train.FloatList(value=s)),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[l]))
                }
            )
        )

        writer.write(example.SerializeToString())
    writer.close()
    print('写入ok')

    # 读取,batch 取
    filename_queue = tf.train.string_input_producer([record_path],)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
features
= tf.io.parse_single_example(serialized_example, features={ 'sample': tf.io.FixedLenFeature([9], tf.int64), 'score': tf.io.FixedLenFeature([9], tf.float32), 'label': tf.io.FixedLenFeature([1], tf.int64), }) is_batch = True if is_batch: batch_size = 3 min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batch_size samples, scores, labels = tf.train.shuffle_batch([features['sample'], features['score'], features['label']], batch_size=batch_size, num_threads=3, capacity=capacity, min_after_dequeue=min_after_dequeue) with tf.compat.v1.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1000): # 从会话中取出数据 sample, score, label = sess.run([samples, scores, labels]) print(sample) print(score) print('###########') coord.request_stop() coord.join(threads) print('ok')
原文地址:https://www.cnblogs.com/demo-deng/p/13789061.html