Shifted Window算法详解

Swin Transformer


作者:elfin   资料来源:Swin


论文地址:https://arxiv.org/abs/2103.14030

项目地址:https://github.com/microsoft/Swin-Transformer


Top  ---  Bottom

摘要

​ 本文提出了一种新的 vision Transformer,称为Swin Transformer,它可以作为计算机视觉的通用backboneTransformer从语言到视觉的转换面临很大的挑战,它主要来自于两个领域之间的差异,例如视觉实体的规模变化很大,图像中的像素与文本中的单词相比分辨率很高。为了解决这些差异,我们提出了一个 hierarchical Transformer(层次 Transformer),其表征是用shifted window计算的。滑动窗口方案通过将自注意计算限制在非重叠的局部窗口上,同时允许跨窗口连接,从而提高了效率。这种hierarchical体系结构具有在各种尺度下建模的灵活性,并且具有与图像大小相关的线性计算复杂性。Swin-Transformer的这些特性使其能够兼容广泛的视觉任务,包括图像分类(ImageNet-1K上的准确率为86.4 top-1)和密集预测任务,如目标检测(COCO test-dev上的58.7 box AP51.1 mask AP)和语义分割(ADE20K val上的53.5 mIoU)。它的性能超过了以前的先进水平,COCO上的box-AP和mask-AP分别为+2.7和+2.6,ADE20K上的mask-AP和+320万,显示了基于Transformer的模型作为视觉backbone的潜力。


Top  ---  Bottom

1、介绍

​ 在计算机视觉建模过程中,CNN网络取得了优良的性能表现,过去几年基于CNN网络做了大量的工作;而在NLP领域,发展至今,Transformer越来越成为baseline,它对于处理长期依赖有较好的表现。它在语言领域的巨大成功促使研究人员研究它对计算机视觉的适应性,最近它在某些任务上显示了有希望的结果,特别是图像分类[19]和联合视觉语言建模[46]

​ 在本文中,我们试图扩展Transformer的适用性,使它可以像CNNs一样作为计算机视觉的通用backbone。我们观察到,将其在语言领域的高性能转移到视觉领域的重大挑战可以解释为两种模式之间的差异。其中一个差异涉及规模。与作为语言Transformer中处理的基本元素的单词tokens不同,视觉元素在规模上可能有很大的差异,这是一个在目标检测等任务中受到关注的问题[41,52,53]。目前基于Transformer的方法,tokens是固定大小的,这个特性并不能适用于视觉。另一个差异主要是,相比于文本,视觉的像素具有更高的分辨率。如实例分割我们需要在像素级处理、计算,这样 self-attention 的计算复杂度就非常高了。为了克服这个问题,我们提出了通用Transformer backbone: (Swin Transformer),该算法构造层次化特征映射,计算复杂度与图像大小呈线性关系。如下图所示:

Swin-Transformer构造了一个层次表示,从小尺寸的像素块(用灰色表示)开始,逐渐合并更深层次的像素块。有了这些分层特征映射,Swin-Transformer模型可以方便地利用先进的技术进行密集预测,如特征金字塔网络(FPN)[41]或U-Net[50]。线性计算复杂度是通过在分割图像的非重叠窗口(红色轮廓)内局部计算自我注意来实现的。每个窗口中的像素块数是固定的,因此复杂度与图像大小成线性关系。这使得Swin-Transformer 作为backbone可以适应各种视觉任务。而之前用于视觉的Transformer技术只使用了单层特征图,且拥有二次复杂度。

​ Swin Transformer的一个关键设计元素是它在连续的自我关注层之间的窗口分区移动,如图2所示。

移动窗口(shifted window)桥接了前一层的窗口,提供了它们之间的连接,显著增强了建模能力(见表4)。这种策略对于延迟也是有效的:一个窗口中的所有查询像素块共享相同的key,这有助于硬件中的内存访问。我们的实验表明,所提出的移动窗方法比滑动窗方法有更低的延迟,但在建模能力上是相似的。

Swin Transformer实现了更强的性能表现:在延迟相似的前提下,它比ResNe(X)t models、ViT / DeiT 都要好!实现了(58.7\%)的box AP,(51.1\%)的mask AP,测试数据 COCO test-dev set,相比于之前的SOTA模型分别提高了(2.7 P)(2.6P);mIoU值提高了3.2,在ImageNet-1K上达到(86.4\%)的Top-1 accuracy。

