tensorflow2.0——Resnet网络设计代码

import tensorflow as tf


class BasicBlock(tf.keras.layers.Layer):
    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.Activation('relu')

        self.conv2 = tf.keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()

        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filter_num, (1, 1), strides=stride))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None):
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        indentity = self.downsample(inputs)
        output = tf.keras.layers.add([out, indentity])
        output = tf.nn.relu(output)

        return output


class ResNet(tf.keras.Model):
    def __init__(self, layer_dims, num_classes=100):  # layer_dims=[2,2,2,2]  表示有4个resblock,每个resblock包含两个basicbloock
                                                        #   num_classes = 100   表示最后的分类有100个
        super(ResNet, self).__init__()
        self.stem = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=[2, 2], padding='same')
        ])

        self.layer1 = self.build_resblock(64, layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1],stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2],stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3],stride=2)
        #   output[b,512,h,w] 将最后的h*w平均为1个值,这样最后就只有[b,512,1,1]
        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(num_classes)

    def call(self, inputs, training=None):
        x = self.stem(inputs)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        #   [b,c]
        x = self.avgpool(x)
        #   [b,100]     100是现在设置的num_classes=100
        x = self.fc(x)
        return x

    def build_resblock(self, filter_num, blocks, stride=1):
        res_blocks = tf.keras.Sequential()
        res_blocks.add(BasicBlock(filter_num, stride=stride))

        for i in range(1,blocks):
            res_blocks.add(BasicBlock(filter_num, stride=1))

        return res_blocks

def resnet18():
    return ResNet([2,2,2,2])

def resnet34():
    return ResNet([3,4,6,3])
原文地址:https://www.cnblogs.com/cxhzy/p/13758763.html