官方代码:https://github.com/microsoft/Swin-Transformer/

参考:论文详解:Swin Transformer - 知乎 (zhihu.com)

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客

Swin Transformer 论文详解及程序解读 - 知乎 (zhihu.com)

https://zhuanlan.zhihu.com/p/401661320

AI大模型系列之三:Swin Transformer 最强CV图解(深度好文)_cv大模型-CSDN博客

整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""

def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # ImageNet-1K 有1000个类别
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()

self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio

# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution

# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)

self.pos_drop = nn.Dropout(p=drop_rate)

# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule

# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)

self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}

@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}

def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)

for layer in self.layers:
x = layer(x)

x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x

def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops

patch embedding

self.proj将patch partition与Linear Embedding合并了,利用二维卷积操作,将stride,kernel_size设置为window_size大小,输出(N, 96, 56, 56)

输入SwinTransformer块之前要把张量展平:x = torch.flatten(x, 2),x = torch.transpose(x, 1, 2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""

def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]

self.in_chans = in_chans
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 输入嵌入投影
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None

def forward(self, x):
'''
# 以默认参数为例 # 输入 (B, C, H, W) = (B, 3, 224, 224)
x = self.proj(x) # 输出 (B, 96, 224/4, 224/4) = (B, 96, 56, 56)
x = torch.flatten(x, 2) # H W 维展平, 输出 (B, 96, 56*56)
x = torch.transpose(x, 1, 2) # C 维放最后, 输出 (B, 56*56, 96)
'''
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # shape = (B, P_h*P_w, C)
if self.norm is not None:
x = self.norm(x)
return x

def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

SwinTransformer块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# window_size=7 
# input_batch_image.shape=[128,3,224,224]
class SwinTransformer(nn.Module):
def __init__(...):
super().__init__()
...
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

self.pos_drop = nn.Dropout(p=drop_rate)

# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(...)
self.layers.append(layer)

self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x) # Patch Partition
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)

for layer in self.layers:
x = layer(x)

x = self.norm(x) # Batch_size Windows_num Channels
x = self.avgpool(x.transpose(1, 2)) # Batch_size Channels 1
x = torch.flatten(x, 1)
return x

def forward(self, x):
x = self.forward_features(x)
x = self.head(x) # self.head => Linear(in=Channels,out=Classification_num)
return x
  • Swin-T的位置编码是一个可选项(self.ape),Swin-T 是在计算 Attention 的时候做了一个相对位置编码
  • ViT 会单独加上一个可学习参数,作为分类的 token。而 Swin-T 则是直接做平均(avgpool),输出分类,类似CNN最后的全局平均池化层

Window Partition

window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从[B,H,W,C], 划分成 shape = (B×HM×WM,M,M,CB×\frac{H}{M}×\frac{W}{M},M,M,C),M,M,CM,M,C为窗口的大小,窗口个数N=B×HM×WMB×\frac{H}{M}×\frac{W}{M}。这个函数会在后面的Window Attention用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""

B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows

Window Reverse

window reverse函数则是对应的逆过程,将shape = (B×HM×WM,M,M,CB×\frac{H}{M}×\frac{W}{M},M,M,C)的窗口张量reshape回[B,H,W,C]。这个函数会在后面的Window Attention用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""

B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

Window Attention

  • 在窗口内而非全局图像内计算自注意力可将计算复杂度由二次降为线性。
  • 在计算原 Attention 的 Query 和 Key 时,加入相对位置编码 B可改善性能

Attention(Q,K,V)=SoftMax(QKT/d+B)V,Attention(Q, K, V ) = SoftMax(QK^T/\sqrt{d} + B)V,

在一个窗口内Q,K,V的shape = [numWindows x Bacth_size, num_heads, window_size * window_size, head_dim]

  • window_size * window_size 即一个窗口中token的个数
  • head_dim是token的维度

QKTQK^T 得到的张量形状为 [numWindows x Bacth_size, num_heads, Q_tokens, K_tokens]

Q_tokens = K_tokens = window_size * window_size

相对位置编码(RPE)介绍

绝对位置编码最常见的一种位置编码方法,其思想是在每个输入序列的元素上添加一个位置向量,以表示该元素在序列中的具体位置。这个位置向量通常通过固定的函数生成,与输入数据无关。通常使用的是正弦和余弦函数,这样生成的编码具有很强的周期性,能够捕捉序列中的相对位置信息。

