Tensorflow v2 创建网络模型且保存参数至本地(非keras)

//20201030

写在前面:最近几天在在学习Tensorflow v2框架搭建网络,今天在这里做一下summary,主要简述一下搭建的大致流程以及需要的要素,最后就是如何存储以及读取存储恢复网络

1.导包

(此处因为做了可视化以及使用mnist当做数据集,所以使用了matplotlib和keras)

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras

2.数据准备

此处直接使用mnist数据集

(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
x_train,x_test = np.array(x_train,np.float32),np.array(x_test,np.float32)

x_train = x_train/255.0# 将数据缩小,减小计算量
x_test = x_test/255.0

training_data = tf.data.Dataset.from_tensor_slices((x_train,y_train))
training_data = training_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)# 此行的意思是将数据变成数据池,且打乱顺序,一次取出batch_size个数据,且在去除是准备下次所需数据————类似流

如果因为墙的原因下载不了数据集,可单独下载数据集,然后使用如下方法解析

数据集链接(使用迅雷下载会很快):https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

解析方法:

import numpy as np

def load_data(path):
    data_set = np.load(path)
    file_name_list = data_set.files
    dict = {}
    for each in file_name_list:
        dict[each] = data_set[each]
    return dict

3.设置网络(此处使用卷积网络)所需参数

num_classes = 10# 类别数,因为此网络用于分类0-9十个数字,所以共有10个类别
num_features = 28*28# 特性数,文件集中图片为28*28灰度图像,所以特征数为28*28

learning_rate = 0.001# 学习率,网络将使用Adam优化器,此处设置学习率为0.001(可改)
training_step = 1000# 使用多少数据训练,此处设置1000,数据集总共有60000张图片(可自行打印shape查看)
display_step = 100# 手动打印损失函数与准确率频率参数
batch_size = 256# 数据预处理时使用
fc_utils = 1024# fully connection layer单元数

4.继承keras.Model重写网络模型

class ConvNet(keras.Model):
    def __init__(self):
        super(ConvNet,self).__init__()
        self.conv1 = keras.layers.Conv2D(32,kernel_size=5,activation='relu')
        self.mp1 = keras.layers.MaxPool2D(2,strides=2)
        self.conv2 = keras.layers.Conv2D(64,kernel_size = 3,activation='relu')
        self.mp2 = keras.layers.MaxPool2D(2,strides=2)
        self.flatten = keras.layers.Flatten()
        self.fc = keras.layers.Dense(fc_utils)
        self.dropout = keras.layers.Dropout(rate = 0.5)
        self.out = keras.layers.Dense(num_classes)

    def call(self,x,training = False):
        x =  tf.reshape(x,[-1,28,28,1])
        x = self.conv1(x)
        x = self.mp1(x)
        x = self.conv2(x)
        x = self.mp2(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.dropout(x,training = training)
        x = self.out(x)
        if not training:
            x = tf.nn.softmax(x)
        return x

此处网络模型为 [ 卷积层_32个过滤器——5个核心---->池化层_2x2——步长为2----->卷积层_64个过滤器_3个核心----->池化层_2x2步长为2----->flatten层_将数据拉平----->拥有1024个单元的全连接层----->dropout层(提升网络稀疏性,提升速度以及准确性)----->out层(输出)------>(如果不是在训练而是在预测,需要加一个softmax层来输出预测值)]

5.定义交叉熵函数(用于计算损失函数)

def cross_entropy(y_pred,y_true):
    y_true = tf.cast(y_true,tf.int64)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,logits=y_pred)
    return tf.reduce_mean(loss)

6.定义准确值函数(用于计算acc)

def accuracy(y_pred,y_true):
    correct_prediction = tf.equal(tf.argmax(y_pred,1),tf.cast(y_true,tf.int64))
    return tf.reduce_mean(tf.cast(correct_prediction,tf.float32),axis = -1)

7.定义运行优化器函数

optimizer = tf.optimizers.Adam(learning_rate)

def run_optimization(x,y):
    with tf.GradientTape() as g:
        pred = conv(x,True)
        loss = cross_entropy(pred,y)
        gradients = g.gradient(loss,conv.trainable_variables)
        optimizer.apply_gradients(zip(gradients,conv.trainable_variables))

8.开始优化

for steps,(batch_x,batch_y) in enumerate(training_data.take(training_step),1):
    run_optimization(batch_x,batch_y)
    if steps%display_step==0:
        pred = conv(batch_x,training=False)
        loss = cross_entropy(pred,batch_y)
        acc = accuracy(pred,batch_y)
        print("step:{}---->loss:{}---->acc:{}".format(steps,loss,acc))

输出如下图

 9.可视化

此处使用测试数据集中前25个数据进行可视化

test_data  = x_test[:25]
label = y_test[:25]
fig,ax = plt.subplots(5,5)
plt.subplots_adjust(wspace=1,hspace=1)
ax = ax.flatten()
pred = conv(test_data,False)

for i in range(25):
    ax[i].imshow(test_data[i],cmap='Greys')
    ax[i].set_title("pred:{},true:{}".format(np.argmax(pred[i]),label[i]))

plt.show()

可视化结果如下

10.存储模型权重

conv.save_weights('./tfmodel.ckpt')

路径可自定义,但必须是.ckpt文件

11.读取权重参数恢复网络

conv = ConvNet()# 定义一个空网络(此网络必须和恢复权重网络相同)
conv.load_weights('./tfmodel.ckpt')

ps:恢复权重参数后,网络就是训练后的状态,可以直接用于预测或者进一步训练

以上

希望对大家有所帮助

原文地址:https://www.cnblogs.com/lavender-pansy/p/13902383.html