搭建resnet

构建resnet的基本单元(basicblock)如下图所示,可以看到,输入张量input在基本单元中会经过两个通路,一个通路由两个卷积层构成,另一个通路是跳接。

resnet的整体结构可以参考 resnet34,可以看到resnet34也是由上述basicblock串联得到,需要注意的是:

(1) 在某些basicblock中,第一个卷积层会对input做shape上的变换,即对输入图片的长宽减一半、通道加一倍

(2) 如果input在卷积层发生了shape上的变换,则对应的跳接线也需要将input进行shape上的变换(通过大小为(1,1)的卷积核)

(3) 在所有basicblock中,第二个卷积层相对第一个卷积层不对数据做shape上的变换

(4) 在所有basicblock中,卷积核的大小保持(3,3)不变,改变图片长宽依靠的是卷积步长

接下来先来用tensorflow搭建一个basicblock类,这个类通过它的参数stride来判定跳接线是“实线”还是“虚线”。实例化basicblock类时只需指定卷积网络的三个基本要素—— 卷积核大小、卷积核数量、卷积步长,即可得到一个basicblock对象。下面的代码构建了BasicBlock类,并使用BasicBlock搭建了一个简单的resnet

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
from tensorflow.keras.callbacks import TensorBoard

class BasicBlock(Model):
    def __init__(self, filter_num, stride=1, filter_size=(3,3)):
        super(BasicBlock,self).__init__()
        # 卷积通路
        self.conv1 = layers.Conv2D(filters=filter_num, kernel_size=filter_size, strides=(stride,stride), padding='same')
        self.relu = layers.Activation('relu')
        self.conv2 = layers.Conv2D(filters=filter_num, kernel_size=filter_size, strides=(1,1), padding='same')
        # 跳接线
        if stride == 1:
            self.sortcut = lambda x: x
        else:
            self.sortcut = layers.Conv2D(filters=filter_num, kernel_size=(1,1), strides=(stride,stride), padding='same')

    def call(self, inputs):
        out = self.conv1(inputs)
        out = self.relu(out)
        out = self.conv2(out)
        identity = self.sortcut(inputs)
        out = layers.add([out,identity])
        return out

# 用BasicBlock搭建一个简单的网络
inputs = layers.Input(shape=(28,28,3))

out = BasicBlock(filter_num=8, stride=2).call(inputs)
out = BasicBlock(filter_num=8, stride=1).call(out)
out = BasicBlock(filter_num=8, stride=1).call(out)

out = BasicBlock(filter_num=16, stride=2).call(out)
out = BasicBlock(filter_num=16, stride=1).call(out)
out = BasicBlock(filter_num=16, stride=1).call(out)

out = layers.Flatten()(out)
out = layers.Dense(1)(out)
model = Model(inputs=inputs, outputs=out)

model.compile(optimizer='adam',loss='mse')
Tensorboard = TensorBoard(log_dir='.\logs', histogram_freq=1)
x = np.random.rand(100,28,28,3)
y = np.random.rand(100,1)
model.fit(x,y,epochs=100,callbacks=[Tensorboard],verbose=2)

截取该resnet中的两个BasicBlock,如下图所示

原文地址:https://www.cnblogs.com/bill-h/p/14139471.html