MNIST手写字母识别(二)

增加隐藏层

  1 import tensorflow as tf
  2 import matplotlib.pyplot as plt
  3 import numpy as np
  4 import tensorflow.examples.tutorials.mnist.input_data as input_data
  5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
  6 import os
  7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  8 
  9 print('训练集 train 数量:',mnist.train.num_examples,
 10       ',验证集 validation 数量:',mnist.validation.num_examples,
 11       ',测试集 test 数量:',mnist.test.num_examples)
 12 
 13 # print('train images shape:',mnist.train.images.shape,
 14 #       'labels shaple:',mnist.train.labels.shape)
 15 
 16 # print((mnist.train.images[0].reshape(28,28)))
 17 # print(len(mnist.train.images[0].shape))
 18 
 19 # def plot_image(image):
 20 #     plt.imshow(image.reshape(28,28))
 21 #     plt.show()
 22 #
 23 # plot_image(mnist.train.images[1])
 24 # plt.imshow(mnist.train.images[20000].reshape(14,56))
 25 # plt.show()
 26 
 27 # print(mnist.train.labels[1])
 28 # print(np.argmax(mnist.train.labels[1]))
 29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False)
 30 # print(mnist_no_one_hot.train.labels[0:10])
 31 #
 32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)
 33 #
 34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)
 35 #
 36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10)
 37 # print(mnist.train.labels[0:10])
 38 # # print(batch_labels_ys)
 39 
 40 # mnist中每张图片共有28*28=784个像素点
 41 x=tf.placeholder(tf.float32,[None,784])
 42 # 0-9一共10个数字->10个类别
 43 y=tf.placeholder(tf.float32,[None,10])
 44 
 45 # 定义模型变量(以正态分布的随机数初始化权重W,以常数0初始化偏置b)
 46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
 47 b=tf.Variable(tf.zeros([10]))
 48 
 49 # 前向计算
 50 forward=tf.matmul(x,W)+b
 51 #softmax分类
 52 pred=tf.nn.softmax(forward)
 53 # 定义交叉熵损失函数
 54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
 55 
 56 # 设置训练参数
 57 train_epochs=150 # 训练轮数
 58 batch_size=50 # 单次训练样本数(批次大小)
 59 total_batch=int(mnist.train.num_examples/batch_size) # 一轮训练的批次数
 60 display_step=1 # 显示粒度
 61 learning_rate=0.04 # 学习率
 62 
 63 # 分类模型构建与训练实践
 64 #选择优化器,梯度下降
 65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
 66 
 67 # 定义准确率,检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
 68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
 69 # 准确率,将布尔值转化为浮点数,并计算平均值
 70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 71 
 72 # 声明会话
 73 sess=tf.Session()
 74 init=tf.global_variables_initializer()
 75 sess.run(init)
 76 
 77 # 训练模型
 78 for epoch in range(train_epochs):
 79       for batch in range(total_batch):
 80             xs, ys = mnist.train.next_batch(batch_size)  # 读取批次数据
 81             sess.run(optimizer, feed_dict={x: xs, y: ys})  # 执行批次训练
 82 
 83       # total_batch批次训练完成之后,使用验证数据计算误差与准确率,验证集没有分批。
 84       loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
 85 
 86       # 打印训练过程中的详细信息
 87       if (epoch + 1)% display_step==0:
 88             print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc))
 89 print("train finished!")
 90 
 91 # 在测试集上评估模型准确率
 92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
 93 print("test accuracy:",accu_test)
 94 
 95 # 在验证集上评估模型准确率
 96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
 97 print("validatin accuracy:",accu_validation)
 98 
 99 # 在训练集上评估模型准确率
100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
101 print("tarin accuracy:",accu_train)
102 
103 # 由于pred预测结果是one-hot编码格式,所以需要转化为0~9数字
104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
105 
106 # 查看结果中的前十项
107 prediction_result[0:10]
108 # 定义可视化函数
109 
110 def plt_images_labels_prediction(images,  # 图像列表
111                                   labels,  # 标签列表
112                                   prediction,  # 预测值列表
113                                   index,  # 从第index个开始显示
114                                   num=10):  # 缺省依次显示10副
115       fig = plt.gcf()  # 获取当前图表,get current figure
116       fig.set_size_inches(10, 12)  # 1英寸等于2.54cm
117       if num > 25:
118             num = 25  # 最多显示25个子图
119       for i in range(0, num):
120             ax = plt.subplot(5, 5, i + 1)  # 获取当前要处理的子图
121 
122             ax.imshow(np.reshape(images[index], (28, 28)),
123                       cmap='binary')  # 显示第index个图像
124             title = "labels=" + str(np.argmax(labels[index]))  # 构建该图上要显示的title信息
125             if len(prediction) > 0:
126                   title += ",predict=" + str(prediction[index])
127 
128             ax.set_title(title)  # 显示图上的title
129             ax.set_xticks([])  # 不显示坐标轴
130             ax.set_yticks([])
131             index += 1
132       plt.show()
133 # 可视化预测结果
134 plt_images_labels_prediction(mnist.test.images,
135                              mnist.test.labels,
136                              prediction_result,10,25)
原文地址:https://www.cnblogs.com/hly97/p/12871079.html