使用预先训练好的模型来训练小训练集tf-keras实现

> //20201104

> 写在前面:最近在练手keras的相关项目,今天做了一个使用预训练模型训练小数据集的项目,在此记录总结一下

> ps:本门采用markdown语法,可用markdown文档编辑器打开

### 首先解释一下名词:

- 预训练模型(pre-trained model):在网络上别人(或者团队)预先花费很多时间在大数据集上训练并上传权重参数的网络(模型)

- 小数据集合(small data set):现实生活中,更多的时候并没有那么多数据(经过标注)用于训练,这个时候如何在小数据集上做出高效的预测就显得异常重要

> 需要注意的是:别人训练过的模型并不能直接拿来就重新训练的,使用他人训练好的的模型(此处拿CNN举例),一般使用网络中用于提取特征的层(layer),往往最后的全连接层以及密集层并不使用(因为最后几层所代表的信息可能只针对该网络当时训练的数据集——不具有普适性,而浅层的卷积层用于提取特征,具有更普遍的用途;另,如果是非常深的网络,高层的卷积层提取的特征会更抽象,往往在训练自己的小数据集也不使用——比如猫狗分类任务中,高层卷积层提取的特征可能是耳朵眼睛鼻子之类很具体的东西,而浅层则是提取颜色纹理之类普遍的特征,这样的预训练网络层用在大象分类任务中明显就不合适[举个栗子~])

### 本文使用的数据集是猫狗数据集,下载链接为(来源_kaggle):https://www.kaggle.com/c/dogs-vs-cats/data

- 本文有三种使用预训练模型的方法

  - 通过卷积基底提取特征然后保存为numpy矩阵喂给后续自定义层次

  - 将卷积基底层冻结,然后在其后拼接自定义层次

  - 微调:将高层基底层解冻,跟随训练集重新训练

### 第一种方法——特征提取并保存为numpy矩阵_缺点:不能使用图像增强,优点:快

#### 1.数据准备阶段

- 首先在目录下创建一个data文件夹,将下载的压缩包在此文件夹下解压,然后执行以下代码(代码目的是将源数据中随机选择需要训练数目的数据,并将其分类到新的目录)来准备数据

import os, shutil

# 專案的根目錄路徑
ROOT_DIR = os.getcwd()

# 置放coco圖像資料與標註資料的目錄
DATA_PATH = os.path.join(ROOT_DIR, "data")

# 原始數據集的路徑
original_dataset_dir = os.path.join(DATA_PATH, "train")

# 存儲小數據集的目錄
base_dir = os.path.join(DATA_PATH, "cats_and_dogs_small")
if not os.path.exists(base_dir): 
    os.mkdir(base_dir)

# 我們的訓練資料的目錄
train_dir = os.path.join(base_dir, 'train')
if not os.path.exists(train_dir): 
    os.mkdir(train_dir)

# 我們的驗證資料的目錄
validation_dir = os.path.join(base_dir, 'validation')
if not os.path.exists(validation_dir): 
    os.mkdir(validation_dir)

# 我們的測試資料的目錄
test_dir = os.path.join(base_dir, 'test')
if not os.path.exists(test_dir):
    os.mkdir(test_dir)    

# 貓的圖片的訓練資料目錄
train_cats_dir = os.path.join(train_dir, 'cats')
if not os.path.exists(train_cats_dir):
    os.mkdir(train_cats_dir)

# 狗的圖片的訓練資料目錄
train_dogs_dir = os.path.join(train_dir, 'dogs')
if not os.path.exists(train_dogs_dir):
    os.mkdir(train_dogs_dir)

# 貓的圖片的驗證資料目錄
validation_cats_dir = os.path.join(validation_dir, 'cats')
if not os.path.exists(validation_cats_dir):
    os.mkdir(validation_cats_dir)

# 狗的圖片的驗證資料目錄
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
if not os.path.exists(validation_dogs_dir):
    os.mkdir(validation_dogs_dir)

# 貓的圖片的測試資料目錄
test_cats_dir = os.path.join(test_dir, 'cats')
if not os.path.exists(test_cats_dir):
    os.mkdir(test_cats_dir)

# 狗的圖片的測試資料目錄
test_dogs_dir = os.path.join(test_dir, 'dogs')
if not os.path.exists(test_dogs_dir):
    os.mkdir(test_dogs_dir)
    
# 複製前1000個貓的圖片到train_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_cats_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)

print('Copy first 1000 cat images to train_cats_dir complete!')

