TF利用分布式队列控制线程

假设分布式任务包含n个ps节点, m个worker节点. m, n>0. 希望所有worker的任务结束后,所有节点才终止。

  • 方法: 借助队列tf.FIFOQueue实现。
  • 原理: tf.FIFOQueue 是个全局的的队列, 出队函数dequeue有这个特点:
    If the queue is empty when this operation executes, it will block until there is an element to dequeue.
    利用这个性质, 设置ps服务器的停止条件:
    1. ps端执行m个出队列操作。 队列初始都是空队列, 因此,一开始出队操作都被阻塞。
    2. 每个worker完成任务后, 往ps的队列中放入一个元素,使得ps端的一个出队操作能执行完成。
  • 参考: https://github.com/hn826/distributed-tensorflow/blob/master/distributed-deep-mnist-with-queue.py

更新

  • 实际中, 可以定义全局变量, 通过判断全局变量状态控制终止条件。
* class GlobalStatus(object):
    def __init__(self):
        with tf.variable_scope("global_status", reuse=tf.AUTO_REUSE):
            self.status = tf.get_variable("status", (), trainable=False, 
                    dtype=tf.int32, initializer=tf.constant_initializer(0))
        self.send_op = self.status.assign(1)

    def change_status(self, sess):
        sess.run(self.send_op)

    def is_done(self, sess):
        z = sess.run(self.status)
        return z>0
原文地址:https://www.cnblogs.com/bregman/p/10735937.html