机器学习——TensorFlow训练Y=2*X

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


def moving_average(a,w=10):
    if(len(a)<w):
        return a[:]
    return [val if idx <w else sum(a[idx-w:idx])/w for idx,val in enumerate(a)]
X_train=np.linspace(-1,1,100)
Y_train=2*X_train+np.random.randn(*X_train.shape)*0.3
# 训练数据

X=tf.placeholder('float')
Y=tf.placeholder('float')
# 占位符

W=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name='bias')
# 定义权重和偏置

z=tf.multiply(X,W)+b
# 定义前向结构

# 反向模型的搭建即反向优化
loss=tf.reduce_mean(tf.square(Y-z))

# 定义学习率:代表调整参数的速度,这个值一般小于一,这个值越大,表明调整幅度的速度越大,但不精确
# 这个值越小,调整幅度越小,但是速度慢
learn_rate=0.01

# 定义优化器:GridientDescentOptimizer梯度下降算法
optimizer=tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)

# 迭代训练模型,初始化所有变量
init=tf.global_variables_initializer()

# 定义训练次数
training_epochs=20

# 定义显示信息
display_step=2

with tf.Session() as sess:
    sess.run(init)
    plot_data={'batch_size':[],'loss_value':[]}
    for epoch in range(training_epochs):
        for x,y in zip(X_train,Y_train):
            sess.run(optimizer,feed_dict={X:x,Y:y})
        if epoch % display_step == 0:
            loss_value = sess.run(loss, feed_dict={X:x, Y:y})
            print('Epoch:', epoch + 1, 'Loss=', loss_value, 'w=',sess.run(W),'b=',sess.run(b))
            if not (loss == 'NA'):
                plot_data['batch_size'].append(epoch)
                plot_data['loss_value'].append(loss_value)
    print('Finished!')
    # 可视化模型
    plt.plot(X_train, Y_train, 'ro', label='Origin data')
    plt.plot(X_train, sess.run(W) * X_train + sess.run(b),label='FittedLine')
    plt.legend()
    plt.show()
    plot_data['avgloss']=moving_average(plot_data['loss_value'])
    plt.figure(1)
    plt.subplot(211)
    plt.plot(plot_data['batch_size'],plot_data['avgloss'],'b',linewidth=1.5)
    plt.xlabel('Minbatch number')
    plt.ylabel('Loss')
    plt.title('Minibatch run vs Training loss')
    plt.show()





在这里插入图片描述

在这里插入图片描述

原文地址:https://www.cnblogs.com/hzcya1995/p/13309446.html