​ 我们相信,一个跨计算机视觉和自然语言处理的统一体系结构可以使这两个领域都受益,因为它将促进视觉和文本信号的联合建模,并且来自这两个领域的建模知识可以更深入地共享。我们希望Swin Transformer在各种视觉问题上的出色表现能够推动社区加深这种信念,并鼓励视觉和语言信号的统一建模。


Top  ---  Bottom

2、相关工作

2.1 CNN及其变体

AlexNet ---> VGG、GoogleNet、ResNet、 DenseNet、 HRNet、EffificientNet。

基于这些演进产生了著名的:

  • 深度可分离卷积
  • 形变卷积

2.2 基于backbone结构的自注意力机制

​ self-attention layers目前被学者热衷与替换ResNet中的某个卷积,这里主要是基于局部窗口优化,它们确实是提高了性能。但是提高性能的同时,也增加了计算复杂度。我们使用shift windows替换原始的滑动窗口,它允许在一般硬件中更有效地实现。

2.3 Self-attention/Transformers 作为 CNNs 的补充

​ 见名知义:在传统的backbone后面添加Self-attention/Transformers结构。我们的工作探索了Transformer对基本视觉特征提取的适应,是对这些工作的补充。

2.4 基于Transformer的backbone

​ 与Swin Transformer最相关的工作是ViT和它的继承者。ViT的开创性工作是在不重叠的中等尺寸图像块上直接应用一种Transformer结构进行图像分类。与卷积网络相比,它在图像分类上实现了令人印象深刻的速度精度折衷。但是ViT需要大量的图片才能训练好网络,DeiT改进了训练策略,使得需要的图片集 变小 。虽然ViT在图片分类上有所提高,但是它不适合于高分辨率图片,因为它的复杂度是图片大小的二次方。将ViT模型应用于目标检测和语义分割等稠密视觉任务中的直接上采样或解卷积算法,效果相对较差。而我们的工作也是魔改了ViT,使其在图片分类任务上进一步提升。根据经验,我们发现我们的Swin-Transformer架构可以在这些图像分类方法中实现最佳的速度-精度折衷,尽管我们的工作侧重于通用性能,而不是专门针对分类。有其他学习和也在做多尺度分辨率的融合工作,但是其复杂度还是二次,而我们的复杂度是线性复杂度。我们的模型是兼顾了模型性能与速度,在COCO目标检测和ADE20K语义分割上达到新的SOTA。


Top  ---  Bottom

3、方法

3.1 整体架构

​ 图3给出了Swin-Transformer体系结构的概述,说明了微型版本(Swin-T)。

它首先通过像ViT一样的分片模块将输入的RGB图像分片成不重叠的像素块。每个像素块被视为一个“token”,其特征被设置为原始像素RGB值的串联。我们使用的像素块是(4 imes 4)的size,所以其特征维度为(4 imes 4 imes 3 = 48)。在这个原始值特征上应用一个线性嵌入层,将其投影到任意维(表示为(C))。

​ 在stage1中,几个Swin Transformer blocks算子被应用于这些像素块上。这些 Transformer blocks保持了(frac{H}{4} imes frac{W}{4})的tokens数量,并且伴随线性的嵌入层。

stage2中,为了产生一个层次化的表示,由于像素块的合并使得tokens的数量减少了。第一次patch merging layer合并了(2 imes 2)领域内的像素块,并且使用一个线性层在(4C)的特征上进行合并。这个操作减少了(2 imes 2 = 4)倍的tokens,并设置输出的维度为(2C)。这里的Transformer blocks应用于特征变换后,tokens的数量变为(frac{H}{8} imes frac{W}{8})。这第一个像素块融合和特征变换被称为stage2。这种操作进行叠加产生了stage3stage4,如图所示,tokens的数量分别为:(frac{H}{16} imes frac{W}{16})(frac{H}{32} imes frac{W}{32})。这些阶段共同产生一个层次表示,具有与典型卷积网络相同的特征图分辨率,如VGG [51] and ResNet [29]。结果表明,该体系结构可以很方便地取代现有方法中的backbone,用于各种视觉任务。


