Vision transformer

import sys

import tensorflow as tf
from tensorflow.keras import Model, layers, initializers
import numpy as np
from tensorflow import keras


save_attn_flag =False
# save_attn_flag = True

class PatchEmbed(layers.Layer):
    """
    2D Image to Patch Embedding
    """
    # def __init__(self, img_size=224, patch_size=16, embed_dim=768):
    def __init__(self,  patch_size, embed_dim):
        super(PatchEmbed, self).__init__()
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        # self.img_size = (img_size, img_size)
        # self.grid_size = (img_size // patch_size, img_size // patch_size)
        # self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = layers.Conv2D(filters=self.embed_dim, kernel_size=self.patch_size,
                                  strides=self.patch_size, padding='SAME',
                                  kernel_initializer=initializers.LecunNormal(),
                                  bias_initializer=initializers.Zeros())


    def call(self, inputs, **kwargs):
        B, H, W, C = inputs.shape
        # print('P里B:{},H:{},W:{},C:{}'.format(B,H,W,C))
        # print()
        # print('c_H:{},patch:{}'.format(H // self.patch_size,self.patch_size))
        # print('P里inputs.shape:',inputs.shape)
        x = self.proj(inputs)
        # print('haha')
        # sys.exit(2)
        c_H = H // self.patch_size
        c_W = W // self.patch_size
        #print('PatchEmbed函数中:c_H:{},c_W:{}'.format(c_H,c_W))
        # sys.exit(2)
        # [B, H, W, C] -> [B, H*W, C]
        # x = tf.reshape(x, [B, self.num_patches, self.embed_dim])
        #print('PatchEmbed函数中x.shape:{}'.format(x.shape))
        x = tf.reshape(x, (-1, c_H*c_W, self.embed_dim))
        return x


class ConcatClassTokenAddPosEmbed(layers.Layer):
    def __init__(self, embed_dim, name=None):
        super(ConcatClassTokenAddPosEmbed, self).__init__(name=name)
        self.embed_dim = embed_dim
        # self.num_patches = num_patches

    def build(self, input_shape):
        #print('build里,input_shape:{}'.format(input_shape))

        # self.cls_token = self.add_weight(name="cls",
        #                                  shape=[1, 1, self.embed_dim],
        #                                  initializer=initializers.Zeros(),
        #                                  trainable=True,
        #                                  dtype=tf.float32)
        self.pos_embed = self.add_weight(name="pos_embed",
                                         shape=[1, input_shape[1], self.embed_dim],
                                         initializer=initializers.RandomNormal(stddev=0.02),
                                         trainable=True,
                                         dtype=tf.float32)

    def call(self, inputs, **kwargs):
        batch_size, _, _ = inputs.shape

        # [1, 1, 768] -> [B, 1, 768]
        # cls_token = tf.broadcast_to(self.cls_token, shape=[batch_size, 1, self.embed_dim])
        # x = tf.concat([cls_token, inputs], axis=1)  # [B, 197, 768]
        # x = x + self.pos_embed

        x = inputs + self.pos_embed
        #print('self.pos_embed:',self.pos_embed.shape)
        if save_attn_flag == True:
            np.save('pos1',self.pos_embed.numpy())

        return x


