【tensorflow】神经网络:断点续训

断点续训,即在一次训练结束后,可以先将得到的最优训练参数保存起来,待到下次训练时,直接读取最优参数,在此基础上继续训练。

读取模型参数:

存储模型参数的文件格式为 ckpt(checkpoint)。

生成 ckpt 文件时,会同步生成索引表,所以可通过判断是否存在索引表来判断是否存在模型参数。

# 模型参数保存路径
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)

保存模型参数:

# 定义回调函数,在模型训练时,回调函数会被执行,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(
  # 文件保存路径
  filepath=checkpoint_save_path,

  # 是否只保留模型参数
  save_weights_only=True,

  # 是否只保留最优结果
  save_best_only=True
)

# 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train,
            batch_size=32, epochs=5,
            validation_data=(x_test, y_test),
            validation_freq=1,
            callbacks=[cp_callback])

代码:

import tensorflow as tf
import os

# 读取输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 声明网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])

# 配置训练方法
model.compile(optimizer="adam",
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 如果存在参数文件,直接读取,在此基础上继续训练
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  # 模型参数保存路径
if os.path.exists(checkpoint_save_path + ".index"):
    model.load_weights(checkpoint_save_path)

# 定义回调函数,在模型训练时,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

# 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train,
                    batch_size=32, epochs=5,
                    validation_data=(x_test, y_test),
                    validation_freq=1,
                    callbacks=[cp_callback])

# 打印网络结构和参数
model.summary()
原文地址:https://www.cnblogs.com/bjxqmy/p/13538364.html