Tensorflow中使用TFRecords高效读取数据--结合Attention-over-Attention Neural Network for Reading Comprehension

原文链接:https://arxiv.org/pdf/1607.04423.pdf 本片论文主要讲了Attention Model在完形填空类的阅读理解上的应用。

转载:https://blog.csdn.net/liuchonge/article/details/73649251

在进行论文仿真的时候用到了TFRecords进行数据的读取操作,所以进行深入学习。这两天看了一下相关博客,结合该代码记录一下TFRecords的相关操作。 
首先说一下为什么要使用TFRecords来进行文件的读写,在TF中数据的传入方式主要包含以下几种:

    1. 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
    2. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
    3. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

之前都是使用1和3进行数据的操作,但是当我们遇到数据集比较大的情况时,这两种方法会及其占用内存,效率很差。那么为甚么使用TFRecords会比较快呢?因为其使用二进制存储文件,也就是将数据存储在一个内存块中,相比其它文件格式要快很多,特别是如果你使用hdd而不是ssd,因为它涉及移动磁盘阅读器头并且需要相当长的时间。总体而言,通过使用二进制文件,您可以更轻松地分发数据,使数据更好地对齐,以实现高效的读取。接下来我们看一下具体的操作。

这里可以参见官网给的建议:

Another approach is to convert whatever data you have into a supported format. This approach makes it easier to mix and match data sets and network architectures. The recommended format for TensorFlow is a TFRecords file containing tf.train.Example protocol buffers (which contain Features as a field). You write a little program that gets your data, stuffs it in an Example protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the tf.python_io.TFRecordWriter. For example, tensorflow/examples/how_tos/reading_data/convert_to_records.py converts MNIST data to this format.

To read a file of TFRecords, use tf.TFRecordReader with the tf.parse_single_example decoder. The parse_single_example op decodes the example protocol buffers into tensors. An MNIST example using the data produced by convert_to_records can be found in tensorflow/examples/how_tos/reading_data/fully_connected_reader.py, which you can compare with the fully_connected_feed version.

  个人感觉可以分成两部分,一是使用tf.train.Example协议流将文件保存成TFRecords格式的.tfrecords文件,这里主要涉及到使用tf.python_io.TFRecordWriter("train.tfrecords")tf.train.Example以及tf.train.Features三个函数,第一个是生成需要对应格式的文件,后面两个函数主要是将我们要传入的数据按照一定的格式进行规范化。这里还要提到一点就是使用TFRecords可以避免多个文件的使用,比如说我们一般会将一次要传入的数据的不同部分分别存放在不同文件夹中,question一个,answer一个,query一个等等,但是使用TFRecords之后,我们可以将一批数据同时保存在一个文件之中,这样方便我们在后续程序中的使用。

另一部分就是在训练模型时将我们生成的.tfrecords文件读入并传到模型中进行使用。这部分主要涉及到使用tf.TFRecordReader("train.tfrecords")tf.parse_single_example两个函数。第一个函数是将我们的二进制文件读入,第二个则是进行解析然后得到我们想要的数据。

接下来我们结合代码进行理解:

生成TFRecords文件

这里关于要使用的数据集的介绍可以参考我的下一篇,主要是QA任务的数据集。代码如下所示:

 1 def tokenize(index, word):
 2 #index是每个单词对应词袋子之中的索引值,word是所有出现的单词
 3   directories = ['cnn/questions/training/', 'cnn/questions/validation/', 'cnn/questions/test/']
 4   for directory in directories:
 5   #分别读取训练测试验证集的数据
 6     out_name = directory.split('/')[-2] + '.tfrecords'
 7     #生成对应.tfrecords文件
 8     writer = tf.python_io.TFRecordWriter(out_name)
 9     #每个文件夹下面都有若干文件,每个文件代表一个QA队,也就是一条训练数据
10     files = map(lambda file_name: directory + file_name, os.listdir(directory))
11     for file_name in files:
12       with open(file_name, 'r') as f:
13         lines = f.readlines()
14         #对每条数据分别获得文档,问题,答案三个值,并将相应单词转化为索引
15         document = [index[token] for token in lines[2].split()]
16         query = [index[token] for token in lines[4].split()]
17         answer = [index[token] for token in lines[6].split()]
18         #调用Example和Features函数将数据格式化保存起来。注意Features传入的参数应该是一个字典,方便后续读数据时的操作
19         example = tf.train.Example(
20            features = tf.train.Features(
21              feature = {
22                'document': tf.train.Feature(
23                  int64_list=tf.train.Int64List(value=document)),
24                'query': tf.train.Feature(
25                  int64_list=tf.train.Int64List(value=query)),
26                'answer': tf.train.Feature(
27                  int64_list=tf.train.Int64List(value=answer))
28                }))
29     #写数据
30       serialized = example.SerializeToString()
31       writer.write(serialized)
View Code