class Attention(layers.Layer):
    k_ini = initializers.GlorotUniform()
    b_ini = initializers.Zeros()

    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.,
                 name=None):
        super(Attention, self).__init__(name=name)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.attn_drop = layers.Dropout(attn_drop_ratio)
        self.proj = layers.Dense(dim, name="out",
                                 kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.proj_drop = layers.Dropout(proj_drop_ratio)

    def call(self, inputs, training=None):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = inputs.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        qkv = self.qkv(inputs)
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        qkv = tf.reshape(qkv, [-1, N, 3, self.num_heads, C // self.num_heads])
        # transpose: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]
        #print('q:{},k:{},v:{}'.format(q.shape,k.shape,v.shape))

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale
        #print('attn1:', attn.shape)
        attn = tf.nn.softmax(attn, axis=-1)
        #print('attn2:', attn.shape)
        attn = self.attn_drop(attn, training=training)
        #print('attn3:',attn.shape)

        if save_attn_flag == True:
            attn_np =  attn.numpy()
            np.save('att1',attn_np)
        # sys.exit(2)

        # multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        x = tf.matmul(attn, v)
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        x = tf.transpose(x, [0, 2, 1, 3])
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = tf.reshape(x, [-1, N, C])

        x = self.proj(x)
        x = self.proj_drop(x, training=training)
        return x


class MLP(layers.Layer):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    k_ini = initializers.GlorotUniform()
    b_ini = initializers.RandomNormal(stddev=1e-6)

    def __init__(self, in_features, mlp_ratio=4.0, drop=0., name=None):
        super(MLP, self).__init__(name=name)
        self.fc1 = layers.Dense(int(in_features * mlp_ratio), name="Dense_0",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        # self.act = layers.Activation("gelu")
        self.act = tf.nn.swish
        self.fc2 = layers.Dense(in_features, name="Dense_1",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.drop = layers.Dropout(drop)

    def call(self, inputs, training=None):
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.drop(x, training=training)
        x = self.fc2(x)
        x = self.drop(x, training=training)
        return x


class Block(layers.Layer):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 name=None):
        super(Block, self).__init__(name=name)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_0")
        self.attn = Attention(dim, num_heads=num_heads,
                              qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio,
                              name="MultiHeadAttention")
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = layers.Dropout(rate=drop_path_ratio, noise_shape=(None, 1, 1)) if drop_path_ratio > 0. 
            else layers.Activation("linear")
        self.norm2 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_1")
        self.mlp = MLP(dim, drop=drop_ratio, name="MlpBlock")

    def call(self, inputs, training=None):
        #
        # x = inputs + self.drop_path(self.attn(self.norm1(inputs)), training=training)
        # x = x + self.drop_path(self.mlp(self.norm2(x)), training=training)

        #   自己
        x = self.norm1(inputs + self.drop_path(self.attn(inputs), training=training))
        x = self.norm2(x + self.drop_path(self.mlp(x), training=training))
        return x


class VisionTransformer(Model):
    def __init__(self,patch_size, embed_dim,
                 depth, num_heads, qkv_bias=True, qk_scale=None,
                 drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0.,
                 representation_size=None, name="ViT-B/16"):
        super(VisionTransformer, self).__init__(name=name)
        self.embed_dim = embed_dim
        self.depth = depth
        self.qkv_bias = qkv_bias
        self.patch_size = patch_size

        self.patch_embed = PatchEmbed(patch_size=self.patch_size, embed_dim=self.embed_dim)
        # num_patches = self.patch_embed.num_patches
        self.cls_token_pos_embed = ConcatClassTokenAddPosEmbed(embed_dim=self.embed_dim,

                                                               name="cls_pos")

        self.pos_drop = layers.Dropout(drop_ratio)

        dpr = np.linspace(0., drop_path_ratio, depth)  # stochastic depth decay rule
        self.blocks = [Block(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
                             qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
                             drop_path_ratio=dpr[i], name="encoderblock_{}".format(i))
                       for i in range(depth)]

        # self.norm = layers.LayerNormalization(epsilon=1e-6, name="encoder_norm")
        self.T_cnn =layers.Conv2DTranspose(filters=self.embed_dim, kernel_size=patch_size,
                                  strides=patch_size, padding='SAME',
                                  kernel_initializer=initializers.LecunNormal(),
                                  bias_initializer=initializers.Zeros())
        self.T_cnn2 = layers.Conv2DTranspose(filters=1, kernel_size=2,
                                            strides=1, padding='VALID',
                                            kernel_initializer=initializers.LecunNormal(),
                                            bias_initializer=initializers.Zeros())
        self.convlstm = keras.Sequential([
            keras.layers.BatchNormalization(),

            # keras.layers.ConvLSTM2D(dim_base,kernel_size=(3,3),strides=1,padding='SAME',return_sequences=False,activation='relu'),
            keras.layers.Bidirectional(
                keras.layers.ConvLSTM2D(self.embed_dim, kernel_size=(2, 2), strides=1, padding='VALID', return_sequences=False,
                                        activation='swish')),

            # keras.layers.ConvLSTM2D(100, kernel_size=(3, 3), strides=2, padding='valid', return_sequences=False,
            #                         activation='relu'),

        ],name='convlstm')



    def call(self, inputs, training=None):
        #print('inpus.shape;{}'.format(inputs.shape))
        # sys.exit(2)
        inputs = inputs[:,9:]
        #print('切片后:inpus.shape;{}'.format(inputs.shape))
        era_in = inputs[:,-1]
        #print('era_in:',era_in.shape)
        # inputs = inputs[:,-1,:,:]*1000
        # inputs = tf.reshape(inputs, (-1, 40, 80, 1))
        # inputs = tf.transpose(inputs*1000,[0,2,3,1])
        inputs = inputs*1000
        # inputs = tf.where(inputs < yuzhi, 0., inputs)
        # inputs = tf.where(inputs < yuzhi, inputs*1, inputs)
        # inputs = tf.where(inputs > 5., inputs*1, inputs)
        inputs = tf.reshape(inputs, (-1, 6,41, 81, 1))
        # print('inpus2.shape;{}'.format(inputs.shape))
        inputs = self.convlstm(inputs)
        # print('ConvLSTM后,inpus.shape;{}'.format(inputs.shape))

        c_H = inputs.shape[1] // self.patch_size
        c_W = inputs.shape[2] // self.patch_size
        # [B, H, W, C] -> [B, num_patches, embed_dim]
        x = self.patch_embed(inputs)  # [B, 196, 768]
        #print('patch_embed后:{}'.format(x.shape))
        x = self.cls_token_pos_embed(x)  # [B, 176, 768]
        #print('cls_token_pos_embed(pos)后:{}'.format(x.shape))
        x = self.pos_drop(x, training=training)
        # print('patch_embed后:{}'.format(x.shape))

        for block in self.blocks:
            x = block(x, training=training)

        #print('block后:{}'.format(x.shape))
        x = tf.reshape(x,(-1,c_H,c_W,self.embed_dim))
        #print('reshape后:{}'.format(x.shape))
        x = self.T_cnn(x)
        #print('T_cnn后:{}'.format(x.shape))
        x = self.T_cnn2(x)
        # x = tf.sqrt(tf.multiply(x,x))
        #print('T_cnn2后:{}'.format(x.shape))
        # x = tf.sqrt(tf.multiply(x, x))
        # #print('正数后:{}'.format(x.shape))
        x = tf.squeeze(x,axis=-1)
        #print('squeeze后:{}'.format(x.shape))

        # x = tf.add(x,era_in)
        #print('add后:{}'.format(x.shape))
        # sys.exit(2)

        return x
原文地址:https://www.cnblogs.com/cxhzy/p/15366814.html