Tensorflow Dataset.from_generator使用示例

shapes = (tf.TensorShape([None, None]), tf.TensorShape([10, 10]))
# 传入的是一个generator,即返回字段为yield的函数,不可传入嵌套生成器
# dataSet output_types参数必选,output_shapes参数可选,不选会直接适配数据的shape
# 参数就是一个元组
data_set = tf.data.Dataset.from_generator(gen_epochs,
                                          output_types=(tf.int32, tf.int32),
                                          output_shapes=shapes,
                                          args=(n, batch_size, 10))

之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介绍了使用Tensorflow Dataset进行数据导入的方法及其优势。最近在实际使用中越发感觉到这个方式非常好用,尤其是发现了.from_generator这个method。

关于Dataset.from_generator的简单介绍,请参见如下两个链接:

https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat

https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369

注意,Dataset.from_generator在旧版Tensorflow中没有,起码在1.3版本tf.contrib.data.Dataset中还没有,后来用的1.7版本就有了。

我们知道,tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。具体到data feeding上,也是如此。虽然设计了placeholder、train.slice_input_producer系列、Dataset等多种方式,但使用中仍有各种不便,尤其是在输入形式复杂、需要多重变换的时候更是如此。而Dataset.from_generator可以在一定程度上解决这个问题。

简单的说,Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:

'''
import pickle
fr=open('/media/dell/D/qcc/RandLA-Net/data/semantic_kitti/dataset/sequences_0.06/00/KDTree/000001.pkl','rb')
inf = pickle.load(fr)
doc = open('1.txt', 'a')
print(inf, file=doc)
print(inf)
'''

# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834

"""
Expected outputs:
Batch No. 0:
[0 1 2 3]
Batch No. 1:
[4 0 1 2]
Batch No. 2:
[3 4 0 1]
Batch No. 3:
[2 3 4]
end!
"""

import numpy as np
import tensorflow as tf


def data_generator():
    dataset = np.array(range(5))
    for d in dataset:
        #print(d)
        yield d


dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
dataset = dataset.repeat(3) #3==epoch
dataset = dataset.batch(4) #4==batchsize

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    try:
        batch_num = 0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num)
            print(one_batch)
            print('')
            batch_num += 1

    except tf.errors.OutOfRangeError:
        print('end!')
        
        

很显然,这个的输出如下:

  1. Batch No. 0:
  2. [0 1 2 3]
  3.  
  4. Batch No. 1:
  5. [4 0 1 2]
  6.  
  7. Batch No. 2:
  8. [3 4 0 1]
  9.  
  10. Batch No. 3:
  11. [2 3 4]
  12.  
  13. end!

下面给出一个复杂的问题。假设需要输入如下序列:

A B

A C B

C

其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。

针对这个问题,使用Dataset.from_generator的一个简化版示例如下:

  1. # demo of Dataset.from_generator
  2. # blog.csdn.net/foreseerwang
  3. # QQ: 50834
  4.  
  5. """
  6. Expected outputs:
  7.  
  8. Batch No. 0:
  9. [[ 1 2 3]
  10. [ 2 3 -1]]
  11.  
  12. Batch No. 1:
  13. [[ 3 -1 -1]
  14. [ 4 5 -1]]
  15.  
  16. Batch No. 2:
  17. [[ 6 7 8]
  18. [ 9 -1 -1]]
  19.  
  20. Batch No. 3:
  21. [[10 11 12]
  22. [13 14 -1]]
  23.  
  24. Batch No. 4:
  25. [[15 -1 -1]]
  26.  
  27. end!
  28. """
  29.  
  30. import io
  31. import numpy as np
  32. import tensorflow as tf
  33.  
  34. class DataFeeder:
  35.  
  36. def __init__(self, filenames):
  37. self.filenames = filenames
  38.  
  39. def file_readline(self):
  40. for filename in self.filenames:
  41. fr = io.open(filename, 'r', encoding='utf-8')
  42.  
  43. while True:
  44. file_line = fr.readline()
  45. if not file_line:
  46. break
  47.  
  48. datalist = file_line.split()
  49. # if datalist is a list of filename, file contents can
  50. # be read and appendded here.
  51. yield np.asarray(datalist, dtype='int32')
  52.  
  53. fr.close()
  54.  
  55. def generate_batch(self, batch_size, num_epochs=None):
  56. dataset = tf.data.Dataset.from_generator(self.file_readline,
  57. tf.int32,
  58. tf.TensorShape([None]))
  59.  
  60. dataset = dataset.repeat(num_epochs)
  61. dataset = dataset.padded_batch(
  62. batch_size,
  63. padded_shapes=tf.TensorShape([3]),
  64. padding_values=-1)
  65.  
  66. iterator = dataset.make_one_shot_iterator()
  67. out_batch = iterator.get_next()
  68.  
  69. return out_batch
  70.  
  71. filenames = ['a.txt', 'b.txt', 'c.txt']
  72. data_feeder = DataFeeder(filenames)
  73. one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)
  74.  
  75. with tf.Session() as sess:
  76. try:
  77. batch_num = 0
  78. while True:
  79. data_batch = sess.run(one_batch)
  80. print('Batch No. %d:' % batch_num)
  81. print(data_batch)
  82. print('')
  83. batch_num+=1
  84.  
  85. except tf.errors.OutOfRangeError:
  86. print('end!')

其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:

a.txt:

1 2 3
2 3
3

b.txt:

4 5
6 7 8
9

c.txt:

10 11 12
13 14
15

运行以上代码的输出为:

  1. Batch No. 0:
  2. [[ 1 2 3]
  3. [ 2 3 -1]]
  4.  
  5. Batch No. 1:
  6. [[ 3 -1 -1]
  7. [ 4 5 -1]]
  8.  
  9. Batch No. 2:
  10. [[ 6 7 8]
  11. [ 9 -1 -1]]
  12.  
  13. Batch No. 3:
  14. [[10 11 12]
  15. [13 14 -1]]
  16.  
  17. Batch No. 4:
  18. [[15 -1 -1]]
  19.  
  20. end!

目前的输出,每个batch是batch_size * 3的矩阵。实际上,1~15的数字可以是某个图片的文件名,在file_readline()函数中读出这些数字后,可以继续读出这些文件的内容,并形成更高维度的Dataset输出,譬如:batch_size * img_size * img_size * img_channel的Dataset。

最后,说几点注意事项(详见代码):

1. generator函数不能有输入参数,但如果是class内的一个函数,可以使用self参数,这也是传递参数的一个手段;

2. 上述class中,建议传递文件名,在generator中打开处理再关闭,而不应该在外面打开(fr=open(filename, ‘r’)),然后把fr传递给generator读取。实践表明:后面这种方法形成的dataset不能repeat;

3. 因为序列不等长,在形成dataset batch时需要使用Dataset.padded_batch方法。

原文地址:https://www.cnblogs.com/yibeimingyue/p/13805105.html