『TensorFlow』第十弹_队列&多线程_道路多坎坷

一、基本队列:

队列有两个基本操作,对应在tf中就是enqueue&dequeue

 tf.FIFOQueue(2,'int32')

import tensorflow as tf

'''FIFO队列操作'''

# 创建队列
# 队列有两个int32的元素
q = tf.FIFOQueue(2,'int32')
# 初始化队列
init= q.enqueue_many(([0,10],))
# 出队
x = q.dequeue()
y = x + 1
# 入队
q_inc = q.enqueue([y])

with tf.Session() as sess:
    init.run()
    for _ in range(5):
        v,_ = sess.run([x,q_inc])
        print(v)

 tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float')

'''随机队列操作'''

# 最大长度10,最小长度2,类型float的随机队列
q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float')

sess = tf.Session()
for i in range(0,10):
    sess.run(q.enqueue(i))
for i in range(0,8): # 在输出8次后会被阻塞
    print(sess.run(q.dequeue()))

#run_option = tf.RunOptions(timeout_in_ms = 10000) # 等待时间10s
#for i in range(0,7): # 在输出8次后会被阻塞
#    # 超时报错继续,不会退出
#    try:
#        print(sess.run(q.dequeue(),options=run_option))
#    except tf.errors.DeadlineExceededError:
#        print('out of range')

print('-----'*5)

二、队列管理:

tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)

'''队列管理器'''

# 队列管理器使用线程管理队列

q = tf.FIFOQueue(1000,'float')
counter = tf.Variable(0.0)  # 计数器
increment_op = tf.assign_add(counter, tf.constant(1.0))   # 计数器加一
enqueue_op = q.enqueue(counter)                           # 入队

# 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作
qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)


sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess,start=True)      # 启动入队线程
for i in range(10):
    print(sess.run(q.dequeue()))
    # 由于主线程和入队线程异步,所以输出不是自然数序列

出队操作还有Queu.dequeue_many(batch_size),如果入队时采用enqueue([image, label]),则可以实现队列数据参与训练。

tf.train.Coordinator()

'''协调器'''

q = tf.FIFOQueue(1000,'float')
counter = tf.Variable(0.0)  # 计数器
increment_op = tf.assign_add(counter, tf.constant(1.0))   # 计数器加一
enqueue_op = q.enqueue(counter)                           # 入队

# 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作
qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

coord = tf.train.Coordinator()
# 线程管理器启动线程,接收协调器管理
enqueue_thread = qr.create_threads(sess,coord=coord,start=True)

for i in range(0,10):
    print(sess.run(q.dequeue()))

coord.request_stop()            # 向各个线程发终止信号
coord.join(enqueue_thread)      # 等待各个线程成功结束
原文地址:https://www.cnblogs.com/hellcat/p/6941367.html