机器学习——TensorFlow之数字体识别流程

import tensorflow as tf
# 导入mnist数据集
# 分析mnist样本特点以及定义变量
# 构建模型
# 训练模型并输出中间状态参数
# 测试模型
# 保存模型
# 读取模型


# 导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

# 分析图片的特点,定义变量
x=tf.placeholder(tf.float32,shape=[None,784])
y=tf.placeholder(tf.float32,shape=[None,10])

# 构建模型
W=tf.Variable(tf.zeros([784,10]))

b=tf.Variable(tf.zeros([10]))

# z表示证据
z=tf.matmul(x,W)+b
# pred表示是每个数字的可能
pred=tf.nn.softmax(z)
# 损失函数,交叉熵,定义反向传播的结构
loss=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

learn_rate=0.01

# 优化器,梯度下降法
optimizer=tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)

# 训练次数
epochs=25

# 批次大小
batch_size=100

# 把中间具体信息显示出来
display_step=1

with tf.Session() as sess:
    # 初始化全局变量
    sess.run(tf.global_variables_initializer())
    # 开始训练
    for epoch in range(epochs):
        # 取值大小
        avg_loss=0
        total_loss=0
        total_batch=int(mnist.train.images.shape[0]/batch_size)
        for i in range(total_batch):
            # 从数据集中按照batch_size大小取值
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            # 运行优化器
            _,c=sess.run([optimizer,loss],feed_dict={x:batch_xs,y:batch_ys})
            # 计算损失值得平均值
            total_loss+=c
        avg_loss=total_loss/total_batch
        if((epoch+1)%display_step==0):
            print('Epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_loss))
    print('########################Finished!#############################
')
    # 测试模型
    print('########################Begin Test############################
')
    correct_predict=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    accuracy=tf.reduce_mean(tf.cast(correct_predict,tf.float32))
    print('Accuracy:',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
    print('########################Save Model############################
')
    saver= tf.train.Saver()
    save_path='log/'
    saver.save(sess,save_path)
    print('saved Successfully at :',save_path)
# 保存模型


在这里插入图片描述
在这里插入图片描述

原文地址:https://www.cnblogs.com/hzcya1995/p/13309445.html