Top  ---  Bottom

3.1.1 Swin Transformer block

​ Swin Transformer块使用了shifted windows替换了传统的多头注意力机制MSA,如上图3(b)。Swin Transformer block是由基于MSA的shifted windows组成,它的前面有LN(LayerNorm)层,后面有LN + MLP包围,且有残差进行连接。


Top  ---  Bottom

3.2 基于自注意力的Shifted Window

标准的Transformer架构适应于图像分类,主要采用了相对位置编码的全局自注意力机制,而全局计算的复杂度是二次的,表现为tokens的数量。这在很多视觉任务中都会带来速度损失,且在高分辨率下表现出很强的不适应性。

3.2.1 非重叠窗口的自注意力

​ 为了高效地计算,我们采用局部窗口。这些窗口是均匀排列,且相互不重叠。假设窗口包含(M imes M)个像素块,则 global MSA 在(h imes w)的图像上的计算复杂度为:

[egin{align} Omega left ( MSA ight )&=4heC^{2} + 2left ( hw ight )^{2}C,\ Omega left ( Wmbox{-}MSA ight )&=4heC^{2} + 2M^{2} hwC, end{align} ]

这里MSA的复杂度是hw的二次方,而(M^{2})是远小于(hw)的,所以它是(hw)的一次复杂度。


Top  ---  Bottom

3.2.2 连续块的移位窗口划分

​ 基于窗口的自注意模块缺乏跨窗口的连接,这限制了它的建模能力。为了在保持非重叠窗口计算效率的同时引入跨窗口连接,我们提出了一种 shifted window 划分方法,该方法在连续的Swin-Transformer块中交替使用两种划分配置。

​ 如图2所示,第一个模型使用了正则窗口划分策略。从左上角开始,将(8 imes 8)的像素块划分为(M imes M (M=4))(2 imes 2)的像素块。下一个模型的策略是shifted。对上一层划分的窗口进行移动:向左上角移动(left ( left lfloor frac{M}{2} ight floor , left lfloor frac{M}{2} ight floor ight ))(2 imes 2)的像素块。利用移窗划分方法,连续的Swin-Transformer块的计算公式为:

[egin{align} hat{Z}^{l} &= Wmbox{-}MSAleft ( LNleft ( Z^{l-1} ight ) ight ) + Z^{l-1},\ Z^{l} &= MLPleft ( LNleft ( hat{Z}^{l} ight ) ight ) + hat{Z}^{l},\ hat{Z}^{l+1} &= SWmbox{-}MSAleft ( LNleft ( Z^{l} ight ) ight ) + Z^{l},\ Z^{l+1} &= MLPleft ( LNleft ( hat{Z}^{l+1} ight ) ight ) + hat{Z}^{l+1} end{align} ]

上面的公式对应了图3(b)所示的结构。

移动窗口划分策略,实现了相邻的非重叠窗口的连接,经过实验我们发现它对于图像分类、目标检测、语义分割是高效的!参考Table 4.

注意: W-MSA中的像素块特征数是一致的,但是SW-MSA它可不是一致的,这个怎么计算呢?


Top  ---  Bottom

3.2.3 shifted策略的高效batch计算