相对位置编码不直接为每个位置分配一个唯一的编码,而是关注序列中各元素之间的相对位置。 相对位置编码的核心思想是通过计算序列中元素之间的距离,来表示它们之间的相对关系。这种方法尤其适合处理需要捕捉长距离依赖关系的任务,因为它能够更加灵活地表示序列中的结构信息。

VIT中的绝对位置编码:

在VIT中采用的是1D绝对位置编码,且在embeddding过程中直接与 patch embedding相加

1
2
3
4
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 

# patch emded + pos_embed :图像块嵌入 + 位置嵌入
x = x + self.pos_embed

Swin Transformer中的相对位置编码:

而在Swin Transformer中如公式所示,是在QKV的过程中将QKT/dQK^T/\sqrt{d}得到的张量根据相对位置索引加上对应的位置偏置,之后再与V算softmax。

Attention(Q,K,V)=SoftMax(QKT/d+B)V,Attention(Q, K, V ) = SoftMax(QK^T/\sqrt{d} + B)V,

在Swin Transformer中相对位置编码采用的是2d编码的方式,由于QKTQK^T得到的张量形状为 [numWindows x Bacth_size, num_heads, window_size * window_size, window_size * window_size],B的形状应该相同,排除numWindows x Bacth_size,一个窗口内的B形状应该为[num_heads, window_size * window_size, window_size * window_size]

由原论文中可以知道B中的值取自B^\hat{B}B^\hat{B}是参数化得到的偏置矩阵 B^R(2M1)×(2M1)\hat{B} ∈ R^{(2M−1)×(2M−1)},M = window_size

相对位置编码原理

以window_size = 2 * 2即每个窗口有4个token (M=2) 为例:

注意:

绝对位置索引:在计算每个token的自注意力时,都以唯一的一个token为中心点

每个token用二维坐标(x,y)表示,这里以第一个token为中心点进行二维坐标的表示。

当计算每个个token的自注意力时,各个token的位置索引都是下图

image-20240927160828643

  • 第 i 行表示第 i 个 token 的query对所有token的key的attention。
  • 对于 Attention 张量来说,以不同元素为原点,其他元素的坐标也是不同的

相对位置索引:在计算每个token的自注意力时,都以当前token为中心点

当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

window_size = M = 2 为例,生成网格 grid 坐标:

1
2
3
4
5
6
7
8
9
coords_h = torch.arange(self.window_size[0]) # x坐标
coords_w = torch.arange(self.window_size[1]) # y坐标
coords = torch.meshgrid([coords_h, coords_w]) # -> 2个(wh, ww)
"""
(tensor([[0, 0],
[1, 1]]), 分别代表四个token的行坐标
tensor([[0, 1],
[0, 1]])),分别代表四个token中的纵坐标
"""

堆叠并展开为 2D 向量:

