tensorflow 加载预训练模型进行 finetune 的操作解析

这是一篇需要仔细思考的博客;

预训练模型

tensorflow 在 1.0 之后移除了 models 模块,这个模块实现了很多模型,并提供了部分预训练模型的权重;

图像识别模型的权重下载地址  https://github.com/tensorflow/models/tree/master/research/slim

模型加载

首先需要了解模型保存的形式,包含了 checkpoint、data、meta 等文件;

模型加载不仅可以从 data 加载训练好的权重,还可以从 meta 加载计算图,

加载计算图我们可以理解为引入了 计算节点和变量,引入变量很重要,这样我们无需自己去创造变量,

加载计算图返回的是个 Saver 对象,如果没有通过 加载图引入变量,也没有自己创造变量,是无法创建 Saver 对象的,没有 Saver 对象,就无法加载预训练权重;

下面我用代码解释上面的逻辑;

首先做个数据准备:写个最简单的计算图,然后保存

with tf.name_scope('scope1'):
    w1 = tf.Variable(1, name='w1')
    v1 = tf.Variable(2, name='v1')

with tf.name_scope('scope2'):
    w2 = tf.Variable(3, name='w2')
    v2 = tf.Variable(4, name='v2')

out = tf.add(w1*v1, w2*v2)
init = tf.global_variables_initializer()

saver = tf.train.Saver()

sess = tf.Session()
sess.run(init)
print(sess.run(out))
saver.save(sess, 'data/test.ckpt')

记住这里 w1 的值为 1,后面有用

加载预训练权重有两种思路

1. 先加载计算图,再加载权重

2. 自己创建计算图,再加载权重

总结一句话就是 先有变量(创建图就必须创建变量),然后创建 Saver 对象,通过 Saver 对象加载权重

先加载计算图,再加载权重

加载计算图用下面的方法

def import_meta_graph(meta_graph_or_file,
                      clear_devices=False,
                      import_scope=None,
                      **kwargs):
  """Recreates a Graph saved in a `MetaGraphDef` proto."""

加载计算图后,通过 sess.graph 获取图,通过 get_tensor_by_name 等方法获取保存的变量

简单举例

## 加载了图中的各个节点,相当于引入变量
saver = tf.train.import_meta_graph('data/test.ckpt.meta')       # 从 meta 文件直接加载图,返回一个存储器
print(type(saver))      # <class 'tensorflow.python.training.saver.Saver'>

### 如果没有变量,直接创建存储器,会报错的
# saver = tf.train.Saver()      ### 报错 ValueError: No variables to save

sess = tf.Session()
saver.restore(sess, 'data/test.ckpt')   ### 获取图权重
# print(sess.run('w1'))           ### 直接这样报错

graph = sess.graph  ### 获取加载的图
print(sess.run(graph.get_tensor_by_name('scope1/w1:0')))        # 1     ### 获取图的 Tensor
# print(sess.run(graph.get_tensor_by_name('w1:0')))       ### 报错 KeyError: "The name 'w1:0' refers to a Tensor which does not exist. The operation, 'w1', does not exist in the graph."
sess.close()

自己创建计算图,再加载权重

注意,自己创建的计算图要与保存的计算图一致,包括 网络结构、作用域、变量名等

##### 上面是直接加载图,引入变量,从而创建存储器
##### 这里我们不加载,自己创建一个图,从而创建存储器
### 注意,图的结构要与 checkpoint 中的一致,作用域、变量名 都要一样
with tf.name_scope('scope1'):
    w1 = tf.Variable(333, name='w1')
    v1 = tf.Variable(2, name='v1')

with tf.name_scope('scope2'):
    w2 = tf.Variable(3, name='w2')
    v2 = tf.Variable(4, name='v2')

init = tf.global_variables_initializer()

saver = tf.train.Saver()

sess = tf.Session()
sess.run(init)
saver.restore(sess, 'data/test.ckpt')

graph = tf.get_default_graph()      ### 获取默认图
print(sess.run(graph.get_tensor_by_name('scope1/w1:0')))        # 1

这里我们创建计算图是把 w1 赋值 333,而加载的 w1 仍然是保存时的 1,说明加载成功了

加载预训练权重进行 finetune

finetune 我会单独写一篇博客,这里不细讲,它大致可以分两种:

1. 加载部分权重,其他权重正常初始化,然后优化所有参数;

2. 加载部分权重,其他权重正常初始化,然后优化非加载的参数,相当于固定部分权重,优化另一部分权重;

需要加载的权重 肯定通过 预训练模型获取,而其他权重则既可以通过预训练的模型获取,也可以自己创建,这点在上一章已经讲清楚了,下面我们为了方便,就直接加载计算图了;

固定部分权重是个难点,它的思路有两点:

1. 先加载这部分权重,然后把这部分权重经过前向计算,得到一个新的 Input_new,然后把这个 Input 作为后续网络的输入,反向传播时传到 Input 肯定就停止了;

