【tensorflow】搭建_手写数字识别_神经网络模型:Sequential() / 神经网络类class 两种方法

MNIST 数据集一共有 7 万张图片,都是 28x28 像素点的 0~9 手写数字,其中 6 万用于训练,1 万张用于测试。

f.keras + Sequential() 详解

代码:

import tensorflow as tf

# 读入训练所需的输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 输入特征归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 搭建网络
model = tf.keras.models.Sequential([
    # 将输入特征(28x28)拉直为一维数组(1x748)
    tf.keras.layers.Flatten(),
    # 定义第一层网络,有128个神经元
    tf.keras.layers.Dense(128, activation="relu"),
    # 定义第二层网络,有10个神经元
    tf.keras.layers.Dense(10, activation="softmax")
])

# 配置训练方法
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 执行训练过程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 打印出网络结构和参数统计
model.summary()

tf.keras + 神经网络类class 详解

代码:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

# 读取训练用的输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 输入特征归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 定义神经网络类
class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        # 定义拉直层
        self.flatten = Flatten()
        # 定义第一层神经网络
        self.d1 = Dense(128, activation="relu")
        # 定义第二层神经网络
        self.d2 = Dense(10, activation="softmax")

    def call(self, x):
        # 将输入特征拉直成一维数组
        x = self.flatten(x)
        # 调用剩下两层神经网络,实现前向传播
        x = self.d1(x)
        y = self.d2(x)
        return y

# 声明神经网络对象
model = MnistModel()

# 配置训练方法
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 执行训练过程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 打印网络结构和参数统计
model.summary()
原文地址:https://www.cnblogs.com/bjxqmy/p/13524576.html