1
2
3
4
5
6
coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
"""

分别在第 1 和 2 维处插入新维度,并利用广播机制做减法,得到 shape = (2, WhWwW_h*W_w, WhWwW_h*W_w) 的张量:

1
2
3
relative_coords_first = coords_flatten[:, :, None]  # 2, wh * ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh * ww
relative_coords = relative_coords_first - relative_coords_second # 2, wh * ww, wh * ww

PS:上一行代码相当重要,我看了几篇博客一直没弄懂,发现原来是有的博客给的图有问题,正是这一行代码实现了与QKTQK^T形状相对应的相对位置索引!

relative_coords_first - relative_coords_second这一行代码实现的就是该点相对位置索引=中心点绝对位置索引-该点绝对位置索引

其中第一个张量第i行代表着以第i个token为中心点的各个token的x坐标

第二个张量第i行代表着以第i个token为中心点的各个token的y坐标

未处理前QKTQK^T的相对位置编码索引:第i行第j列表示以第i个token为中心点时第j个token的相对位置索引

由于相减得到的索引是从负数开始的,故加上偏移量使之从 0 开始:

1
2
3
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1

之所以是加的偏移量是self.window_size[0] - 1与self.window_size[1] - 1是因为相对位置x,y索引值最小分别为-(self.window_size[0] - 1),-(self.window_size[1] - 1)

接着,需要将其展开成 1D 偏移量。对于诸如第 0 行上 (1, 2) 和 (2, 1) 这两个不同的坐标 (x, y),通过将 (x, y) 坐标求和得到 1D 偏移量 x+y 时,二者所表示的相对于原点的偏移量却是相等的 (1+2 = 2+1 = 3):不同的位置却具有相同的偏移量,降低了相对区分度/差异度

为避免这种 偏移量相等 的错误对应情况,还需对坐标 (准确地说是 x 坐标) 进行 乘法变换 (offset multiply),以提高区分度:对 x 坐标实施乘法变换得到 (x’, y),再重新计算得到具有差异度的各坐标位置的偏移量 x’+y

1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)  # 每个 x 坐标乘 (2 * 2 - 1) = 3

接着在最后一维上求和 x+y,展开成一个 1D 坐标 (相对位置索引),并注册为一个不参与网络学习的变量 relative_position_index,其作用是 根据最终的相对位置索引 找到对应的可学习的相对位置编码

1
2
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

之前计算的是相对位置索引,并不是相对位置偏置参数。真正使用到的可训练参数B^\hat{B}是保存在relative position bias table表里的,这个表的长度是等于 (2M−1) × (2M−1) (在二维位置坐标中线性变化乘以2M-1导致)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table表得到的。

(偏移前最大索引为(M-1,M-1),偏移后最大索引为(2M-2,2M-2),乘法变换后为((2M-2)*(2M-1),2M-2),索引最大值等于(2M-2)x(2M-1)+2M-2,由于索引包括0,故表的总长度为(2M-2)x(2M-1)+2M-2+1=(2M−1) × (2M−1))

相对位置编码总体流程

在不考虑num_heads的情况下:

  • 对于一个M*M的窗口,有M2M^2个token,每一个token要跟包括自身在内的M2M^2个token计算自注意力,故QKTQK^T的大小为(M2M^2M2M^2),所以B的大小也为(M2M^2M2M^2)。
  • QKTQK^T的第i行第j列为第i个token与第j个token算qkTqk^T的结果,它要加上以第i个token为中心点时第j个token的相对位置编码。
  • 因此对于相对位置编码,要计算分别以M2M^2个token为中心点时其他token的相对位置编码
  • 相对位置索引矩阵中第i行第j列对应以第i个token为中心点时第j个token的相对位置xy索引(未做处理)
  • 加偏移量
  • 乘法变换
  • 根据总偏移量在相对位置便宜表中得到对应的值
  • 得到QKT+BQK^T+B

相对位置编码矩阵:每一列代表每一个坐标在所有坐标 “眼中” 的相对位置

上图最后生成的是相对位置索引relative_position_index.shape = (M2M^2M2M^2) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1) ∗ (2M−1) ,所以当M=2时相对位置编码可以由一个3 * 3的矩阵表示

以 M=2 的窗口为例,当计算第一个token对应的4个QKTQK^T值时,四个token对应的relative_position_index分为别[4, 3, 1, 0] ,对应的数据就是图中位置索引4, 3, 1, 0位置对应的数据,即relative_position.shape = (M2M^2M2M^2)

前向过程

  • 首先输入张量形状为 [numWindows * B, window_size * window_size, C]
  • 然后经过self.qkv这个全连接层后,进行 reshape,调整轴的顺序,得到形状为[3, numWindows * B, num_heads, window_size*window_size, c//num_heads],并分配给q,k,v
  • 根据公式,我们对q乘以一个scale缩放系数,然后与k(为了满足矩阵乘要求,需要将最后两个维度调换)进行相乘。得到形状为[numWindows*B, num_heads, window_size*window_size, window_size*window_size]attn张量
  • 之前我们针对位置编码设置了个形状为(2 * window_size-1 * 2 * window_size-1, numHeads)的可学习变量。我们用计算得到的相对编码位置索引self.relative_position_index.view(-1)选取,得到形状为(window_size * window_size, window_size * window_size, numHeads)的编码,再permute(2,0,1)后加到attn张量上
  • 暂不考虑 mask 的情况,剩下就是跟 transformer 一样的 softmax,dropout,与V矩阵乘,再经过一层全连接层和 dropout
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape

qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

q = q * self.scale
attn = (q @ k.transpose(-2, -1))

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) # (1, num_heads, windowsize, windowsize)

if mask is not None: # 下文会分析到
...
else:
attn = self.softmax(attn)

attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

总体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""

def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()

self.dim = dim
self.window_size = window_size # (Wh, Ww) (7,7)
self.num_heads = num_heads # MHA 的头数
head_dim = dim // num_heads # dim 平均分给每个 head
self.scale = qk_scale or head_dim ** -0.5 # MHA 内的 scale 分母: 自定义的 qk_scale 或 根号 d

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) # 局部窗口高度方向坐标
coords_w = torch.arange(self.window_size[1]) # 局部窗口宽度方向坐标
# 局部窗口坐标网格
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 相对位置
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
# (num_windows*B, N, C) = (num_windows*B, wh*ww, C)
B_, N, C = x.shape

# (num_windows*B, N, C, num_heads, C//num_heads) -> (C, num_windows*B, num_heads, wh*ww, C//num_heads)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

# Query, Key, Value
# (num_windows*B, num_heads, wh*ww, C//num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

# Query 放缩
# (num_windows*B, num_heads, wh*ww, C//num_heads)
q = q * self.scale

# Query * Key
# (num_windows*B, num_heads, wh*ww, C//num_heads) * (num_windows*B, num_heads, C//num_heads, wh*ww) = (num_windows*B, num_heads, wh*ww, wh*ww)
attn = (q @ k.transpose(-2, -1)) # @ 表示矩阵-向量乘法

# 相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww

# Attention Map = Softmax(Q * K / √d + B)
# (num_heads, wh*ww, wh*ww) -> (1, num_heads, wh*ww, wh*ww) -> (num_windows*B, num_heads, wh*ww, wh*ww)
attn = attn + relative_position_bias.unsqueeze(0)
# 局部窗口 attention map mask + Softmax
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn) # 最终的 Attention Map
else:
attn = self.softmax(attn) # 最终的 Attention Map

# (num_windows*B, num_heads, wh*ww, wh*ww)
attn = self.attn_drop(attn)

# Attention Map * V
# (num_windows*B, num_heads, wh*ww, wh*ww) * (num_windows*B, num_heads, wh*ww, C//num_heads) = (num_windows*B, num_heads, wh*ww, C//num_heads)
# (num_windows*B, num_heads, wh*ww, C//num_heads) -> (num_windows*B, wh*ww, num_heads, C//num_heads) -> (num_windows*B, wh*ww, C) = (N*B, wh*ww, C)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # @ 表示矩阵-向量乘法

# 线性投影 FC
x = self.proj(x)
x = self.proj_drop(x)

return x

def extra_repr(self) -> str:
### 用于输出 print 结果
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

def flops(self, N):
### calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops

以4*4窗口为例的shape变化

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn


# 以 4×4 窗口大小为例
window_size = (4, 4)

coords_h = torch.arange(window_size[0]) # wh
coords_w = torch.arange(window_size[1]) # ww
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # (2, wh, ww)
coords, coords.shape
1
2
3
4
5
6
7
8
9
10
(tensor([[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],

[[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]]]),
torch.Size([2, 4, 4]))
1
2
coords_flatten = torch.flatten(coords, 1)  # (2, wh*ww)
coords_flatten, coords_flatten.shape
1
2
3
(tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.Size([2, 16]))
1
2
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, wh*ww, wh*ww)
relative_coords, relative_coords.shape
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
(tensor([[[ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
[ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
[ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
[ 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
[ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],
[ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],
[ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],
[ 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1, -2, -2, -2, -2],
[ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],
[ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],
[ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],
[ 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1, -1],
[ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
[ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
[ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
[ 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0]],

[[ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],
[ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],
[ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],
[ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
[ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],
[ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],
[ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],
[ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
[ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],
[ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],
[ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],
[ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
[ 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3, 0, -1, -2, -3],
[ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2],
[ 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1, 2, 1, 0, -1],
[ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0]]]),
torch.Size([2, 16, 16]))
1
2
3
# (x, y) 格式显示 横、纵坐标
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (wh*ww, wh*ww, 2)
relative_coords, relative_coords.shape
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
(tensor([[[ 0,  0],
[ 0, -1],
[ 0, -2],
[ 0, -3],
[-1, 0],
[-1, -1],
[-1, -2],
[-1, -3],
[-2, 0],
[-2, -1],
[-2, -2],
[-2, -3],
[-3, 0],
[-3, -1],
[-3, -2],
[-3, -3]],

[[ 0, 1],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[-1, 1],
[-1, 0],
[-1, -1],
[-1, -2],
[-2, 1],
[-2, 0],
[-2, -1],
[-2, -2],
[-3, 1],
[-3, 0],
[-3, -1],
[-3, -2]],

[[ 0, 2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 2],
[-1, 1],
[-1, 0],
[-1, -1],
[-2, 2],
[-2, 1],
[-2, 0],
[-2, -1],
[-3, 2],
[-3, 1],
[-3, 0],
[-3, -1]],

[[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[-1, 3],
[-1, 2],
[-1, 1],
[-1, 0],
[-2, 3],
[-2, 2],
[-2, 1],
[-2, 0],
[-3, 3],
[-3, 2],
[-3, 1],
[-3, 0]],

[[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 1, -3],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[ 0, -3],
[-1, 0],
[-1, -1],
[-1, -2],
[-1, -3],
[-2, 0],
[-2, -1],
[-2, -2],
[-2, -3]],

[[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[-1, 1],
[-1, 0],
[-1, -1],
[-1, -2],
[-2, 1],
[-2, 0],
[-2, -1],
[-2, -2]],

[[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 2],
[-1, 1],
[-1, 0],
[-1, -1],
[-2, 2],
[-2, 1],
[-2, 0],
[-2, -1]],

[[ 1, 3],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[-1, 3],
[-1, 2],
[-1, 1],
[-1, 0],
[-2, 3],
[-2, 2],
[-2, 1],
[-2, 0]],

[[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 2, -3],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 1, -3],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[ 0, -3],
[-1, 0],
[-1, -1],
[-1, -2],
[-1, -3]],

[[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[-1, 1],
[-1, 0],
[-1, -1],
[-1, -2]],

[[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 2],
[-1, 1],
[-1, 0],
[-1, -1]],

[[ 2, 3],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 1, 3],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[-1, 3],
[-1, 2],
[-1, 1],
[-1, 0]],

[[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 3, -3],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 2, -3],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 1, -3],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[ 0, -3]],

[[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[ 0, -2]],

[[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[ 0, -1]],

[[ 3, 3],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 2, 3],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 1, 3],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0]]]),
torch.Size([16, 16, 2]))
1
2
3
# 横坐标加性偏移 (+= 3)
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
tensor([[[ 3,  0],
[ 3, -1],
[ 3, -2],
[ 3, -3],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 2, -3],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 1, -3],
[ 0, 0],
[ 0, -1],
[ 0, -2],
[ 0, -3]],

[[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[ 0, -2]],

[[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 2],
[ 0, 1],
[ 0, 0],
[ 0, -1]],

[[ 3, 3],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 2, 3],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 1, 3],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0]],

[[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 4, -3],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 3, -3],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 2, -3],
[ 1, 0],
[ 1, -1],
[ 1, -2],
[ 1, -3]],

[[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 1, -2]],

[[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 1, 2],
[ 1, 1],
[ 1, 0],
[ 1, -1]],

[[ 4, 3],
[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 3, 3],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 2, 3],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 1, 3],
[ 1, 2],
[ 1, 1],
[ 1, 0]],

[[ 5, 0],
[ 5, -1],
[ 5, -2],
[ 5, -3],
[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 4, -3],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 3, -3],
[ 2, 0],
[ 2, -1],
[ 2, -2],
[ 2, -3]],

[[ 5, 1],
[ 5, 0],
[ 5, -1],
[ 5, -2],
[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 2, 1],
[ 2, 0],
[ 2, -1],
[ 2, -2]],

[[ 5, 2],
[ 5, 1],
[ 5, 0],
[ 5, -1],
[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 2, 2],
[ 2, 1],
[ 2, 0],
[ 2, -1]],

[[ 5, 3],
[ 5, 2],
[ 5, 1],
[ 5, 0],
[ 4, 3],
[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 3, 3],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 2, 3],
[ 2, 2],
[ 2, 1],
[ 2, 0]],

[[ 6, 0],
[ 6, -1],
[ 6, -2],
[ 6, -3],
[ 5, 0],
[ 5, -1],
[ 5, -2],
[ 5, -3],
[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 4, -3],
[ 3, 0],
[ 3, -1],
[ 3, -2],
[ 3, -3]],

[[ 6, 1],
[ 6, 0],
[ 6, -1],
[ 6, -2],
[ 5, 1],
[ 5, 0],
[ 5, -1],
[ 5, -2],
[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 4, -2],
[ 3, 1],
[ 3, 0],
[ 3, -1],
[ 3, -2]],

[[ 6, 2],
[ 6, 1],
[ 6, 0],
[ 6, -1],
[ 5, 2],
[ 5, 1],
[ 5, 0],
[ 5, -1],
[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 4, -1],
[ 3, 2],
[ 3, 1],
[ 3, 0],
[ 3, -1]],

[[ 6, 3],
[ 6, 2],
[ 6, 1],
[ 6, 0],
[ 5, 3],
[ 5, 2],
[ 5, 1],
[ 5, 0],
[ 4, 3],
[ 4, 2],
[ 4, 1],
[ 4, 0],
[ 3, 3],
[ 3, 2],
[ 3, 1],
[ 3, 0]]])
1
2
3
# 纵坐标加性偏移 (+= 3)
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
tensor([[[3, 3],
[3, 2],
[3, 1],
[3, 0],
[2, 3],
[2, 2],
[2, 1],
[2, 0],
[1, 3],
[1, 2],
[1, 1],
[1, 0],
[0, 3],
[0, 2],
[0, 1],
[0, 0]],

[[3, 4],
[3, 3],
[3, 2],
[3, 1],
[2, 4],
[2, 3],
[2, 2],
[2, 1],
[1, 4],
[1, 3],
[1, 2],
[1, 1],
[0, 4],
[0, 3],
[0, 2],
[0, 1]],

[[3, 5],
[3, 4],
[3, 3],
[3, 2],
[2, 5],
[2, 4],
[2, 3],
[2, 2],
[1, 5],
[1, 4],
[1, 3],
[1, 2],
[0, 5],
[0, 4],
[0, 3],
[0, 2]],

[[3, 6],
[3, 5],
[3, 4],
[3, 3],
[2, 6],
[2, 5],
[2, 4],
[2, 3],
[1, 6],
[1, 5],
[1, 4],
[1, 3],
[0, 6],
[0, 5],
[0, 4],
[0, 3]],

[[4, 3],
[4, 2],
[4, 1],
[4, 0],
[3, 3],
[3, 2],
[3, 1],
[3, 0],
[2, 3],
[2, 2],
[2, 1],
[2, 0],
[1, 3],
[1, 2],
[1, 1],
[1, 0]],

[[4, 4],
[4, 3],
[4, 2],
[4, 1],
[3, 4],
[3, 3],
[3, 2],
[3, 1],
[2, 4],
[2, 3],
[2, 2],
[2, 1],
[1, 4],
[1, 3],
[1, 2],
[1, 1]],

[[4, 5],
[4, 4],
[4, 3],
[4, 2],
[3, 5],
[3, 4],
[3, 3],
[3, 2],
[2, 5],
[2, 4],
[2, 3],
[2, 2],
[1, 5],
[1, 4],
[1, 3],
[1, 2]],

[[4, 6],
[4, 5],
[4, 4],
[4, 3],
[3, 6],
[3, 5],
[3, 4],
[3, 3],
[2, 6],
[2, 5],
[2, 4],
[2, 3],
[1, 6],
[1, 5],
[1, 4],
[1, 3]],

[[5, 3],
[5, 2],
[5, 1],
[5, 0],
[4, 3],
[4, 2],
[4, 1],
[4, 0],
[3, 3],
[3, 2],
[3, 1],
[3, 0],
[2, 3],
[2, 2],
[2, 1],
[2, 0]],

[[5, 4],
[5, 3],
[5, 2],
[5, 1],
[4, 4],
[4, 3],
[4, 2],
[4, 1],
[3, 4],
[3, 3],
[3, 2],
[3, 1],
[2, 4],
[2, 3],
[2, 2],
[2, 1]],

[[5, 5],
[5, 4],
[5, 3],
[5, 2],
[4, 5],
[4, 4],
[4, 3],
[4, 2],
[3, 5],
[3, 4],
[3, 3],
[3, 2],
[2, 5],
[2, 4],
[2, 3],
[2, 2]],

[[5, 6],
[5, 5],
[5, 4],
[5, 3],
[4, 6],
[4, 5],
[4, 4],
[4, 3],
[3, 6],
[3, 5],
[3, 4],
[3, 3],
[2, 6],
[2, 5],
[2, 4],
[2, 3]],

[[6, 3],
[6, 2],
[6, 1],
[6, 0],
[5, 3],
[5, 2],
[5, 1],
[5, 0],
[4, 3],
[4, 2],
[4, 1],
[4, 0],
[3, 3],
[3, 2],
[3, 1],
[3, 0]],

[[6, 4],
[6, 3],
[6, 2],
[6, 1],
[5, 4],
[5, 3],
[5, 2],
[5, 1],
[4, 4],
[4, 3],
[4, 2],
[4, 1],
[3, 4],
[3, 3],
[3, 2],
[3, 1]],

[[6, 5],
[6, 4],
[6, 3],
[6, 2],
[5, 5],
[5, 4],
[5, 3],
[5, 2],
[4, 5],
[4, 4],
[4, 3],
[4, 2],
[3, 5],
[3, 4],
[3, 3],
[3, 2]],

[[6, 6],
[6, 5],
[6, 4],
[6, 3],
[5, 6],
[5, 5],
[5, 4],
[5, 3],
[4, 6],
[4, 5],
[4, 4],
[4, 3],
[3, 6],
[3, 5],
[3, 4],
[3, 3]]])
1
2
3
# 横坐标乘性变换 (*= 7)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_coords
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
tensor([[[21,  3],
[21, 2],
[21, 1],
[21, 0],
[14, 3],
[14, 2],
[14, 1],
[14, 0],
[ 7, 3],
[ 7, 2],
[ 7, 1],
[ 7, 0],
[ 0, 3],
[ 0, 2],
[ 0, 1],
[ 0, 0]],

[[21, 4],
[21, 3],
[21, 2],
[21, 1],
[14, 4],
[14, 3],
[14, 2],
[14, 1],
[ 7, 4],
[ 7, 3],
[ 7, 2],
[ 7, 1],
[ 0, 4],
[ 0, 3],
[ 0, 2],
[ 0, 1]],

[[21, 5],
[21, 4],
[21, 3],
[21, 2],
[14, 5],
[14, 4],
[14, 3],
[14, 2],
[ 7, 5],
[ 7, 4],
[ 7, 3],
[ 7, 2],
[ 0, 5],
[ 0, 4],
[ 0, 3],
[ 0, 2]],

[[21, 6],
[21, 5],
[21, 4],
[21, 3],
[14, 6],
[14, 5],
[14, 4],
[14, 3],
[ 7, 6],
[ 7, 5],
[ 7, 4],
[ 7, 3],
[ 0, 6],
[ 0, 5],
[ 0, 4],
[ 0, 3]],

[[28, 3],
[28, 2],
[28, 1],
[28, 0],
[21, 3],
[21, 2],
[21, 1],
[21, 0],
[14, 3],
[14, 2],
[14, 1],
[14, 0],
[ 7, 3],
[ 7, 2],
[ 7, 1],
[ 7, 0]],

[[28, 4],
[28, 3],
[28, 2],
[28, 1],
[21, 4],
[21, 3],
[21, 2],
[21, 1],
[14, 4],
[14, 3],
[14, 2],
[14, 1],
[ 7, 4],
[ 7, 3],
[ 7, 2],
[ 7, 1]],

[[28, 5],
[28, 4],
[28, 3],
[28, 2],
[21, 5],
[21, 4],
[21, 3],
[21, 2],
[14, 5],
[14, 4],
[14, 3],
[14, 2],
[ 7, 5],
[ 7, 4],
[ 7, 3],
[ 7, 2]],

[[28, 6],
[28, 5],
[28, 4],
[28, 3],
[21, 6],
[21, 5],
[21, 4],
[21, 3],
[14, 6],
[14, 5],
[14, 4],
[14, 3],
[ 7, 6],
[ 7, 5],
[ 7, 4],
[ 7, 3]],

[[35, 3],
[35, 2],
[35, 1],
[35, 0],
[28, 3],
[28, 2],
[28, 1],
[28, 0],
[21, 3],
[21, 2],
[21, 1],
[21, 0],
[14, 3],
[14, 2],
[14, 1],
[14, 0]],

[[35, 4],
[35, 3],
[35, 2],
[35, 1],
[28, 4],
[28, 3],
[28, 2],
[28, 1],
[21, 4],
[21, 3],
[21, 2],
[21, 1],
[14, 4],
[14, 3],
[14, 2],
[14, 1]],

[[35, 5],
[35, 4],
[35, 3],
[35, 2],
[28, 5],
[28, 4],
[28, 3],
[28, 2],
[21, 5],
[21, 4],
[21, 3],
[21, 2],
[14, 5],
[14, 4],
[14, 3],
[14, 2]],

[[35, 6],
[35, 5],
[35, 4],
[35, 3],
[28, 6],
[28, 5],
[28, 4],
[28, 3],
[21, 6],
[21, 5],
[21, 4],
[21, 3],
[14, 6],
[14, 5],
[14, 4],
[14, 3]],

[[42, 3],
[42, 2],
[42, 1],
[42, 0],
[35, 3],
[35, 2],
[35, 1],
[35, 0],
[28, 3],
[28, 2],
[28, 1],
[28, 0],
[21, 3],
[21, 2],
[21, 1],
[21, 0]],

[[42, 4],
[42, 3],
[42, 2],
[42, 1],
[35, 4],
[35, 3],
[35, 2],
[35, 1],
[28, 4],
[28, 3],
[28, 2],
[28, 1],
[21, 4],
[21, 3],
[21, 2],
[21, 1]],

[[42, 5],
[42, 4],
[42, 3],
[42, 2],
[35, 5],
[35, 4],
[35, 3],
[35, 2],
[28, 5],
[28, 4],
[28, 3],
[28, 2],
[21, 5],
[21, 4],
[21, 3],
[21, 2]],

[[42, 6],
[42, 5],
[42, 4],
[42, 3],
[35, 6],
[35, 5],
[35, 4],
[35, 3],
[28, 6],
[28, 5],
[28, 4],
[28, 3],
[21, 6],
[21, 5],
[21, 4],
[21, 3]]])
1
2
3
4
5
6
# 计算 1D 偏移量 (x+y)
relative_position_index = relative_coords.sum(-1) # (wh*ww, wh*ww)
relative_position_index, relative_position_index.shape

# 可见偏移量大小沿主对角线垂直方向扩散、分布
# 16 列与 4×4 个坐标位置一一对应
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
(tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7,  3,  2,  1,  0],
[25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8, 4, 3, 2, 1],
[26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9, 5, 4, 3, 2],
[27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10, 6, 5, 4, 3],
[31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14, 10, 9, 8, 7],
[32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15, 11, 10, 9, 8],
[33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10, 9],
[34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],
[38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],
[39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],
[40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],
[41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17],
[45, 44, 43, 42, 38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21],
[46, 45, 44, 43, 39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22],
[47, 46, 45, 44, 40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23],
[48, 47, 46, 45, 41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24]]),
torch.Size([16, 16]))
1
2
3
4
5
6
# 设 MHA 的 heads 数为 3
num_heads = 3

relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
relative_position_bias_table, relative_position_bias_table.shape
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
(Parameter containing:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], requires_grad=True),
torch.Size([49, 3]))
1
2
3
4
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(
window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias, relative_position_bias.shape
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
(tensor([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],

[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], grad_fn=<ViewBackward>),
torch.Size([16, 16, 3]))
1
2
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias, relative_position_bias.shape
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
(tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<CopyBackwards>),
torch.Size([3, 16, 16]))

Shifted Window Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1

mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None

patch merging

由于SwinTransformer块不改变输入的形状,故输出还是[B, H*W, C]

x0,x1,x2,x3如上图所示隔一选一,在通道维度上连接,之后经过全连接层将通道维度由4倍调整为2倍。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""

def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)

def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

# reshape
x = x.view(B, H, W, C)

# 在行、列方向以 stride = 2 等间隔抽样, 实现分辨率 1/2 下采样
x0 = x[:, 0::2, 0::2, :] # shape = (B, H/2, W/2, C)
x1 = x[:, 1::2, 0::2, :] # shape = (B, H/2, W/2, C)
x2 = x[:, 0::2, 1::2, :] # shape = (B, H/2, W/2, C)
x3 = x[:, 1::2, 1::2, :] # shape = (B, H/2, W/2, C)

# 拼接 使通道数加倍
x = torch.cat([x0, x1, x2, x3], -1) # shape = (B, H/2, W/2, 4*C)
x = x.view(B, -1, 4 * C) # shape = (B, H*W/4, 4*C)

# FC 使通道数减半
x = self.norm(x)
x = self.reduction(x) # shape = (B, H*W/4, 2*C)

return x

def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"

def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops

MLP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

Basic Layer

Basic Layer 即 Swin Transformer 的各 Stage,包含了若干 Swin Transformer Blocks 及 其他层。

注意,一个 Stage 包含的 Swin Transformer Blocks 的个数必须是 偶数,因为需交替包含一个含有 Window Attention (W-MSA) 的 Layer 和含有 Shifted Window Attention (SW-MSA) 的 Layer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint

# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])

# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None

def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x

def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops