Mnist字符识别-神经网络实现(TF框架)

Mnist字符识别-神经网络实现(TF框架)

该段代码即贴即用,先贴一下代码,有空的时候写个注释解析。大三的代码了,特别适合新手入门,现在都用Pytorch了。

电脑用的tensorflow版本是1.13.1的,用CPU跑也挺快的。之前用GPU跑了半小时准确率能达到98%左右。

代码

# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot
import matplotlib.pyplot as plt
import numpy as np

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

seed=547
np.random.seed(seed)

epoch_time = 20;
ALPHY = 0.5
batch_size = 10

n_batch_all = mnist.train.num_examples // batch_size
n_batch = 1000 // batch_size

x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

W1 =tf.Variable(xavier_init([784, 30]))
B1 = tf.Variable(tf.zeros([30]))

L1 =  tf.nn.sigmoid(tf.matmul(x,W1) + B1)

W2 =tf.Variable(xavier_init([30, 10]))
B2 = tf.Variable(tf.zeros([10]))

logit_prediction = tf.matmul(L1,W2) + B2
prediction = tf.nn.sigmoid(logit_prediction)
# MSE损失函数
# loss = tf.reduce_mean(tf.square(y - prediction))

#交叉熵损失函数
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit_prediction,labels=y)

train_setup = tf.train.GradientDescentOptimizer(ALPHY).minimize(loss)

init = tf.global_variables_initializer()

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

def getBatch(inputs):
    np.random.shuffle(inputs)
    batch = inputs[:10]
    fina_x = batch[:, :784]
    fina_y = batch[:, 784:794]
    return fina_x, fina_y

def draw(train, text):
    names = range(0, epoch_time)
    names = [str(x) for x in list(names)]
    x = range(len(names))


    plt.plot(x, train, marker='o', mec='r', mfc='w', label='train_1000')
    plt.plot(x, text, marker='*', ms=10, label='train_all')
    plt.legend()  
    plt.xticks(x, names, rotation=1)
    plt.margins(0)
    plt.subplots_adjust(bottom=0.10)
    plt.xlabel('epoch')  
    plt.ylabel("accuracy")  
    pyplot.yticks([0, 0.5, 1])
    # plt.title("A simple plot") 
    plt.savefig('accuracy.jpg', dpi=900)

def train_1000():
    sess.run(init)
    train = tf.zeros(epoch_time)
    # batch_xs_all, batch_ys_all = mnist.train.next_batch(1000);
    # print("X shape:", batch_xs_all.shape)
    # print("Y shape:", batch_ys_all.shape)
    X_mb, Y_mb = mnist.train.next_batch(1000)
    Y_mb = Y_mb.astype(np.float32)
    inputs = tf.concat(axis=1, values=[X_mb, Y_mb])
    inputs = inputs.eval(session=sess)
    train = train.eval(session=sess)
    for epoch in range(epoch_time):
        for batch in range(n_batch):
            fina_x, fina_y = getBatch(inputs)
            # batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_setup, feed_dict={x: fina_x, y: fina_y})
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        train[epoch] = acc;
        print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
    return train;

def train_all():
    sess.run(init)
    text = tf.zeros(epoch_time)
    text = text.eval(session=sess)
    for epoch in range(epoch_time):
        for batch in range(n_batch_all):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_setup, feed_dict={x: batch_xs, y: batch_ys})
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        text[epoch] = acc;
        print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))

with tf.Session() as sess:
    p1 = train_1000();
    p2 = train_all();
    draw(p1, p2)

结果

Iter0, Testing Accuracy=0.0982 
Iter1, Testing Accuracy=0.2913 
Iter2, Testing Accuracy=0.2973 
Iter3, Testing Accuracy=0.3493 
Iter4, Testing Accuracy=0.4311 
Iter5, Testing Accuracy=0.3789 
Iter6, Testing Accuracy=0.49   
Iter7, Testing Accuracy=0.4547 
Iter8, Testing Accuracy=0.4079 
Iter9, Testing Accuracy=0.4748 
Iter10, Testing Accuracy=0.564 
Iter11, Testing Accuracy=0.5026
Iter12, Testing Accuracy=0.6053
Iter13, Testing Accuracy=0.6379
Iter14, Testing Accuracy=0.5863
Iter15, Testing Accuracy=0.6443
Iter16, Testing Accuracy=0.6487
Iter17, Testing Accuracy=0.5809
Iter18, Testing Accuracy=0.6616
Iter19, Testing Accuracy=0.6465
Iter0, Testing Accuracy=0.7625 
Iter1, Testing Accuracy=0.864  
Iter2, Testing Accuracy=0.8596 
Iter3, Testing Accuracy=0.8694 
Iter4, Testing Accuracy=0.9028 
Iter5, Testing Accuracy=0.9046 
Iter6, Testing Accuracy=0.902  
Iter7, Testing Accuracy=0.9021 
Iter8, Testing Accuracy=0.8874 
Iter9, Testing Accuracy=0.9192 
Iter10, Testing Accuracy=0.9175
Iter11, Testing Accuracy=0.9226
Iter12, Testing Accuracy=0.9233
Iter13, Testing Accuracy=0.9156
Iter14, Testing Accuracy=0.93  
Iter15, Testing Accuracy=0.9251
Iter16, Testing Accuracy=0.9232
Iter17, Testing Accuracy=0.9176
Iter18, Testing Accuracy=0.9287
Iter19, Testing Accuracy=0.9273
原文地址:https://www.cnblogs.com/lwp-nicol/p/15262544.html