shifted操作使得像素块patches的个数(left lceil frac{h}{M} ight ceil imes left lceil frac{w}{M} ight ceil)从变为(left (left lceil frac{h}{M} ight ceil + 1 ight ) imes left (left lceil frac{w}{M} ight ceil + 1 ight ))。如图2所示。且这里部分窗口的大小不是(M imes M)。最简单的方法是直接将所有窗口padding到同样的大小。如果正则策略划分的较小,如上面的(2 imes 2),那么将增加计算量。然而我们提出了一种批量的高效计算方法:循环向左上角移动(cyclic-shifting,如下图所示:

cyclic-shifting实际上就是将移动造成的非(M imes M)像素块合并为(M imes M)像素块,或者你可以理解为之前是窗口在移动,而现在是特征图在移动,超过左上角window的部分在右下脚进行填充!经过cyclic-shifting的调整,实际上每个像素块的大小又一致了,同时我们还实现了不同的patch之间的信息融合,且patches的个数没有发生变化。


Top  ---  Bottom

3.2.4 相对位置偏置

在自注意力模块的计算中,我们引入相对位置偏置到每一个头的计算:

[mathbf{Attention} left(Q,K,V ight) = mathbf{SoftMax} left(QK^{T} / sqrt{d} + B ight)V ]

其中(Q,K,V in mathbb{R}^{M^{2} imes d})分别标识查询矩阵、键矩阵、值矩阵;(d)是查询矩阵或键矩阵的dimension,(M^{2})是窗口内的像素块个数。

Note: 这里的像素块是基本像素块单元,即上文的(2 imes 2)像素块,(M^{2})即窗口内的(2 imes 2)像素块个数。

因为相对位置的范围是(left[ -M +1, M-1 ight]),我们参数化了一个小尺寸的bias矩阵(hat{B} in mathbb{R}^{left( 2M-1 ight) imes left( 2M-1 ight)}),并且我们的(B)是从(hat{B})中取的一个token。

​ 如表4所示,我们观察到与没有此偏差项或使用绝对位置嵌入的对应项相比有显著改进。如[19]中所述,进一步向输入中添加绝对位置嵌入会略微降低性能,因此在我们的实现中不采用这种方法。

​ 预训练中学习到的相对位置bias矩阵也可用于初始化模型(hat{B}),以便通过双三次插值以不同的窗口大小进行微调[19,60]。


Top  ---  Bottom

3.3 结构变体

作者设置的模型架构有:

  • Swin-T(C=96 quad layer \, numbers={2,2,6,2})
  • Swin-S(C=96 quad layer \, numbers={2,2,18,2})
  • Swin-B(C=128 \,\, layer \, numbers={2,2,18,2})
  • Swin-L(C=192 \,\, layer \, numbers={2,2,18,2})

这里的最基础模型是Swin-B模型,它和ViT-B/DeiT-B模型的计算复杂度一样。Swin-T, Swin-S and Swin-L分别是继承模型size的(0.25 imes,0.5 imes,2 imes)的放缩。要注意的是Swin-T, Swin-S的计算复杂度分别与ResNet-50 (DeiT-S) and ResNet-101相当。所有实验的配置中,默认窗口的大小设置为7;每个头的查询矩阵维度(d=32),且MLP扩张层的(alpha=4)。其中(C)是第一阶段中隐藏层的通道数。ImageNet图像分类模型变量的模型大小、理论计算复杂度(FLOPs)和吞吐量如表1所示。


Top  ---  Bottom

4、构建模型接口

接口: Swin-Transformer.models.build,build_model

这里主要对要构建的模型进行识别,Swin项目当然只支持Swin项目,如果有其他配置需要添加,可以在else部分修改。

主要的代码为:

model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
                        patch_size=config.MODEL.SWIN.PATCH_SIZE,
                        in_chans=config.MODEL.SWIN.IN_CHANS,
                        num_classes=config.MODEL.NUM_CLASSES,
                        embed_dim=config.MODEL.SWIN.EMBED_DIM,
                        depths=config.MODEL.SWIN.DEPTHS,
                        num_heads=config.MODEL.SWIN.NUM_HEADS,
                        window_size=config.MODEL.SWIN.WINDOW_SIZE,
                        mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
                        qkv_bias=config.MODEL.SWIN.QKV_BIAS,
                        qk_scale=config.MODEL.SWIN.QK_SCALE,
                        drop_rate=config.MODEL.DROP_RATE,
                        drop_path_rate=config.MODEL.DROP_PATH_RATE,
                        ape=config.MODEL.SWIN.APE,
                        patch_norm=config.MODEL.SWIN.PATCH_NORM,
                        use_checkpoint=config.TRAIN.USE_CHECKPOINT)

Note: 这里作者使用yacs.config进行配置,相关内容可以参考Swin-Transformer.config.py

SwinTransformer类的参数几乎是见名知义,关于参数的具体含义为:

"""
Args:
    img_size (int | tuple(int)): 输入图像大小. Default 224
    patch_size (int | tuple(int)): 像素块的大小. Default: 4
    in_chans (int): 输入图像的通道数. Default: 3
    num_classes (int): 分类数. Default: 1000
    embed_dim (int): 像素块编码的维度. Default: 96
    depths (tuple(int)): 每个Swin Transformer层的深度.
    num_heads (tuple(int)): 不同层的attention头数量.
    window_size (int): shfited window的大小,即M. Default: 7
    mlp_ratio (float):  MLP hidden dim到embedding dim的比率. Default: 4
    qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
    qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
    drop_rate (float): Dropout rate. Default: 0
    attn_drop_rate (float): Attention dropout rate. Default: 0
    drop_path_rate (float): Stochastic depth rate. Default: 0.1
    norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
    ape (bool): 是否添加绝对位置编码. Default: False
    patch_norm (bool): 是否在patch embedding后添加normalization. Default: True
    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""

Top  ---  Bottom

5、SwinTransformer

SwinTransformer的初始化参数前面已经介绍过,下面我们梳理其模型构建的组件----属性变量。

self.num_features

这个变量并没有参数直接传递,其计算为:int(embed_dim * 2 ** (len(depths) - 1))

self.apeself.absolute_pos_embed

self.ape是否使用绝对位置编码。self.absolute_pos_embed在初始化的时候,是一个shape为(1, num_patches, embed_dim)的全零张量,并使用trunc_normal_进行截断正态分布初始化,其标准差为0.02。

将PatchEmbed与绝对位置编码相加 [绝对位置编码为可选] 之后,再对合并后的特征图进行随机dropout。再根据self.layers生成每一个stage的前向传播。stage的构建是基于BasicLayer生成!可参考5.2。

SwinTransformer的主要由 PatchEmbed + layer1 + layer2 + layer3 + layer4 构成。

分类模型这里主要接:LN + AvgPool1d + Head: nn.Linear

5.1 PatchEmbed

self.patch_embed

将图像分割为不重叠的像素块。

self.patch_embed = PatchEmbed(
   img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
   norm_layer=norm_layer if self.patch_norm else None
)

PatchEmbed是基于torch实现的

参数:

  • img_size (int): Image size. Default: 224.
  • patch_size (int): Patch token size. Default: 4.
  • in_chans (int): Number of input image channels. Default: 3.
  • embed_dim (int): Number of linear projection output channels. Default: 96.
  • norm_layer (nn.Module, optional): Normalization layer. Default: None

其他属性:

  • self.patches_resolution:像素块分辨率,即高宽对应的像素块分割数

    [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
    
  • self.num_patches:像素块的数量,即 patches_resolution[0] * patches_resolution[1]

  • self.proj:这里使用卷积进行编码,卷积核的大小就是patch_size,步长也是patch_size。所以这个卷积处理后patch的shape变为(1 imes1 imes embed\_dim)

前向传播:

def forward(self, x):
    B, C, H, W = x.shape
    # FIXME look at relaxing size constraints
    assert H == self.img_size[0] and W == self.img_size[1], 
        f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
    x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
    if self.norm is not None:
        x = self.norm(x)
    return x

前向传播首先验证输入特征图的shape是否满足初始化中的image配置self.img_size。然后进行patch上的卷积,再进行Normalization layer标准化[可选]。如图所示:

总结PatchEmbed:

像素块编码主要是使用卷积在原features上进行卷积,然后再接一个可选的BN层。卷积的卷积核大小就是像素块的大小即 patch_size,步长也为 patch_size。这样实现了不同像素块之间不会有信息融合,也即只对像素块进行编码,这与传统的CNN滑窗相比,减少了大量的卷积操作!卷积后的features如图所示,在dim=2和dim=3的维度上将张量平铺,在将最后两个维度转置,得到patch卷积输出,其shape为:(batch_size, ph*pw, channels)。BN层就是pytorch中的BN层。


Top  ---  Bottom

5.2 BasicLayer简介

BasicLayer从图像上很好理解,它是论文中stage1~stage4的生成单元。

build layers:

self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
    layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                       input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                         patches_resolution[1] // (2 ** i_layer)),
                       depth=depths[i_layer],
                       num_heads=num_heads[i_layer],
                       window_size=window_size,
                       mlp_ratio=self.mlp_ratio,
                       qkv_bias=qkv_bias, qk_scale=qk_scale,
                       drop=drop_rate, attn_drop=attn_drop_rate,
                       drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                       norm_layer=norm_layer,
                       downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                       use_checkpoint=use_checkpoint)
    self.layers.append(layer)

BasicLayer参数:

"""
dim (int): 输入的通道数.
input_resolution (tuple[int]): 输入特征图的分辨率.
depth (int): blocks数量,即当前stage的深度.
num_heads (int): attention头的数量.
window_size (int): shifted局部窗口的大小.
mlp_ratio (float): mlp 隐含层的维度到embedding dim的比率.
qkv_bias (bool, optional): 给query, key, value添加一个可学习的偏置. 默认: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. 默认: 0.0
attn_drop (float, optional): Attention dropout rate. 默认: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. 默认: 0.0
norm_layer (nn.Module, optional): Normalization layer. 默认: nn.LayerNorm
downsample (nn.Module | None, optional): 在stage的最后进行下采样的Downsample layer. 默认: None
use_checkpoint (bool): 是否使用checkpoint对当前层进行保存. 默认: False.
"""

每一个stage都是由depth个SwinTransformerBlock组成,就如何残差神经网络中的残差块一样。在前向传播中,BasicLayer非常简单,就是使用SwinTransformerBlock构建基本框架再加一个可选的下采样层。

其中SwinTransformerBlock的引入为:

# build blocks
self.blocks = nn.ModuleList([
    SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                         num_heads=num_heads, window_size=window_size,
                         shift_size=0 if (i % 2 == 0) else window_size // 2,
                         mlp_ratio=mlp_ratio,
                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                         drop=drop, attn_drop=attn_drop,
                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                         norm_layer=norm_layer)
    for i in range(depth)])

这里的self.blocks主要包含了当前stage的主要SwinTransformerBlock块。

下采样层主要是做像素的融合。

# patch merging layer
if downsample is not None:
    self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
    self.downsample = None

SwinTransformerBlock参考5.3。


Top  ---  Bottom

5.3 SwinTransformerBlock

参数:

"""
dim (int): 输入的通道数.
input_resolution (tuple[int]): 输入特征图的分辨率.
num_heads (int): attention头的数量.
window_size (int): 局部窗口的大小.
shift_size (int): SW-MSA的移动size.
mlp_ratio (float): mlp 隐含层的维度到embedding dim的比率.
qkv_bias (bool, optional): 给query, key, value添加一个可学习的偏置. 默认: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. 默认: 0.0
attn_drop (float, optional): Attention dropout rate. 默认: 0.0
drop_path (float, optional): Stochastic depth rate. 默认: 0.0
act_layer (nn.Module, optional): 激活层. 默认: nn.GELU
norm_layer (nn.Module, optional): Normalization layer.  默认: nn.LayerNorm
"""

attn_mask

当shift_size为0时,attn_mask为None。当它不为0时,那么那么由于窗口的移动,会让原来的特征图被划分为9个区域。这9个区域的划分原则是:新窗口来源是否是移动后合并产生的。以(input\_resolution=(8,8))为例,有如下区域划分:

# 生成全零张量
img_mask = torch.zeros((1, H, W, 1))
# 按区域划分mask
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

此时的img_mask.squeeze(dim=3)为:

tensor([[[0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [6., 6., 6., 6., 7., 7., 8., 8.],
         [6., 6., 6., 6., 7., 7., 8., 8.]]])

然后我们可以获取新生成的windows的mask:

mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

新生成的每个窗口的mask为:

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

这里的attn_mask会传给WindowAttention用于窗口内的多头注意力计算。实际就是在WindowAttention中的softmax之前将添加偏置的(QK^{T} / sqrt{d} + B)再加一个mask信息。如最后依据所示,不等于0的那些点全部将mask值置为(-100)。这样实现了对移动拼接产生的window注意力输出产生一个偏置。

前向传播:

  • 第一步:检测定义的输入分辨率是否与输入的特征图x的L(序列长度)相同;

  • 第二步:使用self.norm1进行特征标准化,再将数据view(B, H, W, C);

  • 第三步:使用torch.roll移动特征图;

  • 第四步:使用window_partition划分窗口,这里是在shifted_x上面划分,得到(num_windows*B, window_size, window_size, C)的特征图,再view(-1, self.window_size * self.window_size, C)。

  • 第五步:实现W-MSA/SW-MSA 结构

    # num_windows*B, window_size*window_size, C
    attn_windows = self.attn(x_windows, mask=self.attn_mask)
    

    其中 x_windows 是 shifted_x 的窗口划分,self.attn 是WindowAttention的的实例。W-MSA/SW-MSA 的实现区别主要为是否使用shifted。

SwinTransformerBlock主要就是W-MSA/SW-MSA的实现,其结构为:(LN + (W-MSA/SW-MSA) + LN + MLP)。要注意的是shifted的特征图最后我们会还原。这里的LN为nn.LayerNorm;MLP为作者自己的实现。

MLP:

MLP是:全连接层 + 激活层 + Dropout层 + 全连接层 + Dropout层

self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

Top  ---  Bottom

5.4 WindowAttention

基于带有相对位置偏置的多头注意力模型的移动/非移动窗口注意力模型

参数:

"""
dim (int): 输入通道数.
window_size (tuple[int]): 局部窗口的大小.
num_heads (int): attention头的数量.
qkv_bias (bool, optional):  给query, key, value添加一个可学习的偏置. 默认: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. 默认: 0.0
proj_drop (float, optional): Dropout ratio of output. 默认: 0.0
"""

窗口注意力层的初始化:

相对位置偏置

# 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(
   torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  
trunc_normal_(self.relative_position_bias_table, std=.02)

self.relative_position_bias_table使用了截断正太分布进行初始化,标准差为0.02。

​ 如论文中所示,(M)标识窗口的大小,那么初始化的偏置矩阵是(hat{B} in mathbb{R}^{left( 2M-1 ight) imes left( 2M-1 ight)}),为什么是(left( 2M-1 ight) imes left( 2M-1 ight))?后面再说明这个问题!

WindowAttention层最主要的就是相对位置偏置的编码部分比较复杂,其他操作都是我们熟悉的torch层,所以,这里仔细研究其处理过程。

相对位置偏置(B)是从(hat{B})的一个token。所以(hat{B})存储了所有的偏置,(B)要通过索引获取。下面是索引的生成:

coords:记录了窗口的坐标,原点为窗口左上角;

coords_flatten:记录了坐标的平铺;sahpe为 ( 2, (M^{2}));

relative_coords:记录了窗口内的像素(像素块)的相对位置;如像素块 (patch_{a})(M^{2})个相对位置,因为窗口内有(M^{2})个像素块。

​ 作者在项目中的实现方式是:

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

因此,relative_coords的shape为 ((M^{2})(M^{2}),2)。此时的relative_coords[0, :, :]标识的是(h, w)=(0, 0) 到所有点的相对坐标。注意此时窗口内的任意两个坐标的相对位置我们都有了

但是作者又进行了下面的操作:

relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

前两行做的事情是将相对左边都移动到从0开始。第三行是将高H的相对坐标乘以了(2M-1)

乘以(2M-1)是何用意?self.relative_position_bias_table是初始化的偏置表格,我们需要使用索引进行获取,而relative_coords是生成索引的关键,生成索引的代码为:

relative_position_index = relative_coords.sum(-1)

这里的索引本质上等同于偏置,如果索引相同则偏置也相同。首先我们探讨relative_position_index应该有什么样的性质:

  • 当像素块(patch_{a})与像素块(patch_{a+1})同行(或者同列);像素块(patch_{b})与像素块(patch_{b+1})同行(或者同列)时。像素块 (patch_{a}) 与像素块 (patch_{b}) 之间的偏置应该和像素块 (patch_{a+1}) 与像素块 (patch_{b+1}) 之间的偏置一样!即:

    relative_position_index[i,j] = relative_position_index[M-1-j,M-1-i]
    

    注意上面的等式应该满足绝大多数情况,但是在W的边界上不应该满足,因为我们的数据是行优先排列,索引+1即为下一个像素块,如果当前是宽的边界,那么下一个就换行了。因此我们的索引最好满足(j+1)除以(l)余数不为0,其中(l in left { M, 2M,cdots ,M^{2} ight })

    比较繁琐的是宽的右边界上,下一个换行了,但是这样的模式下,偏置应为也一样!如第(3M-1)像素块相对于第0个像素块的偏置 和 (4M-1)像素块相对于第(M)个像素块的偏置是一样的!

  • 那么问题来了,基于上面的准则,我们至少需要多少个偏置?由论文我们知道,需要(left (2M-1 ight ) imes left (2M-1 ight ))个。这是怎么计算的呢?

    首先,relative_position_index矩阵的shape为((M^{2})(M^{2}))。在主对角线方向上,我们共有(2M^{2}-1)条线,每条线都只有一种或两种偏置索引,原因参考上面的规则说明。那到底哪些是1,哪些是2呢?推导可以发现,每条线组成偏置索引数量的序列为:

    这个序列2的个数为:(2left(M-1 ight) imes left(M-1 ight));1的数量为:(2left(M-1 ight) + 2M - 1)

    所以我们需要的索引数量为:

    [egin{align} 2 imes 2&left(M-1 ight) imes left(M-1 ight) + 2left(M-1 ight) + 2M - 1\ &= 4 left( M^{2} - 2M + 1 ight) + 4M - 3 \ &=4M^{2} - 8M + 4 + 4M - 3 \ &=4M^{2} -4M + 1\ &=left( 2M - 1 ight)^{2} end{align} ]

    实际上到这里我们大概就知道乘以(2M-1),就是为了让索引满足上述需求,且索引最小值到最大值是连续的!小于(2M-1)时,索引矩阵就不能满足上面的规则;大于(2M-1)时,索引矩阵的值就不是连续的!那么为什么是(2M-1)

    解释:

    • (1)在高H对应的特征图上,每个(M imes M)的块是一样的,且,主对角线方向上是一样的,这样就会产生(2M-1)个不同(M imes M)的纵坐标H的索引块;

    • (2)上诉高的每一个(M imes M)块对应的宽的索引块是一样的;

    • (3)一个(M imes M)的宽索引块,其宽的索引取值范围是(left[ 0, 2M-2 ight])

    • (4)对于索引我们要使用:(H imes x + W)的形式获取最终的相对位置索引,那么对于每一行我们乘以了(x),我们仍然需要保持其相邻的(M imes M)块之间的大小关系,对于高H,相邻的高索引都是相差1,假设当前块的行索引为(m),那么有:

      [mx+0> left ( m-1 ight )x + 2M-2 Rightarrow x > 2M-2 ]

    • (5)对于relative_coords的左下角元素,合并高索引与宽索引后为:

      [left ( 2M-2 ight )x + 2M-2 leq left ( 2M-1 ight )^{2} Rightarrow x leq 2M-1 ]

    • (6)由不等式(6)、(7)可以得出乘的数只能是(2M-1)

至此我们得到相对位置偏置的索引了,比如M=4,我们可以得到如下的索引:

tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7,  3,  2,  1,  0],
        [25, 24, 23, 22, 18, 17, 16, 15, 11, 10,  9,  8,  4,  3,  2,  1],
        [26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10,  9,  5,  4,  3,  2],
        [27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10,  6,  5,  4,  3],
        [31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7],
        [32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15, 11, 10,  9,  8],
        [33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10,  9],
        [34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],
        [38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],
        [39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],
        [40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],
        [41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17],
        [45, 44, 43, 42, 38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21],
        [46, 45, 44, 43, 39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22],
        [47, 46, 45, 44, 40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23],
        [48, 47, 46, 45, 41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24]])

其他初始化

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

前向传播:

  • 首先获取q, k, v。

    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    

    注意这里是使用线性变换将维度扩大到3倍,使其与q, k, v对应。

    生成q, k, v的代码块:

    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    

    注意这里的x的shape为 (num_windows*B, N, C) ,而上面的3可以理解为将新生成的通道分为q, k, v三份,再将每一份的通道数C拆为:self.num_heads 与 C // self.num_heads 两个维度,以实现多头机制。

  • 计算注意力:(QK^{T} / sqrt{d})

    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))
    
  • 给注意力添加偏置

    relative_position_bias = self.relative_position_bias_table[
        self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], 
        self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
    attn = attn + relative_position_bias.unsqueeze(0)
    
  • 实现多头注意力

    [mathbf{Attention} left(Q,K,V ight) = mathbf{SoftMax} left(QK^{T} / sqrt{d} + B ight)V ]

Top  ---  Bottom

完!

清澈的爱,只为中国
原文地址:https://www.cnblogs.com/dan-baishucaizi/p/14661164.html