第十五节 分布式系统

import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", " ", "启动服务的类型ps or worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的哪一台服务器以task:0,task:1")

def main(argv):
    # 定义一个全局计数的op,给钩子列表中的训练步数使用
    global_step = tf.contrib.framework.get_or_create_global_step()

    # 指定集群描述对象,ps worker,多台worker或者ps的定位规则,第一台:/job:worker/task:0,第二台:/job:worker/task:1,ps也是如此
    cluster = tf.train.ClusterSpec({"ps":["192.168.0.4:2222",], "worker":["192.168.109.128:2323",]})

    # 创建不同的服务 ps worker,job_name指定是ps还是worker,task_index,指定启动哪台服务器
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

    # 根据不同的服务器做不同的事情,ps保存参数,worker指定设备运行模型计算
    if FLAGS.job_name == 'ps':
        # 参数服务器只需接受参数
        server.join()
    else:
        worker_device = "/job:worker/task:0/cpu:0/"
        # 指定设备去运行
        with tf.device(tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster)):
            # 演示一个矩阵乘法运算
            x = tf.Variable([[1, 2, 3, 4]])
            w = tf.Variable([[2], [4], [5], [7]])
            mat = tf.matmul(x, w)

        # 创建分布式会话
        with tf.train.MonitoredTrainingSession(
                master="grpc://192.168.0.1:2222",  # 指定是否是主work
                is_chief=(FLAGS.task_index==0),  # 判断书否是主worker
                config=tf.ConfigProto(log_device_placement =True),  # 打印设备信息
                hooks=[tf.train.StopAtStepHook(last_step=1000)]  # 指定训练步数,指定步数需要定义一个全局计数的op
        ) as mon_sess:
            while not mon_sess.should_stop():
                # should_stops是否异常停止
                mon_sess.run(mat)

if __name__ == "__main__":
    tf.app.run()
原文地址:https://www.cnblogs.com/kogmaw/p/12602483.html