# 複製下500個貓的圖片到validation_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1000, 1500)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_cats_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)

print('Copy next 500 cat images to validation_cats_dir complete!')

# 複製下500個貓的圖片到test_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1500, 2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(test_cats_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)

print('Copy next 500 cat images to test_cats_dir complete!')

# 複製前1000個狗的圖片到train_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_dogs_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)

print('Copy first 1000 dog images to train_dogs_dir complete!')


# 複製下500個狗的圖片到validation_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1000, 1500)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_dogs_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)

print('Copy next 500 dog images to validation_dogs_dir complete!')

# C複製下500個狗的圖片到test_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1500, 2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(test_dogs_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)
    
print('Copy next 500 dog images to test_dogs_dir complete!')

#### 2.导入相应的包

import os
import tensorflow  as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from IPython.display import Image
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers

#### 3.从网络上加载keras内置VGG16模型(需kexue上网)

conv_base = keras.applications.VGG16(
    weights='imagenet',
    include_top = False,# 这里告诉keras我们只需要卷积基底的参数
    input_shape=(150,150,3)
)

print(conv_base.summary())

#### 4.创建训练数据生成器并使用卷及基层提取特征

datagen = ImageDataGenerator(rescale=1./255)

batch_size = 20

def extract_features(directory,sample_count):
    features = np.zeros(shape=(sample_count,4,4,512))
    labels = np.zeros(shape=(sample_count))

    generator = datagen.flow_from_directory(
        directory,
        target_size=(150,150),
        batch_size=batch_size,
        class_mode='binary'# 因为此次训练目标是二分类问题
    )

    i = 0
    for inputs_batch,labels_batch in generator:
        features_batch = conv_base.predict(inputs_batch)# 此处将卷积层输出的矩阵保存至本地,让需要分类的图片通过已经训练过的卷积基底层
        features[i*batch_size:(i+1)*batch_size] = features_batch
        labels [i*batch_size:(i+1)*batch_size] = labels_batch
        i += 1
        if i*batch_size >= sample_count:
            break

    print('extract_features complete!')
    return features,labels

'''
如果执行过之前数据准备代码,则以下路径配置代码可以注释掉(属于重复代码)
''' base_dir
= 'data/cats_and_dogs_small' train_dir = os.path.join(base_dir, 'train') validation_dir = os.path.join(base_dir, 'validation') test_dir = os.path.join(base_dir, 'test') train_features,train_labels = extract_features(train_dir,2000) validation_features,validation_labels = extract_features(validation_dir,1000) test_features,test_labels = extract_features(test_dir,1000) train_features = np.reshape(train_features,(2000,4*4*512)) validation_features = np.reshape(validation_features,(1000,4*4*512)) test_features = np.reshape(test_features,(1000,4*4*512))

#### 5.使用keras序列模型创建卷积基底层后的层(用于输出)

model = models.Sequential([
    layers.Dense(256,activation='relu',input_dim=4*4*512),
    layers.Dropout(rate=0.5),
    layers.Dense(1,activation='sigmoid')
])

#### 6.编译&训练模型

model.compile(optimizer=optimizers.RMSprop(lr = 2e-5),
              loss = 'binary_crossentropy',
              metrics=['acc'])

history = model.fit(
    train_features,train_labels,
    epochs = 30,
    batch_size=20,
    validation_data=(validation_features,validation_labels)
)

#### 7.可视化(此处可视化使用子图方式将本文三个模块_第一种方法、第二种方法、微调、平滑后图像集合在一张最后的图上了,可以自行改为每运行一个模块展示一次图片(提示:如果没有gpu,后两个模块运行的会非常的慢))

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epoches = range(len(acc))

fig,ax = plt.subplots(2,4)
plt.subplots_adjust(wspace=1,hspace=1)
ax = ax.flatten()
ax[0].plot(epoches,acc,label='Training acc')
ax[0].plot(epoches,val_acc,label = 'Validation acc')
ax[0].set_title('Training and validation accuracy')
ax[0].legend()


ax[1].plot(epoches,loss,label='Training loss')
ax[1].plot(epoches,val_loss,label='Validation loss')
ax[1].set_title('Training and validation loss')
ax[1].legend()

### 第二种方法(网络拼接,统一训练_需要冻结基底层)_优点:可以使用图像增强,缺点:慢

#### 1.使用kera序列模型创建model:

