TensorFlow文件读取

一、文件读取并训练的原理图

1、文件队列构造

tf.train.string_input_producer(string_tensor, ,shuffle=True):将输出字符串(例如文件名)输入到管道队列

  • string_tensor:含有文件名的1阶张量
  • num_epochs:过几遍数据,默认无限过数据
  • return:具有输出字符串的队列

将文件名列表交给tf.train.string_input_producer函数。string_input_producer来生成一个先入先出的队列,文件阅读器会需要它们来取数据。string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数,QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,如果shuffle=True的话,会对文件名进行乱序处理。一过程是比较均匀的,因此它可以产生均衡的文件名队列。

这个QueueRunner工作线程是独立于文件阅读器的线程,因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的read方法。阅读器的read方法会输出一个键来表征输入的文件和其中纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

2、文件阅读器

根据文件格式,选择对应的文件阅读器

class tf.TextLineReader:阅读文本文件逗号分隔值(CSV)格式,默认按行读取

class tf.FixedLengthRecordReader(record_bytes):要读取每个记录是固定数量字节的二进制文件

  • record_bytes:整型,指定每次读取的字节数

tf.TFRecordReader:读取TfRecords文件

3、文件内容解码器

由于从文件中读取的是字符串,需要函数去解析这些字符串到张量

tf.decode_csv(records,record_defaults=None,field_delim = None,name = None):将CSV转换为张量,与tf.TextLineReader搭配使用

  • records:tensor型字符串,每个字符串是csv中的记录行
  • field_delim:默认分割符”,”
  • record_defaults:参数决定了所得张量的类型,并设置一个值在输入字符串中缺少使用默认值

tf.decode_raw(bytes,out_type,little_endian = None,name = None) :将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用,二进制读取为uint8格式

4、开启线程操作

tf.train.start_queue_runners(sess=None,coord=None):收集所有图中的队列线程,并启动线程

  • sess:所在的会话中
  • coord:线程协调器
  • return:返回所有线程队列

5、管道读端批处理

tf.train.batch(tensors,batch_size,num_threads = 1,capacity = 32,name=None):读取指定大小(个数)的张量

  • tensors:可以是包含张量的列表
  • batch_size:从队列中读取的批处理大小
  • num_threads:进入队列的线程数
  • capacity:整数,队列中元素的最大数量
  • return:tensors

tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue, num_threads=1,) :乱序读取指定大小(个数)的张量

  • min_after_dequeue:留下队列里的张量个数,能够保持随机打乱

6、CSV文件读取案例

import tensorflow as tf
import os

def readcsv(filelist):
    """
    读取csv文件
    """
    # 构造文件队列
    file_queue = tf.train.string_input_producer(filelist)

    # 构建阅读器
    reader = tf.TextLineReader()

    key, value = reader.read(file_queue)

    # 对每行内容进行解码
    records = [["None"], ["None"]]
    example, label = tf.decode_csv(value, record_defaults=records)

    # 批处理
    example_batch, label_batch = tf.train.batch([example, label], batch_size=10, num_threads=1, capacity=10)
    return example_batch, label_batch


if __name__ == '__main__':
    filelist = os.listdir("./data/csvdata")
    filelist = ["./data/csvdata/{}".format(i) for i in filelist]

    example_batch, label_batch = readcsv(filelist)

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # 线程协调器
        coord = tf.train.Coordinator()

        # 开启读取文件线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        # 打印数据
        print(sess.run([example_batch, label_batch]))

        coord.request_stop()
        coord.join()
View Code
原文地址:https://www.cnblogs.com/20183544-wangzhengshuai/p/14395207.html