【tensorflow】神经网络:数据集增强

数据增强可以帮助扩展数据集。对图像的增强,就是对图像的简单形变,用来应对因拍照角度不同而引起的图片变形。

 

数据增强函数

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
    # 调整输入特征大小,每个输入特征将乘以该参数
    rescale=1.0 / 255,  # 归一化

# 图片将在[-45°, 45°]范围内做随机旋转 rotation_range=45,
# 图片将在[-0.15, 0.15]范围内做随机左右偏移,大小保持不变 width_shift_range=0.15,
# 图片将在[-0.15, 0.15]范围内做随机上下偏移,大小保持不变 height_shift_range=0.15,
# 是否做水平翻转操作 horizontal_flip=False,
# 图片将做[0.75, 1.25]范围内做随机缩放,大小保持不变 zoom_range=0.25 ) # x_train 需要是4维数据 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) image_gen_train.fit(x_train)

 

代码:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 加载输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化处理,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 数据集增强
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train = ImageDataGenerator(
    rotation_range=45,        # 随机旋转45°
    width_shift_range=0.15,   # 宽度偏移
    height_shift_range=0.15,  # 高度偏移
    horizontal_flip=False,    # 不水平翻转
    zoom_range=0.5            # 随机缩放
)
image_gen_train.fit(x_train)

# 声明网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])

# 配置训练方法
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 执行训练过程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 打印网络结构和参数
model.summary()
原文地址:https://www.cnblogs.com/bjxqmy/p/13536660.html