SwinTransformer代码详解
官方代码:https://github.com/microsoft/Swin-Transformer/
参考:论文详解:Swin Transformer - 知乎 (zhihu.com)
【深度学习】详解 Swin Transformer (SwinT)-CSDN博客
Swin Transformer 论文详解及程序解读 - 知乎 (zhihu.com)
https://zhuanlan.zhihu.com/p/401661320
AI大模型系列之三:Swin Transformer 最强CV图解(深度好文)_cv大模型-CSDN博客
整体架构
1 | class SwinTransformer(nn.Module): |
patch embedding
self.proj
将patch partition与Linear Embedding合并了,利用二维卷积操作,将stride,kernel_size设置为window_size大小,输出(N, 96, 56, 56)
输入SwinTransformer块之前要把张量展平:x = torch.flatten(x, 2)
,x = torch.transpose(x, 1, 2)
1 | class PatchEmbed(nn.Module): |
SwinTransformer块
1 | # window_size=7 |
- Swin-T的位置编码是一个可选项(
self.ape
),Swin-T 是在计算 Attention 的时候做了一个相对位置编码 - ViT 会单独加上一个可学习参数,作为分类的 token。而 Swin-T 则是直接做平均(avgpool),输出分类,类似CNN最后的全局平均池化层。
Window Partition
window partition
函数是用于对张量划分窗口,指定窗口大小。将原本的张量从[B,H,W,C], 划分成 shape = (),为窗口的大小,窗口个数N=。这个函数会在后面的Window Attention
用到。
1 | def window_partition(x, window_size): |
Window Reverse
而window reverse
函数则是对应的逆过程,将shape = ()的窗口张量reshape回[B,H,W,C]。这个函数会在后面的Window Attention
用到。
1 | def window_reverse(windows, window_size, H, W): |
Window Attention
- 在窗口内而非全局图像内计算自注意力可将计算复杂度由二次降为线性。
- 在计算原 Attention 的 Query 和 Key 时,加入相对位置编码 B可改善性能
在一个窗口内Q,K,V的shape = [numWindows x Bacth_size, num_heads, window_size * window_size, head_dim]
- window_size * window_size 即一个窗口中
token
的个数- head_dim是token的维度
得到的张量形状为 [numWindows x Bacth_size, num_heads, Q_tokens, K_tokens]
Q_tokens = K_tokens = window_size * window_size
相对位置编码(RPE)介绍
绝对位置编码最常见的一种位置编码方法,其思想是在每个输入序列的元素上添加一个位置向量,以表示该元素在序列中的具体位置。这个位置向量通常通过固定的函数生成,与输入数据无关。通常使用的是正弦和余弦函数,这样生成的编码具有很强的周期性,能够捕捉序列中的相对位置信息。
相对位置编码不直接为每个位置分配一个唯一的编码,而是关注序列中各元素之间的相对位置。 相对位置编码的核心思想是通过计算序列中元素之间的距离,来表示它们之间的相对关系。这种方法尤其适合处理需要捕捉长距离依赖关系的任务,因为它能够更加灵活地表示序列中的结构信息。
VIT中的绝对位置编码:
在VIT中采用的是1D绝对位置编码,且在embeddding过程中直接与 patch embedding相加
1
2
3
4 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# patch emded + pos_embed :图像块嵌入 + 位置嵌入
x = x + self.pos_embed
Swin Transformer中的相对位置编码:
而在Swin Transformer中如公式所示,是在QKV的过程中将得到的张量根据相对位置索引加上对应的位置偏置,之后再与V算softmax。
在Swin Transformer中相对位置编码采用的是2d编码的方式,由于得到的张量形状为
[numWindows x Bacth_size, num_heads, window_size * window_size, window_size * window_size]
,B的形状应该相同,排除numWindows x Bacth_size
,一个窗口内的B形状应该为[num_heads, window_size * window_size, window_size * window_size]
由原论文中可以知道B中的值取自。是参数化得到的偏置矩阵 ,M = window_size
相对位置编码原理
以window_size = 2 * 2即每个窗口有4个token (M=2) 为例:
注意:
- 论文详解:Swin Transformer - 知乎 (zhihu.com)代码跟演示图是匹配的,
- 【深度学习】详解 Swin Transformer (SwinT)-CSDN博客,Swin Transformer 论文详解及程序解读 - 知乎 (zhihu.com)这两篇代码跟演示图不匹配
绝对位置索引:在计算每个token的自注意力时,都以唯一的一个token为中心点
每个token用二维坐标(x,y)表示,这里以第一个token为中心点进行二维坐标的表示。
当计算每个个token的自注意力时,各个token的位置索引都是下图
- 第 i 行表示第 i 个 token 的
query
对所有token的key
的attention。- 对于 Attention 张量来说,以不同元素为原点,其他元素的坐标也是不同的
相对位置索引:在计算每个token的自注意力时,都以当前token为中心点
当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。
以 window_size = M = 2 为例,生成网格 grid 坐标:
1
2
3
4
5
6
7
8
9 coords_h = torch.arange(self.window_size[0]) # x坐标
coords_w = torch.arange(self.window_size[1]) # y坐标
coords = torch.meshgrid([coords_h, coords_w]) # -> 2个(wh, ww)
"""
(tensor([[0, 0],
[1, 1]]), 分别代表四个token的行坐标
tensor([[0, 1],
[0, 1]])),分别代表四个token中的纵坐标
"""堆叠并展开为 2D 向量:
1
2
3
4
5
6 coords = torch.stack(coords) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
"""分别在第 1 和 2 维处插入新维度,并利用广播机制做减法,得到 shape = (2, , ) 的张量:
1
2
3 relative_coords_first = coords_flatten[:, :, None] # 2, wh * ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh * ww
relative_coords = relative_coords_first - relative_coords_second # 2, wh * ww, wh * wwPS:上一行代码相当重要,我看了几篇博客一直没弄懂,发现原来是有的博客给的图有问题,正是这一行代码实现了与形状相对应的相对位置索引!
relative_coords_first - relative_coords_second
这一行代码实现的就是该点的相对位置索引=中心点的绝对位置索引-该点的绝对位置索引其中第一个张量第i行代表着以第i个token为中心点的各个token的x坐标
第二个张量第i行代表着以第i个token为中心点的各个token的y坐标
未处理前的相对位置编码索引:第i行第j列表示以第i个token为中心点时第j个token的相对位置索引
由于相减得到的索引是从负数开始的,故加上偏移量使之从 0 开始:
1
2
3 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1之所以是加的偏移量是self.window_size[0] - 1与self.window_size[1] - 1是因为相对位置x,y索引值最小分别为-(self.window_size[0] - 1),-(self.window_size[1] - 1)
接着,需要将其展开成 1D 偏移量。对于诸如第 0 行上 (1, 2) 和 (2, 1) 这两个不同的坐标 (x, y),通过将 (x, y) 坐标求和得到 1D 偏移量 x+y 时,二者所表示的相对于原点的偏移量却是相等的 (1+2 = 2+1 = 3):不同的位置却具有相同的偏移量,降低了相对区分度/差异度
为避免这种 偏移量相等 的错误对应情况,还需对坐标 (准确地说是 x 坐标) 进行 乘法变换 (offset multiply),以提高区分度:对 x 坐标实施乘法变换得到 (x’, y),再重新计算得到具有差异度的各坐标位置的偏移量 x’+y
1 relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) # 每个 x 坐标乘 (2 * 2 - 1) = 3
接着在最后一维上求和 x+y,展开成一个 1D 坐标 (相对位置索引),并注册为一个不参与网络学习的变量 relative_position_index,其作用是 根据最终的相对位置索引 找到对应的可学习的相对位置编码:
1
2 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)之前计算的是相对位置索引,并不是相对位置偏置参数。真正使用到的可训练参数是保存在
relative position bias table
表里的,这个表的长度是等于 (2M−1) × (2M−1) (在二维位置坐标中线性变化乘以2M-1导致)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table
表得到的。(偏移前最大索引为(M-1,M-1),偏移后最大索引为(2M-2,2M-2),乘法变换后为((2M-2)*(2M-1),2M-2),索引最大值等于(2M-2)x(2M-1)+2M-2,由于索引包括0,故表的总长度为(2M-2)x(2M-1)+2M-2+1=(2M−1) × (2M−1))
相对位置编码总体流程
在不考虑num_heads的情况下:
- 对于一个M*M的窗口,有个token,每一个token要跟包括自身在内的个token计算自注意力,故的大小为(,),所以B的大小也为(,)。
- 的第i行第j列为第i个token与第j个token算的结果,它要加上以第i个token为中心点时第j个token的相对位置编码。
- 因此对于相对位置编码,要计算分别以个token为中心点时其他token的相对位置编码
- 相对位置索引矩阵中第i行第j列对应以第i个token为中心点时第j个token的相对位置xy索引(未做处理)
- 加偏移量
- 乘法变换
- 根据总偏移量在相对位置便宜表中得到对应的值
- 得到
相对位置编码矩阵:每一列代表每一个坐标在所有坐标 “眼中” 的相对位置 上图最后生成的是相对位置索引relative_position_index.shape = ( ∗ ) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1) ∗ (2M−1) ,所以当M=2时相对位置编码可以由一个3 * 3的矩阵表示
以 M=2 的窗口为例,当计算第一个token对应的4个值时,四个token对应的relative_position_index分为别[4, 3, 1, 0] ,对应的数据就是图中位置索引4, 3, 1, 0位置对应的数据,即relative_position.shape = (∗)
前向过程
- 首先输入张量形状为
[numWindows * B, window_size * window_size, C]
- 然后经过
self.qkv
这个全连接层后,进行 reshape,调整轴的顺序,得到形状为[3, numWindows * B, num_heads, window_size*window_size, c//num_heads]
,并分配给q,k,v
。 - 根据公式,我们对
q
乘以一个scale
缩放系数,然后与k
(为了满足矩阵乘要求,需要将最后两个维度调换)进行相乘。得到形状为[numWindows*B, num_heads, window_size*window_size, window_size*window_size]
的attn
张量 - 之前我们针对位置编码设置了个形状为
(2 * window_size-1 * 2 * window_size-1, numHeads)
的可学习变量。我们用计算得到的相对编码位置索引self.relative_position_index.view(-1)
选取,得到形状为(window_size * window_size, window_size * window_size, numHeads)
的编码,再permute(2,0,1)后加到attn
张量上 - 暂不考虑 mask 的情况,剩下就是跟 transformer 一样的 softmax,dropout,与
V
矩阵乘,再经过一层全连接层和 dropout
1 | def forward(self, x, mask=None): |
总体代码
1 | class WindowAttention(nn.Module): |
以4*4窗口为例的shape变化
1 | import torch |
1 | (tensor([[[0, 0, 0, 0], |
1 | coords_flatten = torch.flatten(coords, 1) # (2, wh*ww) |
1 | (tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], |
1 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, wh*ww, wh*ww) |
1 | (tensor([[[ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3], |
1 | # (x, y) 格式显示 横、纵坐标 |
1 | (tensor([[[ 0, 0], |
1 | # 横坐标加性偏移 (+= 3) |
1 | tensor([[[ 3, 0], |
1 | # 纵坐标加性偏移 (+= 3) |
1 | tensor([[[3, 3], |
1 | # 横坐标乘性变换 (*= 7) |
1 | tensor([[[21, 3], |
1 | # 计算 1D 偏移量 (x+y) |
1 | (tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7, 3, 2, 1, 0], |
1 | # 设 MHA 的 heads 数为 3 |
1 | (Parameter containing: |
1 | relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view( |
1 | (tensor([[[0., 0., 0.], |
1 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
1 | (tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], |
Shifted Window Attention
1 | if self.shift_size > 0: |
patch merging
由于SwinTransformer块不改变输入的形状,故输出还是[B, H*W, C]
x0,x1,x2,x3如上图所示隔一选一,在通道维度上连接,之后经过全连接层将通道维度由4倍调整为2倍。
1 | class PatchMerging(nn.Module): |
MLP
1 | class Mlp(nn.Module): |
Basic Layer
Basic Layer 即 Swin Transformer 的各 Stage,包含了若干 Swin Transformer Blocks 及 其他层。
注意,一个 Stage 包含的 Swin Transformer Blocks 的个数必须是 偶数,因为需交替包含一个含有 Window Attention (W-MSA) 的 Layer 和含有 Shifted Window Attention (SW-MSA) 的 Layer。
1 | class BasicLayer(nn.Module): |