Swin

method

image-20221019152850665

整体上看,Swin的结构类似于PVT,呈现出一种金字塔架构。

为什么金字塔结构如此重要?FPN理论认为,不同尺寸的特征图拥有不同的感受野,同时还有池化操作,从而能够很好地处理这个物体不同尺寸的这个问题(这点可以参考论文U-Net)。这样的模型更适合使用密集型任务。

code

Patch Merging

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def forward(self, x):
    """
    x: B, H*W, C
    """
    print(x.shape)
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"
    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

    x = x.view(B, H, W, C)

    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

    x = self.norm(x)
    x = self.reduction(x)

    return x
1
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

这部分首先将特征图不同位置的信息提取出来,组成4张新的特征图,然后在通道维度上进行拼接,通道数就扩大为了原来的4倍,最后通过投影将通道数缩小为之前的一半。如下图:

image-20221019193319461

提取方法如下:

image-20221019192125716

Swin Transformer Block

image-20221019194850450

首先需要说明的是,两个transformer层合在一起才算是swin transformer的一个基本单元。两次attention作用的位置不同:

image-20221019195140178

Window Based Attention

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def forward(self, x, mask=None):
    """
    Args:
        x: input features with shape of (num_windows*B, N, C)
        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
    """
    # [128, 49, 96]
    B_, N, C = x.shape
    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))

    if mask is not None:
        nW = mask.shape[0]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
    else:
        attn = self.softmax(attn)

    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

首先注意到输入的中间维度是49。这是因为swin的window size为7,这意味着一个window中一共包含49个patch。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B, H, W, C)
    # [2, 56, 56, 96]

    # cyclic shift
    if self.shift_size > 0: # 3
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    else:
        shifted_x = x
    # [2, 56, 56, 96]

    # partition windows
    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
    # [128, 49, 96]

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
    # [128, 49, 96]

    # merge windows
    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)	
    # [128, 7, 7, 96]
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
    
    # reverse cyclic shift
    if self.shift_size > 0:
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = shifted_x
    x = x.view(B, H * W, C)
    # [2, 3136, 96]

    # FFN
    x = shortcut + self.drop_path(x)
    x = x + self.drop_path(self.mlp(self.norm2(x)))

    return x

代码通过window_partation操作,将一个batch内的图片切分为若干个长为49的序列,然后在此序列中做attention,从而减小计算量。

Shift Window

swin transformer在奇数层的window不shift,但在偶数层shift,这点在层数的设置上也可以看出:

1
2
3
4
5
6
7
8
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                     num_heads=num_heads, window_size=window_size,
                     shift_size=0 if (i % 2 == 0) else window_size // 2,
                     mlp_ratio=mlp_ratio,
                     qkv_bias=qkv_bias, qk_scale=qk_scale,
                     drop=drop, attn_drop=attn_drop,
                     drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                     norm_layer=norm_layer)

image-20221020192429552

对于cyclic shift的形象解释:

image-20221020193327144

image-20221020193337439

而后是mask机制。在window shift之后,为了并行计算,图片仍然被划分为4个大块。但为了让图片避免与不相关的部分作自注意力,我们引入了一下掩码:

image-20221020200411450

image-20221020200723794

消融实验表明shift window和pos embedding在检测任务上的有效性:

image-20221020202011308

Swin Transformer

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def forward_features(self, x):
    x = self.patch_embed(x)
    if self.ape:
        x = x + self.absolute_pos_embed
    x = self.pos_drop(x)

    for layer in self.layers:
        x = layer(x)

    x = self.norm(x)  # B L C # [2, 49, 768]
    x = self.avgpool(x.transpose(1, 2))  # B C 1 # [2, 768, 1]
    x = torch.flatten(x, 1) # [2, 768]
    return x

def forward(self, x):
    x = self.forward_features(x)
    x = self.head(x)
    return x

类似resnet,swin会将最后一层特征图池化,而不是引入cls token。

Licensed under CC BY-NC-SA 4.0
Built with Hugo
Theme Stack designed by Jimmy
visitors: total visits: time(s) reads: time(s)