(2)tf.data.Dataset喂数据给模型

一般而言把数据喂给模型的方式有三种:

1.建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用。使用这种方法十分灵活,可以一下子将所有数据读入内存,然后分batch进行feed;也可以建立一个Python的generator,一个batch一个batch的将数据读入,并将其feed进placeholder。这种方法很直观,用起来也比较方便灵活,但是这种方法的效率较低,难以满足高速计算的需求。

 

2.使用TensorFlow的QueueRunner,通过一系列的Tensor操作,将磁盘上的数据分批次读入并送入模型进行使用。这种方法效率很高,但因为其牵涉到Tensor操作,不够直观,也不方便调试,所有有时候会显得比较困难。使用这种方法时,常用的一些操作包括tf.TextLineReader,tf.FixedLengthRecordReader以及tf.decode_raw等等。如果需要循环,条件操作,还需要使用TensorFlow的tf.while_loop,tf.case等操作,更是难上加难。

 1 import tensorflow as tf
 2 filename_queue=tf.train.string_input_producer(["./data/all_c_dev.en"])
 3 
 4 reader=tf.TextLineReader()
 5 key,value=reader.read(filename_queue)
 6 
 7 
 8 with tf.Session() as sess:
 9     tf.train.start_queue_runners()
10     for i in range(10):
11         print(sess.run([key,value]))

3.自1.x版本开始,逐步开发引入了tf.data.Dataset模块,使其数据读入的操作变得更为方便,而支持多线程(进程)的操作,也在效率上获得了一定程度的提高。

  a.一次将所有数据读入内存

 1 images = ...                                                 #图像数据images读入内存;
 2 labels = ...                                                 #对应的标签数据labels读入内存;
 3 data = tf.data.Dataset.from_tensor_slices((images, labels))  #使用读入内存的数据images、labels构建Dataset;
 4 data = data.batch(batch_size)                                #设置batchsize大小
 5 iterator=tf.data.Iterator.from_structure(data.output_types,
 6                         data.output_shapes)  #基于此前构建的Dataset的数据类型和结构,构建一个可重新初始化iterator
 7 init_op = iterator.make_initializer(data)                    #基于此前构建的Dataset构建一个iterator初始化op。
 8 with tf.Session()  as sess:                                  #展开会话
 9     sess.run(init_op)                                        #初始化iterator
10     try:
11         images, labels = iterator.get_next()                 #获取一个batchsize的数据
12     except tf.errors.OutOfRangeError:                        #iterator中的元素取完之后,会抛出OutOfRangeError异常,TensorFlow没有对这个异常进行处理,我们需要对其进行捕捉和处理。
13         sess.run(init_op)

  b.包装一个generator

def gen():                                                            #定义一个生成器函数
    with  open('train.csv')  as f:
        lines = [line.strip().split(',')  for line in f.readlines()]
        index = 0
        while  True:
            image = cv2.imread(lines[index][0])
            image = cv2.resize(image, (224, 224))
            label = lines[index][1]
            yield  (image, label)
            index += 1
            if index == len(lines):
           index = 0

batch_size = 2
data = tf.data.Dataset.from_generator( gen,                           #指定通过gen构建Dataset
                       (tf.float32, tf.int32), #指定数据类型                        (tf.TensorShape([
224, 224, 3]),tf.TensorShape([]))) #指定shape
                     
data = data.batch(batch_size)                          #设置batchsize
iter = data.make_one_shot_iterator()                                 #创建单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。 
with tf.Session()  as sess:
    images, labels = iter.get_next()

  c.使用tensor读取数据

def _parse_function(filename, label):                  #定义解析数据的函数
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    image = tf.image.resize_images(image, [224, 224])
    return image, filename, label

images = tf.constant(image_names)                    #转化为tensor
labels = tf.constant(labels)
images = tf.random_shuffle(images, seed=0)
labels = tf.random_shuffle(labels, seed=0)
data = tf.data.Dataset.from_tensor_slices((images, labels))    #利用tensor构建dataset
data = data.map(_parse_function, num_parallel_calls=4)       #利用map函数处理tensor得到新的dataset,num_parallel_calls表示并行处理
data
= data.prefetch(buffer_size=batch_size * 10)          #prefetch可以充分利用时间,预准备 data = data.batch(batch_size)                      #设置batchsize iterator = tf.data.Iterator.from_structure(data.output_types,data.output_shapes) #构建iterator init_op = iterator.make_initializer(data)               #初始化 with tf.Session() as sess: sess.run(init_op) try: images, filenames, labels = iterator.get_next() except tf.errors.OutOfRangeError: sess.run(init_op)

使用tf,data是一种管道pipeline机制,他有很多的特色,比如prefetch和map,能够充分利用cpu的时间,这篇博客介绍的很好。

 tf.data.Dataset.from_generator
原文地址:https://www.cnblogs.com/super-zheng/p/13215425.html