tensorflow 2.0 学习(三)MNIST训练

用tensorflow2.0 版回顾了一下mnist的学习

代码如下,感觉这个版本下的mnist学习更简洁,更方便

关于tensorflow的基础知识,这里就不更新了,用到什么就到网上搜索相关的知识

# encoding: utf-8

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

#加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28)
path = r'G:2019pythonmnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
f.close()

#预处理输入数据
x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10)

#第一层输入256, 第二次输出128, 第三层输出10
#第一,二,三层参数w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))    #正态分布的一种
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

#将60000组数据切分为600组,每组100个数据
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100)
lr = 0.001      #学习率
losses = []     #储存每epoch的loss值,便于观察学习情况

for epoch in range(20):
    #一次性处理100组(x, y)数据
    for step, (x, y) in enumerate(train_db):    #遍历切分好的数据step:0->599
        with tf.GradientTape() as tape:
            #向前传播第一,二,三层
            h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])  #可以直接写成 +b1
            h1 = tf.nn.relu(h1)
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            out = h2@w3 + b3
            #计算mse
            loss = tf.square(y - out)
            loss = tf.reduce_mean(loss)
        #计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        #更新w,b
        w1.assign_sub(lr*grads[0])  #原地减去给定的值,实现参数的自我更新
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr*grads[2])
        b2.assign_sub(lr*grads[3])
        w3.assign_sub(lr*grads[4])
        b3.assign_sub(lr*grads[5])
        #观察学习情况
        if step%500 == 0:
            print(epoch, step, 'loss:', float(loss))
    #将每epoch的loss情况储存起来,最后观察
    losses.append(float(loss))

plt.plot(losses, marker='s', label='training')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()
plt.savefig('exam_mnist_forward.png') plt.show()

观察结果:

可由注释理解代码的含义!下一次更新mnist数据集训练的进阶!

原文地址:https://www.cnblogs.com/heze/p/12076792.html