tensorflow中数据批次划分示例教程

1.简介

将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。

1.1 tf.train

tf.train.slice_input_producer(tensor_list,shuffle=True,seed=None,capacity=32)

tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,allow_smaller_final_batch=False)

参数说明:

shuffle:为True时进行数据清洗

allow_smaller_final_batch:为True时将小于batch_size的批次值输出

-------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------------------------------------------------

1.2 tf.data.Dataset

tf.data.Dataset是一个类,可以使用以下方法:

from_tensor_slices(tensors)

batch(batch_size,drop_remainder=False)

shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

repeat(count=None)

make_one_shot_iterator() / get_next()

注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器

参数说明:

tensors:可以是列表、字典、元组等类型

drop_remainder:为False时表示不保留小于batch_size的批次,否则删除

buffer_size:数据清洗时使用的buffer大小

count:对应为epoch个数,为None时表示数据序列无限延续

2.示例

2.1 使用tf.train.slice_input_producer和tf.train.batch

 1 import tensorflow as tf
 2 import numpy as np
 3 import math
 4 
 5 # 生成样例数据集
 6 def generate_data():
 7     num = 15
 8     labels = np.asarray(range(num))
 9     images = np.random.random([num, 5, 5, 3])
10     return images, labels
11 
12 # 打印样例信息
13 images, labels = generate_data()
14 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
15 
16 # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数
17 n_epochs = 3
18 batch_size = 6
19 train_nums = 15
20 iterations = math.ceil(train_nums/batch_size)
21 
22 # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据
23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False)
24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32)
25 print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape))
26 
27 
28 with tf.Session() as sess:
29     tf.global_variables_initializer().run()
30     # 启动队列线程
31     coord = tf.train.Coordinator()
32     threads = tf.train.start_queue_runners(sess, coord)
33     # 打印信息
34     for epoch in range(n_epochs):       
35         for iteration in range(iterations):
36             cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch])
37             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
38     # 接收线程
39     coord.request_stop()
40     coord.join(threads)    
41 
42 
43 # 打印结果如下
44 images.shape=(15, 5, 5, 3), labels.shape=(15,)
45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
47 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
48 The 1 epoch, the 3 iteration, current batch is [12 13 14  0  1  2]
49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8]
50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14]
51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5]
52 The 3 epoch, the 1 iteration, current batch is [ 6  7  8  9 10 11]
53 The 3 epoch, the 2 iteration, current batch is [12 13 14  0  1  2]
54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]

如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
 3 The 1 epoch, the 1 iteration, current batch is [ 2  5  8 11  3 10]
 4 The 1 epoch, the 2 iteration, current batch is [ 9 12  7  1 14 13]
 5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6]
 6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13  5]
 7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7]
 8 The 2 epoch, the 3 iteration, current batch is [10 13  1  4 12  3]
 9 The 3 epoch, the 1 iteration, current batch is [ 2  8  5  9 14  7]
10 The 3 epoch, the 2 iteration, current batch is [ 0 11  6  1 14  9]
11 The 3 epoch, the 3 iteration, current batch is [11  6 12  7  0 13]

如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 3 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
 4 The 1 epoch, the 3 iteration, current batch is [12 13 14]
 5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 6 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
 7 The 2 epoch, the 3 iteration, current batch is [12 13 14]
 8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
 9 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
10 The 3 epoch, the 3 iteration, current batch is [12 13 14]

2.2 使用tf.data.Dataset类

 1 import tensorflow as tf
 2 import numpy as np
 3 import math
 4 
 5 # 生成样例数据集
 6 def generate_data():
 7     num = 15
 8     labels = np.asarray(range(num))
 9     images = np.random.random([num, 5, 5, 3])
10     return images, labels
11 # 打印样例信息
12 images, labels = generate_data()
13 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
14 
15 # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数
16 n_epochs = 3
17 batch_size = 6
18 train_nums = 15
19 iterations = math.ceil(train_nums/batch_size)
20 
21 # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续
22 dataset = tf.data.Dataset.from_tensor_slices((images, labels))
23 dataset = dataset.batch(batch_size).repeat()
24 
25 # 使用生成器make_one_shot_iterator和get_next取数据
26 iterator = dataset.make_one_shot_iterator()
27 next_iterator = iterator.get_next()
28 
29 with tf.Session() as sess:
30     for epoch in range(n_epochs):
31         for iteration in range(iterations):
32             cu_image_batch, cu_label_batch = sess.run(next_iterator)
33             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
34 
35 
36 # 结果如下:
37 images.shape=(15, 5, 5, 3), labels.shape=(15,)
38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
39 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
40 The 1 epoch, the 3 iteration, current batch is [12 13 14]
41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
42 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
43 The 2 epoch, the 3 iteration, current batch is [12 13 14]
44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
45 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
46 The 3 epoch, the 3 iteration, current batch is [12 13 14]

使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:

 1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
 2 The 1 epoch, the 1 iteration, current batch is [ 7  4 10  8  3 11]
 3 The 1 epoch, the 2 iteration, current batch is [ 0  2 12 13 14  5]
 4 The 1 epoch, the 3 iteration, current batch is [6 9 1]
 5 The 2 epoch, the 1 iteration, current batch is [ 6 14  7  9  3  8]
 6 The 2 epoch, the 2 iteration, current batch is [13  5 12  1 11  2]
 7 The 2 epoch, the 3 iteration, current batch is [ 0  4 10]
 8 The 3 epoch, the 1 iteration, current batch is [10  8 13 12  3 14]
 9 The 3 epoch, the 2 iteration, current batch is [ 6  9  2  5  1 11]
10 The 3 epoch, the 3 iteration, current batch is [0 4 7]

!!!

原文地址:https://www.cnblogs.com/jfl-xx/p/9945967.html