P79 自实现一个线性回归

https://www.bilibili.com/video/BV184411Q7Ng?p=79

 

 

 

 

 

 

 

 代码示例:

"""
写一个线性回归
运用分布式计算机集群进行训练
即:运用多台计算机的cpu和gpu进行参数保存、模型训练和预测
"""
import tensorflow as tf

"""
定义一个命令行参数
"""
FLAGS=tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("job_name","",'启动服务的类型:ps or worker')

tf.app.flags.DEFINE_integer("task_index",0,'指定ps或者worker当中的哪一台服务器以task:0,task:1')


def main(argv):
    """
    定义全局技术的op,给钩子列表当中的训练步数使用
    """
    global_step = tf.train.get_or_create_global_step()



    """
    指定集群描述对象,哪些用于ps(parameter server),哪些用于worker server
    ps是参数服务器,后面是ip地址和端口号,端口号随便写的数字,
    worker 是计算服务器,后面跟的是电脑的网络ip地址和端口号,端口号随便写的数字
    返回一个集群对象
    """
    cluster=tf.train.ClusterSpec({"ps":["192.168.1.100:2223"], "worker":["192.168.1.101:2222"]})
    """
    创建不同的服务,ps,worker
    返回一个服务
    """
    server=tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)

    """
    根据不同的服务做不同的事情
    ps:去更新保存参数
    worker:指定设备去运行模型计算
    """

    if FLAGS.job_name=="ps":
        """
        参数服务器什么都不用做,只需要等待worker传递参数就行
        """
        server.join()
    else:
        """
        可以指定设备去运行
        可以用cpu或者gpu设备去运行训练的工作
        在指定设备的上下文环境进行运算训练,上下文环境是指:with...
        """
        worker_device="/job:worker/task:0/cpu:0/"
        with tf.device(tf.train.replica_device_setter(
            worker_device=worker_device,
            cluster=cluster
        )):
            """
            简单做一个矩阵乘法的运算
            x是一个1行4列的矩阵
            w是一个4行1列的矩阵
            """
            x=tf.Variable([[1,2,3,4]])
            w=tf.Variable([[2],[2],[2],[2]])
            mat=tf.matmul(x,w)
            """
            创建分布式会话
            """
            with tf.train.MonitoredTrainingSession(
                master="grpc://192.168.1.101:2222",#指定主worker
                is_chief=(FLAGS.task_index==0),#判断是否是主worker
                config=tf.ConfigProto(log_device_placement=True),#打印设备信息
                hooks=[tf.train.StopAtStepHook(last_step=200)]#指定运行200次
            ) as mon_sess:
                while not mon_sess.should_stop():
                    mon_sess.run(mat)














if __name__=="__main__":
    """
    tf.app.run() 默认调用main()函数
    """
    tf.app.run()

运行结果:

 注解:

  • 这段程序的结果是在dell工作站上运行出来的。
  • 开启参数服务器:.....2333,这个代表的是用于更新和保存参数的联想笔记本工作注解:

 注解:

  • 初始化工作服务器:.....2222,这个代表的是用于更新和保存参数的联想笔记本工作站。
原文地址:https://www.cnblogs.com/yibeimingyue/p/14197970.html