tf.train.batch and tf.train.shuffle_batch

这俩方法都是从队列中批量获取元素,常用于样本的批量获取;

这俩 API 非常反人类,有些参数我还没搞懂,时间关系,先学习常规用法吧

batch

从队列中获取指定个数的元素组成一个 batch

def batch(tensors, batch_size, num_threads=1, capacity=32,
          enqueue_many=False, shapes=None, dynamic_pad=False,
          allow_smaller_final_batch=False, shared_name=None, name=None):
  """Creates batches of tensors in `tensors`."""

tensors:队列

batch_size:获取元素个数

capacity:队列容量    【没搞懂有啥用】

label = np.asarray(range(0, 100))
# label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([label], shuffle=False)
label_batch = tf.train.batch(input_queue, batch_size=19, num_threads=1, capacity=5)

with tf.Session() as sess:
    coord = tf.train.Coordinator()      # 线程的协调器
    threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
    for j in range(8):
        out = sess.run([label_batch])
        print(out)
    coord.request_stop()
    coord.join(threads)
# [array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18])]
# [array([19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37])]
# [array([38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56])]
# [array([57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75])]
# [array([76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94])]
# [array([95, 96, 97, 98, 99,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])]
# [array([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32])]
# [array([33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])]

shuffle_batch

从队列中随机获取指定个数的元素组成一个 batch

def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                  num_threads=1, seed=None, enqueue_many=False, shapes=None,
                  allow_smaller_final_batch=False, shared_name=None, name=None):
  """Creates batches by randomly shuffling tensors."""

capacity:队列容量,这个参数一定要比 min_after_dequeue 大

推荐值为

  • capacit(min_after_dequeu(num_threada small safety margi∗ batcize

min_after_dequeue当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别;

定义了随机取样的缓冲区大小,此参数越大表示 更大级别的混合 但是 会导致启动更加缓慢,并且会占用更多的内存

images = np.random.random([100,2])
label = np.asarray(range(0, 100))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=10)

with tf.Session() as sess:
    coord = tf.train.Coordinator()      # 线程的协调器
    threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
    for _ in range(5):
        image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
        print(image_batch_v, label_batch_v)

    coord.request_stop()
    coord.join(threads)

注意,这俩 API 是持续获取数据的,也就是说可以在循环中重复执行,每次获取不同数据

参考资料:

https://blog.csdn.net/akadiao/article/details/79645221

https://blog.csdn.net/u013555719/article/details/77679964

原文地址:https://www.cnblogs.com/yanshw/p/12467753.html