单变量线性回归问题(TensorFlow实战)

简单例子介绍Tensorflow实现机器学习的思路,重点步骤:

  1. 生成人工数据集及其可视化
  2. 构建线性模型
  3. 定义损失函数
  4. 定义优化器、最小损失函数
  5. 训练结果的可视化
  6. 利用学习到的模型进行预测
 1 import tensorflow as tf
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 
 5 np.random.seed(5)
 6 # 采用np生成等差数列,范围在-1~1之间生成100个点
 7 x_data=np.linspace(-1,1,100)
 8 # y=2x+1+噪声,其中噪声的维度和x_data一致
 9 y_data=2*x_data+1.0+np.random.randn(*x_data.shape)*0.4
10 
11 # 画图 y=2x+1
12 plt.scatter(x_data,y_data)
13 plt.plot(x_data,2*x_data+1,color='red',linewidth=3)
14 # plt.show()
15 
16 # 定义训练数据的占位符,x是特征值,y是标签值
17 x=tf.placeholder("float",name="x")
18 y=tf.placeholder("float",name="y")
19 
20 # 定义模型函数
21 def model(x,w,b):
22     return tf.multiply(x,w)+b
23 
24 # tf.Variable的作用时保存和更新函数
25 w=tf.Variable(1.0)      #斜率
26 b=tf.Variable(0.0)      #截距
27 pred=model(x,w,b)       #预测值
28 
29 # 迭代次数和学习率
30 train_epochs=10
31 learning_rate=0.05
32 
33 # 损失函数用来描述预测值与真实值之间的误差,从而指导模型收敛方向,均方差MSE
34 loss=tf.reduce_mean(tf.square(y-pred))
35 # 定义优化器optimizer,初始化一个GradientDescentOptimizer,设置学习率和优化目标,最小值损失
36 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
37 
38 sess=tf.Session()
39 init=tf.global_variables_initializer()
40 sess.run(init)
41 
42 # 模型训练阶段,设置迭代轮次,每次通过样本逐个输入模型,进行梯度下降优化操作
43 # 每次迭代,绘制出模型曲线
44 
45 # 开始训练,轮次为epoch,采用SGD随机梯度下降优化方法
46 step=0  #记录训练步数
47 loss_list=[]    #用于保存loss值的列表
48 
49 for epoch in range(train_epochs):
50     for xs,ys in zip(x_data,y_data):
51         _,losss=sess.run([optimizer,loss],feed_dict={x:xs,y:ys})
52     # 显示损失值loss
53     # display_step:控制报告的粒度
54     # 例如,如果display_step设为2,则将每训练2个样本输出一次损失值
55     # 与超参数不同,修改display_step 不会更改模型所学习的规律
56         loss_list.append(losss)
57         step=step+1
58         display_step=10
59         if step%display_step==0:
60             print("Train Epoch:",'%02d'%(epoch+1),"Step:%03d"%(step),"loss=","{:.9f}".format(losss))
61     b0temp=b.eval(session=sess)
62     w0temp=w.eval(session=sess)
63     plt.plot(x_data,w0temp*x_data+b0temp)
64 plt.plot(loss_list,'r+')
65 plt.show()
66 
67 # 训练完成后,打印查看参数
68 print("w:",sess.run(w))
69 print("b",sess.run(b))
70 
71 plt.scatter(x_data,y_data,label='Original data')
72 plt.plot(x_data,x_data*sess.run(w)+sess.run(b),label='Fitted line',color='r',linewidth=3)
73 plt.legend(loc=2)   #通过参数loc指定图例位置
74 # plt.show()
75 
76 x_test=3.21
77 predict=sess.run(pred,feed_dict={x:x_test})
78 print("预测值:%f"%predict)
79 
80 target=2*x_test+1.0
81 print("目标值%f"%target)
原文地址:https://www.cnblogs.com/hly97/p/12815086.html