(续)使用MindSpore_hub 进行 加载模型用于推理或迁移学习

接前文:

https://www.cnblogs.com/devilmaycry812839668/p/15005959.html

==========================================================

前文中,在冻结底层特征提取层的参数后,只训练最后一层全连接层,最终可以获得测试效果:

59 epoch,      metric:
 {'Loss': 0.7698193432237858, 'Top1-Acc': 0.7275641025641025, 'Top5-Acc': 0.9834735576923077}

如果我们在训练网络的时候 不冻结从MindSpore_hub 上下载的低层网络结构的参数,而是要其和自建的全连接网络一起进行训练呢,那么最终的效果如何呢???

训练

以下给出代码:

import os
import mindspore_hub as mshub
import mindspore
from mindspore import context, Tensor, nn
from mindspore.nn import Momentum
from mindspore.train.serialization import save_checkpoint, load_checkpoint,load_param_into_net
from mindspore import ops
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore import dtype as mstype
from mindspore import Model


# 设置新的网络结构
class ReduceMeanFlatten(nn.Cell):
    def __init__(self):
        super(ReduceMeanFlatten, self).__init__()
        self.mean = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.mean(x, (2, 3))
        x = self.flatten(x)
        return x


# 设置每步学习率
def generate_steps_lr(lr_init, steps_per_epoch, total_epochs):
    total_steps = total_epochs * steps_per_epoch
    decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps]
    lr_each_step = []
    for i in range(total_steps):
        if i < decay_epoch_index[0]:
            lr = lr_init
        elif i < decay_epoch_index[1]:
            lr = lr_init * 0.1
        elif i < decay_epoch_index[2]:
            lr = lr_init * 0.01
        else:
            lr = lr_init * 0.001
        lr_each_step.append(lr)
    return lr_each_step


# 设置数据集
def create_cifar10dataset(dataset_path, batch_size, do_train):
    if do_train:
        usage, shuffle = "train", True
    else:
        usage, shuffle = "test", False

    data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=True)

    # define map operations
    trans = [C.Resize((256, 256))]
    if do_train:
        trans += [
            C.RandomHorizontalFlip(prob=0.5),
        ]

    trans += [
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    return data_set


# Create Dataset
dataset_path = "datasets/cifar-10-batches-bin/train"
dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=True)

# 构建整体网络
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(True)

# Check MindSpore Hub website to conclude that the last output shape is 1280.
last_channel = 1280
# The number of classes in target task is 10.
num_classes = 10

reduce_mean_flatten = ReduceMeanFlatten()

classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(True)

"""
ckpt_path = "./cifar10_finetune_epoch59.ckpt"
trained_ckpt = load_checkpoint(ckpt_path)
load_param_into_net(classification_layer, trained_ckpt)
"""

train_network = nn.SequentialCell([network, reduce_mean_flatten, classification_layer])


# 正式训练设置
# Set epoch size
epoch_size = 60

# Wrap the backbone network with loss.
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
loss_net = nn.WithLossCell(train_network, loss_fn)