读取.tfrecords文件

因为在读取数据之后我们可能还会进行一些额外的操作,使我们的数据格式满足模型输入,所以这里会引入一些额外的函数来实现我们的目的。这里介绍几个个人感觉较重要常用的函数。不过还是推荐到官网API去查,或者有某种需求的时候到Stack Overflow上面搜一搜,一般都能找到满足自己需求的函数。 
1,string_input_producer( 
string_tensor, 
num_epochs=None, 
shuffle=True, 
seed=None, 
capacity=32, 
shared_name=None, 
name=None, 
cancel_op=None 
)
其输出是一个输入管道的队列,这里需要注意的参数是num_epochs和shuffle。对于每个epoch其会将所有的文件添加到文件队列当中,如果设置shuffle,则会对文件顺序进行打乱。其对文件进行均匀采样,而不会导致上下采样。

2,shuffle_batch( 
tensors, 
batch_size, 
capacity, 
min_after_dequeue, 
num_threads=1, 
seed=None, 
enqueue_many=False, 
shapes=None, 
allow_smaller_final_batch=False, 
shared_name=None, 
name=None 
)
产生随机打乱之后的batch数据

3,sparse_ops.serialize_sparse(sp_input, name=None): 返回一个字符串的3-vector(1-D的tensor),分别表示索引、值、shape

4,deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): 将多个稀疏的serialized_sparse合并成一个

 1 def read_records(index=0):
 2 #生成读取数据的队列,要指定epoches
 3   train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
 4   validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
 5   test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)
 6 
 7   queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
 8   #定义一个recordreader对象,用于数据的读取
 9   reader = tf.TFRecordReader()
10   #从之前的队列中读取数据到serialized_example
11   _, serialized_example = reader.read(queue)
12   #调用parse_single_example函数解析数据
13   features = tf.parse_single_example(
14       serialized_example,
15       features={
16         'document': tf.VarLenFeature(tf.int64),
17         'query': tf.VarLenFeature(tf.int64),
18         'answer': tf.FixedLenFeature([], tf.int64)
19       })
20 
21 #返回索引、值、shape的三元组信息
22   document = sparse_ops.serialize_sparse(features['document'])
23   query = sparse_ops.serialize_sparse(features['query'])
24   answer = features['answer']
25 
26 #生成batch切分数据
27   document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
28       [document, query, answer], batch_size=FLAGS.batch_size,
29       capacity=2000,
30       min_after_dequeue=1000)
31 
32   sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
33   sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)
34 
35   document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
36   document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.shape, 1)
37 
38   query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
39   query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.shape, 1)
40 
41   return document_batch, document_weights, query_batch, query_weights, answer_batch
View Code

最后,我们要在模型开始训练之前,执行下面两行代码:

1 with tf.Session() as sess:
2   # Start populating the filename queue.
3   coord = tf.train.Coordinator()
4   threads = tf.train.start_queue_runners(coord=coord)
View Code

这是填充队列的指令,如果不执行程序会等在队列文件的读取处无法运行。至此,我们就可以使用TFRecords来读写文件了。最后总结一下,大概格式如下,这里并未指定某种读写函数,而是可以自定义的方式用的伪代码来说一下整个流程:

 1 def read_my_file_format(filename_queue):
 2   reader = tf.SomeReader()
 3   key, record_string = reader.read(filename_queue)
 4   example, label = tf.some_decoder(record_string)
 5   processed_example = some_processing(example)
 6   return processed_example, label
 7 
 8 def input_pipeline(filenames, batch_size, num_epochs=None):
 9   filename_queue = tf.train.string_input_producer(
10       filenames, num_epochs=num_epochs, shuffle=True)
11   example, label = read_my_file_format(filename_queue)
12   # min_after_dequeue defines how big a buffer we will randomly sample
13   #   from -- bigger means better shuffling but slower start up and more
14   #   memory used.
15   # capacity must be larger than min_after_dequeue and the amount larger
16   #   determines the maximum we will prefetch.  Recommendation:
17   #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
18   min_after_dequeue = 10000
19   capacity = min_after_dequeue + 3 * batch_size
20   example_batch, label_batch = tf.train.shuffle_batch(
21       [example, label], batch_size=batch_size, capacity=capacity,
22       min_after_dequeue=min_after_dequeue)
23   return example_batch, label_batch
View Code

 

 

原文地址:https://www.cnblogs.com/gaofighting/p/9628918.html