2. 分别加载需要训练的参数 train_var 和不需要训练的参数 fixed_var,然后在 优化器的 optimizer.minimize 方法中指定 var_list 为 train_var

def minimize(self, loss, global_step=None, var_list=None,
               gate_gradients=GATE_OP, aggregation_method=None,
               colocate_gradients_with_ops=False, name=None,
               grad_loss=None):
    """Add operations to minimize `loss` by updating `var_list`.
      var_list: Optional list or tuple of `Variable` objects to update to
        minimize `loss`.  Defaults to the list of variables collected in
        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`."""

创造计算图,先加载 fixed_var,再添加新的网络层

demo 如下:只是伪代码哦

tf.reset_default_graph()        ### 这句暂时可忽略

# 构建计算图
images = tf.placeholder(tf.float32,(None,224,224,3))
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
    logits, endpoints = mobilenet_v2.mobilenet(images,depth_multiplier=1.4)

with tf.variable_scope("finetune_layers"):
    mobilenet_tensor = tf.get_default_graph().get_tensor_by_name("MobilenetV2/expanded_conv_14/output:0")       # 获取目标张量,取出mobilenet中指定层的张量

    # 将张量作为新的 Input 向新层传递
    x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_1")(mobilenet_tensor)
    x = tf.nn.relu(x,name="relu_1")
    x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_2")(x)
    x = tf.layers.Conv2D(10,3,name="conv2d_3")(x)
    predictions = tf.reshape(x, (-1,10))

分别获取 train_var 和 fixed_var,在 minimize 中指定 var_list

demo 如下:只是伪代码哦

#### 不重要的函数
def get_var_list(target_tensor=None):
    '''获取指定变量列表 var_list 的函数;
       具体怎么干的,无需关心,只需要知道它的作用是获取一批权重
    '''
    if target_tensor==None:
        target_tensor = r"MobilenetV2/expanded_conv_14/output:0"
    target = target_tensor.split("/")[1]
    all_list = []
    all_var = []

    for var in tf.global_variables():
        if var != []:
            all_list.append(var.name)
            all_var.append(var)
    try:
        all_list = list(map(lambda x:x.split("/")[1],all_list))
        # 查找对应变量作用域的索引
        ind = all_list[::-1].index(target)
        ind = len(all_list) -  ind - 1
        print(ind)
        del all_list
        return all_var[:ind+1]
    except:
        print("target_tensor is not exist!")
        
#### 下面这一堆仔细看
x_train = np.random.random(size=(141,224,224,3))
y_train = to_categorical(label_fake,10)

y_label = tf.placeholder(tf.int32, (None,10))

### 收集变量作用域 finetune_layers 内的可训练变量,作为 train_var
train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="finetune_layers")

loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_label,logits=logits)
### 定义优化方法,用 var_list 指定需要更新的权重,此时仅更新 train_var 权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)

epochs = 10
batch_size = 16

# 目标张量名称,要获取变量列表 fixed_var
target_tensor = "MobilenetV2/expanded_conv_14/output:0"
fixed_var = get_var_list(target_tensor)
saver = tf.train.Saver(var_list=fixed_var)

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    writer = tf.summary.FileWriter(r"./logs", sess.graph)
    ## 初始化 train_var, 使用初始化指定函数
    sess.run(tf.variables_initializer(var_list=train_var))
    saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))

    for i in range(2000):
        start = (i*batch_size) % x_train.shape[0]
        end = min(start+batch_size, x_train.shape[0])
        _, merge, losses = sess.run([train_step,merge_all,loss], feed_dict={images:x_train[start:end], y_label:y_train[start:end]})
        if i%100==0: writer.add_summary(merge, i)

重点就下面 3 句

train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)
sess.run(tf.variables_initializer(var_list=train_var))    
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))

深入理解一下:

1. 如果初始化了 train_var 后,又加载了所有预训练变量,也就是说 train_var 的初始值也是 预训练的,而不是平常的 全0、全1、高斯分布等,这是可以的;

2. 如果先 加载所有预训练变量,然后初始化 train_var,也是可以的,因为 fixed_var 没有重新初始化,还是 预训练的值,而 train_var 初始值是多少没那么重要;

3. 如果初始化 train_var,加载 fixed_var,则谁先谁后无所谓;

4. 如果是 先用 tf.global_variables_initializer() 初始化全部参数,再加载全部预训练参数,也是可以的;

5. 如果先加载全部预训练参数,在用 tf.global_variables_initializer() 初始化全部参数,是不可以的,因为 fixed_var 也被初始化了;

上述代码采用的是第 3 种,指定了只加载 fixed_var

saver = tf.train.Saver(var_list=fixed_var)

参考资料:

https://zhuanlan.zhihu.com/p/42183653

原文地址:https://www.cnblogs.com/yanshw/p/12432595.html