# Create an optimizer.
steps_per_epoch = dataset.get_dataset_size()
lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size)
# optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
optim = Momentum(filter(lambda x: x.requires_grad, train_network.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
# 构建模型
train_net = nn.TrainOneStepCell(loss_net, optim)


# for name, para in optim.parameters_and_names():
#     print(name)


for epoch in range(epoch_size):
    for i, items in enumerate(dataset):
        data, label = items
        data = mindspore.Tensor(data)
        label = mindspore.Tensor(label)

        loss = train_net(data, label)
        print(f"epoch: {epoch}/{epoch_size}, loss: {loss}")
    # Save the ckpt file for each epoch.
    if not os.path.exists('ckpt'):
       os.mkdir('ckpt')
    ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt"
    save_checkpoint(train_network, ckpt_path)


"""
dataset_path = "datasets/cifar-10-batches-bin/test"
# Define loss and create model.
eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False)
eval_metrics = {'Loss': nn.Loss(),
                 'Top1-Acc': nn.Top1CategoricalAccuracy(),
                 'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(train_network, loss_fn=loss_fn, optimizer=None, metrics=eval_metrics)

for i in range(1):
    metrics = model.eval(eval_dataset)
    print("{} epoch, 	 metric: 
".format(i), metrics)

# 0 epoch,      metric:
#  {'Loss': 0.7694976472128661, 'Top1-Acc': 0.7278645833333334, 'Top5-Acc': 0.9834735576923077}


# 0 epoch,      metric: 
#  {'Loss': 0.7698339091088527, 'Top1-Acc': 0.7276642628205128, 'Top5-Acc': 0.9834735576923077}
"""

其和之前代码唯一的区别:

设置低层网络可训练:

# 构建整体网络
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(True)

优化器对低层网络参数进行优化:

# optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
optim = Momentum(filter(lambda x: x.requires_grad, train_network.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)

最终进行测试,与之前代码一样,测试时最后一层全连接网络  classification_layer  引入了60个不同训练epoch下的参数,并拼接成60个只有最后一层 classification_layer  参数不同的推理网络,并进行评估:

这里我们因为拼接了60个模型,也就给出60个评估结果。

代码:

import os
import mindspore_hub as mshub
import mindspore
from mindspore import context, export, Tensor, nn
from mindspore.nn import Momentum
from mindspore.train.serialization import save_checkpoint, load_checkpoint,load_param_into_net
from mindspore import ops
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore import dtype as mstype
from mindspore import Model


# 设置新的网络结构
class ReduceMeanFlatten(nn.Cell):
    def __init__(self):
        super(ReduceMeanFlatten, self).__init__()
        self.mean = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.mean(x, (2, 3))
        x = self.flatten(x)
        return x


# 设置每步学习率
def generate_steps_lr(lr_init, steps_per_epoch, total_epochs):
    total_steps = total_epochs * steps_per_epoch
    decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps]
    lr_each_step = []
    for i in range(total_steps):
        if i < decay_epoch_index[0]:
            lr = lr_init
        elif i < decay_epoch_index[1]:
            lr = lr_init * 0.1
        elif i < decay_epoch_index[2]:
            lr = lr_init * 0.01
        else:
            lr = lr_init * 0.001
        lr_each_step.append(lr)
    return lr_each_step


# 设置数据集
def create_cifar10dataset(dataset_path, batch_size, do_train):
    if do_train:
        usage, shuffle = "train", True
    else:
        usage, shuffle = "test", False

    data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=True)

    # define map operations
    trans = [C.Resize((256, 256))]
    if do_train:
        trans += [
            C.RandomHorizontalFlip(prob=0.5),
        ]

    trans += [
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    return data_set


# 构建整体网络
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(False)

# Check MindSpore Hub website to conclude that the last output shape is 1280.
last_channel = 1280
# The number of classes in target task is 10.
num_classes = 10

reduce_mean_flatten = ReduceMeanFlatten()

classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(True)

train_network = nn.SequentialCell([network, reduce_mean_flatten, classification_layer])

# Wrap the backbone network with loss.
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")


dataset_path = "datasets/cifar-10-batches-bin/test"
# Define loss and create model.
eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False)
eval_metrics = {'Loss': nn.Loss(),
                 'Top1-Acc': nn.Top1CategoricalAccuracy(),
                 'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(train_network, loss_fn=loss_fn, optimizer=None, metrics=eval_metrics)


for i in range(60):
    # Load a pre-trained ckpt file.
    ckpt_path = "./ckpt/cifar10_finetune_epoch{}.ckpt".format(i)
    trained_ckpt = load_checkpoint(ckpt_path)
    load_param_into_net(train_network, trained_ckpt)

    metrics = model.eval(eval_dataset)
    print("{} epoch, 	 metric: 
".format(i), metrics)

运行结果:

 0 epoch,      metric:
 {'Loss': 0.6577172723527138, 'Top1-Acc': 0.7723357371794872, 'Top5-Acc': 0.9895833333333334}
1 epoch,      metric:
 {'Loss': 0.4597269055696252, 'Top1-Acc': 0.8423477564102564, 'Top5-Acc': 0.992988782051282}
2 epoch,      metric:
 {'Loss': 0.38787451790024835, 'Top1-Acc': 0.8724959935897436, 'Top5-Acc': 0.9943910256410257}
3 epoch,      metric:
 {'Loss': 0.3605625171405383, 'Top1-Acc': 0.8798076923076923, 'Top5-Acc': 0.9945913461538461}
4 epoch,      metric:
 {'Loss': 0.31980608275327355, 'Top1-Acc': 0.893729967948718, 'Top5-Acc': 0.9974959935897436}
5 epoch,      metric:
 {'Loss': 0.309864604099391, 'Top1-Acc': 0.8940304487179487, 'Top5-Acc': 0.9970953525641025}
6 epoch,      metric:
 {'Loss': 0.30217035547591364, 'Top1-Acc': 0.8982371794871795, 'Top5-Acc': 0.9966947115384616}
7 epoch,      metric:
 {'Loss': 0.2860589229191343, 'Top1-Acc': 0.9030448717948718, 'Top5-Acc': 0.9974959935897436}
8 epoch,      metric:
 {'Loss': 0.27449251447493833, 'Top1-Acc': 0.9097556089743589, 'Top5-Acc': 0.9978966346153846}
9 epoch,      metric:
 {'Loss': 0.2682889519327989, 'Top1-Acc': 0.9139623397435898, 'Top5-Acc': 0.9977964743589743}
10 epoch,      metric:
 {'Loss': 0.2563578961476779, 'Top1-Acc': 0.9139623397435898, 'Top5-Acc': 0.9977964743589743}
11 epoch,      metric:
 {'Loss': 0.2866767762372127, 'Top1-Acc': 0.9080528846153846, 'Top5-Acc': 0.9975961538461539}
12 epoch,      metric:
 {'Loss': 0.28505400239000434, 'Top1-Acc': 0.9091546474358975, 'Top5-Acc': 0.9978966346153846}
13 epoch,      metric:
 {'Loss': 0.2673245904883609, 'Top1-Acc': 0.9168669871794872, 'Top5-Acc': 0.9976963141025641}
14 epoch,      metric:
 {'Loss': 0.2491214134497568, 'Top1-Acc': 0.9205729166666666, 'Top5-Acc': 0.9977964743589743}
15 epoch,      metric:
 {'Loss': 0.24817989028405207, 'Top1-Acc': 0.921875, 'Top5-Acc': 0.9981971153846154}
16 epoch,      metric:
 {'Loss': 0.2570310448511289, 'Top1-Acc': 0.9211738782051282, 'Top5-Acc': 0.9976963141025641}
17 epoch,      metric:
 {'Loss': 0.3002154855169535, 'Top1-Acc': 0.9123597756410257, 'Top5-Acc': 0.9969951923076923}
18 epoch,      metric:
 {'Loss': 0.2195321206392971, 'Top1-Acc': 0.9382011217948718, 'Top5-Acc': 0.9983974358974359}
19 epoch,      metric:
 {'Loss': 0.22502830106951968, 'Top1-Acc': 0.9382011217948718, 'Top5-Acc': 0.9984975961538461}
20 epoch,      metric:
 {'Loss': 0.2301286602689801, 'Top1-Acc': 0.9394030448717948, 'Top5-Acc': 0.9986979166666666}
21 epoch,      metric:
 {'Loss': 0.2384060412140314, 'Top1-Acc': 0.9394030448717948, 'Top5-Acc': 0.9984975961538461}
22 epoch,      metric:
 {'Loss': 0.24342964330214903, 'Top1-Acc': 0.9413060897435898, 'Top5-Acc': 0.9985977564102564}
23 epoch,      metric:
 {'Loss': 0.2535825404521837, 'Top1-Acc': 0.9415064102564102, 'Top5-Acc': 0.9985977564102564}
24 epoch,      metric:
 {'Loss': 0.25143755575036497, 'Top1-Acc': 0.9411057692307693, 'Top5-Acc': 0.9985977564102564}
25 epoch,      metric:
 {'Loss': 0.2594438399058349, 'Top1-Acc': 0.9423076923076923, 'Top5-Acc': 0.9986979166666666}
26 epoch,      metric:
 {'Loss': 0.2635905724521133, 'Top1-Acc': 0.9426081730769231, 'Top5-Acc': 0.9986979166666666}
27 epoch,      metric:
 {'Loss': 0.2674773707296178, 'Top1-Acc': 0.9432091346153846, 'Top5-Acc': 0.9986979166666666}
28 epoch,      metric:
 {'Loss': 0.2769584987499911, 'Top1-Acc': 0.9410056089743589, 'Top5-Acc': 0.9986979166666666}
29 epoch,      metric:
 {'Loss': 0.27183630618879434, 'Top1-Acc': 0.9413060897435898, 'Top5-Acc': 0.9986979166666666}
30 epoch,      metric:
 {'Loss': 0.2784294540134187, 'Top1-Acc': 0.9425080128205128, 'Top5-Acc': 0.9986979166666666}
31 epoch,      metric:
 {'Loss': 0.2748070613234414, 'Top1-Acc': 0.9420072115384616, 'Top5-Acc': 0.9987980769230769}
32 epoch,      metric:
 {'Loss': 0.2730912099633338, 'Top1-Acc': 0.9438100961538461, 'Top5-Acc': 0.9987980769230769}
33 epoch,      metric:
 {'Loss': 0.28566566075326777, 'Top1-Acc': 0.9416065705128205, 'Top5-Acc': 0.9987980769230769}
34 epoch,      metric:
 {'Loss': 0.2814028470299887, 'Top1-Acc': 0.9425080128205128, 'Top5-Acc': 0.9986979166666666}
35 epoch,      metric:
 {'Loss': 0.2830203296055119, 'Top1-Acc': 0.9432091346153846, 'Top5-Acc': 0.9986979166666666}
36 epoch,      metric:
 {'Loss': 0.28531362023084467, 'Top1-Acc': 0.9421073717948718, 'Top5-Acc': 0.9983974358974359}
37 epoch,      metric:
 {'Loss': 0.2822426350837118, 'Top1-Acc': 0.9441105769230769, 'Top5-Acc': 0.9989983974358975}
38 epoch,      metric:
 {'Loss': 0.2795818664544367, 'Top1-Acc': 0.944511217948718, 'Top5-Acc': 0.9989983974358975}
39 epoch,      metric:
 {'Loss': 0.2843668075639159, 'Top1-Acc': 0.9446113782051282, 'Top5-Acc': 0.9986979166666666}
40 epoch,      metric:
 {'Loss': 0.27985764143700065, 'Top1-Acc': 0.9452123397435898, 'Top5-Acc': 0.9990985576923077}
41 epoch,      metric:
 {'Loss': 0.28603066254394294, 'Top1-Acc': 0.9455128205128205, 'Top5-Acc': 0.9988982371794872}
42 epoch,      metric:
 {'Loss': 0.2913823546581252, 'Top1-Acc': 0.9451121794871795, 'Top5-Acc': 0.9988982371794872}
43 epoch,      metric:
 {'Loss': 0.2839965824872158, 'Top1-Acc': 0.9441105769230769, 'Top5-Acc': 0.9990985576923077}
44 epoch,      metric:
 {'Loss': 0.2868479898847792, 'Top1-Acc': 0.9430088141025641, 'Top5-Acc': 0.9986979166666666}
45 epoch,      metric:
 {'Loss': 0.28786868360287327, 'Top1-Acc': 0.9428084935897436, 'Top5-Acc': 0.9986979166666666}
46 epoch,      metric:
 {'Loss': 0.286204165542227, 'Top1-Acc': 0.9434094551282052, 'Top5-Acc': 0.9986979166666666}
47 epoch,      metric:
 {'Loss': 0.2854417120347301, 'Top1-Acc': 0.9444110576923077, 'Top5-Acc': 0.9988982371794872}
48 epoch,      metric:
 {'Loss': 0.2862873625964335, 'Top1-Acc': 0.9439102564102564, 'Top5-Acc': 0.9988982371794872}
49 epoch,      metric:
 {'Loss': 0.2861092998081264, 'Top1-Acc': 0.9443108974358975, 'Top5-Acc': 0.9989983974358975}
50 epoch,      metric:
 {'Loss': 0.2889811675601372, 'Top1-Acc': 0.9430088141025641, 'Top5-Acc': 0.9986979166666666}
51 epoch,      metric:
 {'Loss': 0.28843501244916925, 'Top1-Acc': 0.9432091346153846, 'Top5-Acc': 0.9986979166666666}
52 epoch,      metric:
 {'Loss': 0.288954030366604, 'Top1-Acc': 0.9428084935897436, 'Top5-Acc': 0.9986979166666666}
53 epoch,      metric:
 {'Loss': 0.2873103237941858, 'Top1-Acc': 0.9436097756410257, 'Top5-Acc': 0.9985977564102564}
54 epoch,      metric:
 {'Loss': 0.2920099106223134, 'Top1-Acc': 0.9436097756410257, 'Top5-Acc': 0.9984975961538461}
55 epoch,      metric:
 {'Loss': 0.28650384773271026, 'Top1-Acc': 0.9436097756410257, 'Top5-Acc': 0.9985977564102564}
56 epoch,      metric:
 {'Loss': 0.2816329109621806, 'Top1-Acc': 0.9446113782051282, 'Top5-Acc': 0.9988982371794872}
57 epoch,      metric:
 {'Loss': 0.2840236019970013, 'Top1-Acc': 0.9454126602564102, 'Top5-Acc': 0.9986979166666666}
58 epoch,      metric:
 {'Loss': 0.28734656473935527, 'Top1-Acc': 0.9434094551282052, 'Top5-Acc': 0.9987980769230769}
59 epoch,      metric:
 {'Loss': 0.2908317415830372, 'Top1-Acc': 0.9436097756410257, 'Top5-Acc': 0.9985977564102564}

========================================================================================

假如我们按照前文(https://www.cnblogs.com/devilmaycry812839668/p/15005959.html

的方法,先冻结CNN层参数,只训练全连接层60epoch,然后再设置整个网络可训练此时再训练60epoch(此时是CNN层和全连接层均进行训练)

那么效果会如何呢???

由于前文(https://www.cnblogs.com/devilmaycry812839668/p/15005959.html)已经获得了(冻结CNN层参数,只训练全连接层60epoch)全连接的参数,此时我们只需要将前文获得的全连接参数导入到网络中,解冻CNN层的参数再训练60个epoch即可。

代码如下:

import os
import mindspore_hub as mshub
import mindspore
from mindspore import context, Tensor, nn
from mindspore.nn import Momentum
from mindspore.train.serialization import save_checkpoint, load_checkpoint,load_param_into_net
from mindspore import ops
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore import dtype as mstype
from mindspore import Model


# 设置新的网络结构
class ReduceMeanFlatten(nn.Cell):
    def __init__(self):
        super(ReduceMeanFlatten, self).__init__()
        self.mean = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.mean(x, (2, 3))
        x = self.flatten(x)
        return x


# 设置每步学习率
def generate_steps_lr(lr_init, steps_per_epoch, total_epochs):
    total_steps = total_epochs * steps_per_epoch
    decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps]
    lr_each_step = []
    for i in range(total_steps):
        if i < decay_epoch_index[0]:
            lr = lr_init
        elif i < decay_epoch_index[1]:
            lr = lr_init * 0.1
        elif i < decay_epoch_index[2]:
            lr = lr_init * 0.01
        else:
            lr = lr_init * 0.001
        lr_each_step.append(lr)
    return lr_each_step


# 设置数据集
def create_cifar10dataset(dataset_path, batch_size, do_train):
    if do_train:
        usage, shuffle = "train", True
    else:
        usage, shuffle = "test", False

    data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=True)

    # define map operations
    trans = [C.Resize((256, 256))]
    if do_train:
        trans += [
            C.RandomHorizontalFlip(prob=0.5),
        ]

    trans += [
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    return data_set


# Create Dataset
dataset_path = "datasets/cifar-10-batches-bin/train"
dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=True)

# 构建整体网络
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(True)

# Check MindSpore Hub website to conclude that the last output shape is 1280.
last_channel = 1280
# The number of classes in target task is 10.
num_classes = 10

reduce_mean_flatten = ReduceMeanFlatten()

classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(True)


ckpt_path = "./cifar10_finetune_epoch59.ckpt"
trained_ckpt = load_checkpoint(ckpt_path)
load_param_into_net(classification_layer, trained_ckpt)


train_network = nn.SequentialCell([network, reduce_mean_flatten, classification_layer])


# 正式训练设置
# Set epoch size
epoch_size = 60

# Wrap the backbone network with loss.
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
loss_net = nn.WithLossCell(train_network, loss_fn)


# Create an optimizer.
steps_per_epoch = dataset.get_dataset_size()
lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size)
# optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
optim = Momentum(filter(lambda x: x.requires_grad, train_network.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
# 构建模型
train_net = nn.TrainOneStepCell(loss_net, optim)


# for name, para in optim.parameters_and_names():
#     print(name)


for epoch in range(epoch_size):
    for i, items in enumerate(dataset):
        data, label = items
        data = mindspore.Tensor(data)
        label = mindspore.Tensor(label)

        loss = train_net(data, label)
        print(f"epoch: {epoch}/{epoch_size}, loss: {loss}")
    # Save the ckpt file for each epoch.
    if not os.path.exists('ckpt'):
       os.mkdir('ckpt')
    ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt"
    save_checkpoint(train_network, ckpt_path)
View Code

测试代码如下:

import os
import mindspore_hub as mshub
import mindspore
from mindspore import context, export, Tensor, nn
from mindspore.nn import Momentum
from mindspore.train.serialization import save_checkpoint, load_checkpoint,load_param_into_net
from mindspore import ops
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore import dtype as mstype
from mindspore import Model


# 设置新的网络结构
class ReduceMeanFlatten(nn.Cell):
    def __init__(self):
        super(ReduceMeanFlatten, self).__init__()
        self.mean = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.mean(x, (2, 3))
        x = self.flatten(x)
        return x


# 设置每步学习率
def generate_steps_lr(lr_init, steps_per_epoch, total_epochs):
    total_steps = total_epochs * steps_per_epoch
    decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps]
    lr_each_step = []
    for i in range(total_steps):
        if i < decay_epoch_index[0]:
            lr = lr_init
        elif i < decay_epoch_index[1]:
            lr = lr_init * 0.1
        elif i < decay_epoch_index[2]:
            lr = lr_init * 0.01
        else:
            lr = lr_init * 0.001
        lr_each_step.append(lr)
    return lr_each_step


# 设置数据集
def create_cifar10dataset(dataset_path, batch_size, do_train):
    if do_train:
        usage, shuffle = "train", True
    else:
        usage, shuffle = "test", False

    data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=True)

    # define map operations
    trans = [C.Resize((256, 256))]
    if do_train:
        trans += [
            C.RandomHorizontalFlip(prob=0.5),
        ]

    trans += [
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    return data_set


# 构建整体网络
model = "mindspore/ascend/1.0/mobilenetv2_v1.0_openimage"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(False)

# Check MindSpore Hub website to conclude that the last output shape is 1280.
last_channel = 1280
# The number of classes in target task is 10.
num_classes = 10

reduce_mean_flatten = ReduceMeanFlatten()

classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(True)

train_network = nn.SequentialCell([network, reduce_mean_flatten, classification_layer])

# Wrap the backbone network with loss.
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")


dataset_path = "datasets/cifar-10-batches-bin/test"
# Define loss and create model.
eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False)
eval_metrics = {'Loss': nn.Loss(),
                 'Top1-Acc': nn.Top1CategoricalAccuracy(),
                 'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(train_network, loss_fn=loss_fn, optimizer=None, metrics=eval_metrics)


for i in range(60):
    # Load a pre-trained ckpt file.
    ckpt_path = "./ckpt/cifar10_finetune_epoch{}.ckpt".format(i)
    trained_ckpt = load_checkpoint(ckpt_path)
    load_param_into_net(train_network, trained_ckpt)

    metrics = model.eval(eval_dataset)
    print("{} epoch, 	 metric: 
".format(i), metrics)
View Code

测试的运行结果:

0 epoch,      metric:
 {'Loss': 1.109902911270276, 'Top1-Acc': 0.5971554487179487, 'Top5-Acc': 0.9620392628205128}
1 epoch,      metric:
 {'Loss': 1.0624191101927023, 'Top1-Acc': 0.6339142628205128, 'Top5-Acc': 0.9628405448717948}
2 epoch,      metric:
 {'Loss': 0.7588460279198793, 'Top1-Acc': 0.7394831730769231, 'Top5-Acc': 0.9831730769230769}
3 epoch,      metric:
 {'Loss': 0.5849081557721664, 'Top1-Acc': 0.7970753205128205, 'Top5-Acc': 0.9881810897435898}
4 epoch,      metric:
 {'Loss': 0.6166368476473368, 'Top1-Acc': 0.7920673076923077, 'Top5-Acc': 0.9852764423076923}
5 epoch,      metric:
 {'Loss': 0.5292885322123766, 'Top1-Acc': 0.8171073717948718, 'Top5-Acc': 0.9931891025641025}
6 epoch,      metric:
 {'Loss': 0.4921705843641972, 'Top1-Acc': 0.8321314102564102, 'Top5-Acc': 0.991386217948718}
7 epoch,      metric:
 {'Loss': 0.4731407286121677, 'Top1-Acc': 0.8370392628205128, 'Top5-Acc': 0.9928886217948718}
8 epoch,      metric:
 {'Loss': 0.4385749581866922, 'Top1-Acc': 0.8488581730769231, 'Top5-Acc': 0.9932892628205128}
9 epoch,      metric:
 {'Loss': 0.4706742326514079, 'Top1-Acc': 0.8462540064102564, 'Top5-Acc': 0.9935897435897436}
10 epoch,      metric:
 {'Loss': 0.4497633205058101, 'Top1-Acc': 0.8553685897435898, 'Top5-Acc': 0.995292467948718}
11 epoch,      metric:
 {'Loss': 0.4042159711034634, 'Top1-Acc': 0.8647836538461539, 'Top5-Acc': 0.9959935897435898}
12 epoch,      metric:
 {'Loss': 0.38131014781836897, 'Top1-Acc': 0.8719951923076923, 'Top5-Acc': 0.9954927884615384}
13 epoch,      metric:
 {'Loss': 0.40497787275279945, 'Top1-Acc': 0.8701923076923077, 'Top5-Acc': 0.9950921474358975}
14 epoch,      metric:
 {'Loss': 0.3917648155624286, 'Top1-Acc': 0.8725961538461539, 'Top5-Acc': 0.9954927884615384}
15 epoch,      metric:
 {'Loss': 0.3893791334405064, 'Top1-Acc': 0.8768028846153846, 'Top5-Acc': 0.9961939102564102}
16 epoch,      metric:
 {'Loss': 0.37673575434690487, 'Top1-Acc': 0.8773036858974359, 'Top5-Acc': 0.9961939102564102}
17 epoch,      metric:
 {'Loss': 0.4286681634779924, 'Top1-Acc': 0.8668870192307693, 'Top5-Acc': 0.9941907051282052}
18 epoch,      metric:
 {'Loss': 0.34090835952128357, 'Top1-Acc': 0.9042467948717948, 'Top5-Acc': 0.9969951923076923}
19 epoch,      metric:
 {'Loss': 0.3469662770258788, 'Top1-Acc': 0.9025440705128205, 'Top5-Acc': 0.9971955128205128}
20 epoch,      metric:
 {'Loss': 0.36564507971380433, 'Top1-Acc': 0.9037459935897436, 'Top5-Acc': 0.9974959935897436}
21 epoch,      metric:
 {'Loss': 0.38988389251067135, 'Top1-Acc': 0.9051482371794872, 'Top5-Acc': 0.9971955128205128}
22 epoch,      metric:
 {'Loss': 0.4074996639616214, 'Top1-Acc': 0.9044471153846154, 'Top5-Acc': 0.9975961538461539}
23 epoch,      metric:
 {'Loss': 0.4216701938496091, 'Top1-Acc': 0.9061498397435898, 'Top5-Acc': 0.9976963141025641}
24 epoch,      metric:
 {'Loss': 0.4329471406541789, 'Top1-Acc': 0.9046474358974359, 'Top5-Acc': 0.9976963141025641}
25 epoch,      metric:
 {'Loss': 0.439997365129481, 'Top1-Acc': 0.9060496794871795, 'Top5-Acc': 0.9973958333333334}
26 epoch,      metric:
 {'Loss': 0.4489440012414259, 'Top1-Acc': 0.9069511217948718, 'Top5-Acc': 0.9970953525641025}
27 epoch,      metric:
 {'Loss': 0.466666470026375, 'Top1-Acc': 0.9060496794871795, 'Top5-Acc': 0.9973958333333334}
28 epoch,      metric:
 {'Loss': 0.47206670428642655, 'Top1-Acc': 0.9063501602564102, 'Top5-Acc': 0.9975961538461539}
29 epoch,      metric:
 {'Loss': 0.48405538867481457, 'Top1-Acc': 0.9067508012820513, 'Top5-Acc': 0.9973958333333334}
30 epoch,      metric:
 {'Loss': 0.4855282697421782, 'Top1-Acc': 0.9077524038461539, 'Top5-Acc': 0.9975961538461539}
31 epoch,      metric:
 {'Loss': 0.49811935061865, 'Top1-Acc': 0.9060496794871795, 'Top5-Acc': 0.9971955128205128}
32 epoch,      metric:
 {'Loss': 0.4986075199946451, 'Top1-Acc': 0.9077524038461539, 'Top5-Acc': 0.9971955128205128}
33 epoch,      metric:
 {'Loss': 0.5223908047570521, 'Top1-Acc': 0.905448717948718, 'Top5-Acc': 0.9967948717948718}
34 epoch,      metric:
 {'Loss': 0.5181410906087154, 'Top1-Acc': 0.9040464743589743, 'Top5-Acc': 0.9971955128205128}
35 epoch,      metric:
 {'Loss': 0.5221440061394764, 'Top1-Acc': 0.9067508012820513, 'Top5-Acc': 0.9972956730769231}
36 epoch,      metric:
 {'Loss': 0.5157031946561148, 'Top1-Acc': 0.9084535256410257, 'Top5-Acc': 0.9973958333333334}
37 epoch,      metric:
 {'Loss': 0.5131257220487803, 'Top1-Acc': 0.907051282051282, 'Top5-Acc': 0.9971955128205128}
38 epoch,      metric:
 {'Loss': 0.5138702635192133, 'Top1-Acc': 0.9075520833333334, 'Top5-Acc': 0.9971955128205128}
39 epoch,      metric:
 {'Loss': 0.5061611129858647, 'Top1-Acc': 0.9088541666666666, 'Top5-Acc': 0.9972956730769231}
40 epoch,      metric:
 {'Loss': 0.5194041342985702, 'Top1-Acc': 0.9077524038461539, 'Top5-Acc': 0.9972956730769231}
41 epoch,      metric:
 {'Loss': 0.5229644895124903, 'Top1-Acc': 0.9077524038461539, 'Top5-Acc': 0.9972956730769231}
42 epoch,      metric:
 {'Loss': 0.5114161455773263, 'Top1-Acc': 0.9089543269230769, 'Top5-Acc': 0.9970953525641025}
43 epoch,      metric:
 {'Loss': 0.5174470136573804, 'Top1-Acc': 0.9076522435897436, 'Top5-Acc': 0.9970953525641025}
44 epoch,      metric:
 {'Loss': 0.5201686953224942, 'Top1-Acc': 0.9085536858974359, 'Top5-Acc': 0.9972956730769231}
45 epoch,      metric:
 {'Loss': 0.5192232744206055, 'Top1-Acc': 0.9084535256410257, 'Top5-Acc': 0.9970953525641025}
46 epoch,      metric:
 {'Loss': 0.5211286343385907, 'Top1-Acc': 0.909354967948718, 'Top5-Acc': 0.9971955128205128}
47 epoch,      metric:
 {'Loss': 0.5148008207294109, 'Top1-Acc': 0.9085536858974359, 'Top5-Acc': 0.9973958333333334}
48 epoch,      metric:
 {'Loss': 0.5206076546476722, 'Top1-Acc': 0.9079527243589743, 'Top5-Acc': 0.9974959935897436}
49 epoch,      metric:
 {'Loss': 0.5179570562285312, 'Top1-Acc': 0.9087540064102564, 'Top5-Acc': 0.9970953525641025}
50 epoch,      metric:
 {'Loss': 0.5165543049891616, 'Top1-Acc': 0.9088541666666666, 'Top5-Acc': 0.9971955128205128}
51 epoch,      metric:
 {'Loss': 0.5153302571296734, 'Top1-Acc': 0.9091546474358975, 'Top5-Acc': 0.9970953525641025}
52 epoch,      metric:
 {'Loss': 0.5174539989204323, 'Top1-Acc': 0.9088541666666666, 'Top5-Acc': 0.996895032051282}
53 epoch,      metric:
 {'Loss': 0.5147456764461822, 'Top1-Acc': 0.9085536858974359, 'Top5-Acc': 0.9971955128205128}
54 epoch,      metric:
 {'Loss': 0.5186178396086175, 'Top1-Acc': 0.9087540064102564, 'Top5-Acc': 0.9973958333333334}
55 epoch,      metric:
 {'Loss': 0.5285426172086796, 'Top1-Acc': 0.9084535256410257, 'Top5-Acc': 0.9970953525641025}
56 epoch,      metric:
 {'Loss': 0.5223714007378686, 'Top1-Acc': 0.9087540064102564, 'Top5-Acc': 0.9973958333333334}
57 epoch,      metric:
 {'Loss': 0.5148643667284304, 'Top1-Acc': 0.9074519230769231, 'Top5-Acc': 0.9974959935897436}
58 epoch,      metric:
 {'Loss': 0.5172925066374143, 'Top1-Acc': 0.9076522435897436, 'Top5-Acc': 0.9970953525641025}
59 epoch,      metric:
 {'Loss': 0.5174664831160604, 'Top1-Acc': 0.9080528846153846, 'Top5-Acc': 0.9969951923076923}

================================================================

最终,我们获得了三种迁移学习训练方法的最终性能数据:

第一种方法:(前文方法)

冻结网络的CNN层,不冻结全连接层,为CNN层导入mindspore_hub上的预训练参数,进行60个epoch的训练。

测试结果,第60个epoch训练后测试结果:

59 epoch,      metric:
 {'Loss': 0.7698193432237858, 'Top1-Acc': 0.7275641025641025, 'Top5-Acc': 0.9834735576923077}

第二种方法:(本文第一个方法)

网络的CNN层和全连接层均不冻结,整个网络的参数均进行训练,为CNN层导入mindspore_hub上的预训练参数,进行60个epoch的训练。

测试结果:

第1个epoch训练后测试结果:

 0 epoch,      metric:
 {'Loss': 0.6577172723527138, 'Top1-Acc': 0.7723357371794872, 'Top5-Acc': 0.9895833333333334} 

第60个epoch训练后测试结果:

59 epoch,      metric:
 {'Loss': 0.2908317415830372, 'Top1-Acc': 0.9436097756410257, 'Top5-Acc': 0.9985977564102564}

第三种方法:(本文第二个方法)

网络的CNN层和全连接层均不冻结,整个网络的参数均进行训练。为CNN层导入mindspore_hub上的预训练参数,并为全连接层导入第一种方法中第60epoch训练获得的全连接参数,再进行60个epoch的训练。

测试结果:

 

第1个epoch训练后测试结果:

0 epoch,      metric:
 {'Loss': 1.109902911270276, 'Top1-Acc': 0.5971554487179487, 'Top5-Acc': 0.9620392628205128}

第60个epoch训练后测试结果:

59 epoch,      metric:
 {'Loss': 0.5174664831160604, 'Top1-Acc': 0.9080528846153846, 'Top5-Acc': 0.9969951923076923}

分析:

第一种方法冻结了CNN层的参数,只训练全连接层参数,训练后性能效果最差,是三种方法中最差的。这说明迁移学习中低层的网络特征也应该进行一定的训练,至少是在数据量较为充分,训练次数教多的情况下,低层网络也是应该参与训练的。(请注意:小样本学习任务不在这个讨论范围内)

第二种方法 和 第三种方法 对整个网络的全体参数都没有冻结,全部参数均进行训练,获得了较好效果。但是值得注意的是,第三种方法相当于是在第一种方法的训练基础上将CNN层的参数解冻(或者说是为网络引入了第一种方法的最终全连接参数),但是第一个epoch后获得的效果不但没有提升反而下降了,差于第一种方法的最终效果。而第二种方法同第一种、第三种方法一样也是为CNN层引入了mindspore_hub的预训练参数,并同第三种方法一样没有冻结CNN层参数,但是不同于第三种方法的是它的全连接层采用的是随机初始化而不是引入第一种方法最终获得的全连接参数。

这说明,低层网络如,CNN层这样的网络更具有泛化性;而高层网络,如全连接层,与低层网络耦合性较高,是属于针对特定低层网络的,如果低层网络有较大变动的情况下特定参数的全连接网络开始训练是没有在随机参数的全连接网络开始训练的效果好的。或者可以这样说,低层网络参数具有较好的迁移性,而高层网络,如全连接层网络参数是不具备较好的迁移性的,如果使用某个特定训练后的全连接网络参数最终性能很可能会不如随机初始化全连接网络参数的效果好。

本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
原文地址:https://www.cnblogs.com/devilmaycry812839668/p/15008269.html