【tensorflow】使用_Fashion数据集_搭建神经网络模型:Sequential() / 神经网络类class 两种方法

FASHION 数据集一共有 7 万张图片,每张图片都是 28x28 像素点的灰度值数据,其中 6 万张用于训练,1 万张用于测试。

一共有 10 个分类:

0 T恤

1 裤子

2 帽头衫

3 连衣裙

4 外套

5 凉鞋

6 衬衫

7 运动鞋

8 包

9 靴子

f.keras + Sequential() 详解

代码:

import tensorflow as tf

# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()

# 输入特征归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 声明网络结构
model = tf.keras.models.Sequential([
    # 拉直层
    tf.keras.layers.Flatten(),
    # 两层全连接层
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])

# 配置训练方法(优化器,损失函数,评测方法)
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 执行训练过程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 打印网络结构和参数统计
model.summary()

tf.keras + 神经网络类class 详解

代码:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()

# 输入特征归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 定义神经网络类
class FashionModel(Model):
    # 定义网络结构
    def __init__(self):
        super(FashionModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation="relu")
        self.d2 = Dense(10, activation="softmax")

    # 调用网络结构,实现前向传播
    def call(self, inputs, training=None, mask=None):
        x = self.flatten(inputs)
        x = self.d1(x)
        y = self.d2(x)
        return y

# 声明神经网络对象
model = FashionModel()

# 配置训练方法(优化器,损失函数,评测指标)
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 执行训练过程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 打印网络结构和参数
model.summary()
原文地址:https://www.cnblogs.com/bjxqmy/p/13527573.html