model = models.Sequential([
    conv_base,
    layers.Flatten(),
    layers.Dense(256,activation='relu'),
    layers.Dense(1,activation='sigmoid')
])
#输出冻结之前需要训练的参数
print('This is the number of trainable weights before freezing the conv base:',len(model.trainable_weights))
# 输出冻结之后需要训练的参数
conv_base.trainable = False
print('this is the number of trainable wights after freezing the conv base',len(model.trainable_weights))

#### 2.定义训练、测试、交叉数据集&数据流

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=40,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)
# 测试数据集不用图像增强
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150,150),
    batch_size=20,
    class_mode='binary'
)

validation_genetator = test_datagen.flow_from_directory(
    validation_dir,
    target_size=(150,150),
    batch_size=20,
    class_mode='binary'
)

#### 3.编译&训练&保存模型

model.compile(optimizer=optimizers.RMSprop(lr = 2e-5),
              loss='binary_crossentropy',
              metrics=['acc'])

model.fit_generator(
    train_generator,
    steps_per_epoch=100,
    epochs=30,
    validation_data=validation_genetator,
    validation_steps=50,
    verbose=2
)
# 已在目录下创建一个weight目录用于保存权重
model.save('./weight/cats_anddogs_small_3.h5')

#### 4.可视化

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

ax[2].plot(epoches,acc,label='Training acc')
ax[2].plot(epoches,val_acc,label = 'Validation acc')
ax[2].set_title('Training and validation accuracy')
ax[2].legend()

ax[3].plot(epoches,loss,label='Training loss')
ax[3].plot(epoches,val_loss,label='Validation loss')
ax[3].set_title('Training and validation loss')
ax[3].legend()

### 微调(fine-tune)

- 在使用前两种方法训练完网络一次之后(保证自定义密集、输出层的参数误差不会很大),将高层卷积层(提取抽象特征)解冻,重新跟随训练集微调参数(使用小学习率来保证“微调”)

#### 1.选择并启动需要解冻的层次(使用层次名称作为索引)

conv_base.trainable = True

layers_frozen = ['block5_conv1','block5_conv2','block5_conv3','block5_pool']# go get the layer will be frozen
for layer in conv_base.layers:
    if layer.name in layers_frozen:
        layer.trainable = True
    else:
        layer.trainable = False
for layer in conv_base.layers:
    print("{}:{}".format(layer.name,layer.trainable))

#### 2.编译&训练模型(使用第二种方法训练后的模型重新对数据集进行训练)

model.compile(loss='binart_crossentropy',
              optimizer=optimizers.RMSprop(lr=1e-5),
              metrics=['acc'])

history = model.fit_generator(
    train_generator,
    steps_per_epoch=100,
    epochs = 100,
    validation_data = validation_genetator,
    validation_steps = 50
)

model.save('cats_and_dogs_small_4.h5')

#### 3.可视化

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

ax[4].plot(epoches,acc,label='Training acc')
ax[4].plot(epoches,val_acc,label = 'Validation acc')
ax[4].set_title('Training and validation accuracy')
ax[4].legend()

ax[5].plot(epoches,loss,label='Training loss')
ax[5].plot(epoches,val_loss,label='Validation loss')
ax[5].set_title('Training and validation loss')
ax[5].legend()

#### 4.步骤3中可视化图像汇成锯齿状,此处使用一个平滑方法来是图像更整洁(原理没有搞懂,公式是0.8previous+0.2point,来源于github项目

def smooth_curve(points,factor = 0.8):
    smoothed_points = []
    for point in points:
        if smoothed_points:
            previous = smoothed_points[-1]
            smoothed_points.append(previous*factor + point*(1-factor))
        else:
            smoothed_points.append(point)
    return smoothed_points

ax[6].plot(epoches,smooth_curve(acc),label='Training acc')
ax[6].plot(epoches,smooth_curve(val_acc),label = 'Validation acc')
ax[6].set_title('Training and validation accuracy')
ax[6].legend()

ax[7].plot(epoches,smooth_curve(loss),label='Training loss')
ax[7].plot(epoches,smooth_curve(val_loss),label='Validation loss')
ax[7].set_title('Training and validation loss')
ax[7].legend()

plt.show()

#### 5.输出测试数据集准确率

test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')

test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)
print('test acc:', test_acc)

### 整个项目图像汇总:

另:由于图像尚未保存,明天更新加上

以上

希望对大家有所帮助

原文地址:https://www.cnblogs.com/lavender-pansy/p/13928983.html