Графический свинг-трансформер

Идентификация изображения
Графический свинг-трансформер

введение

В настоящее время при применении Transformer к полю изображения возникают две основные проблемы:

  • Визуальный объект сильно меняется, и производительность визуального преобразователя может быть не очень хорошей в разных сценариях.

  • Разрешение изображения высокое и много пикселей Расчет Трансформера, основанный на глобальном внутреннем внимании, приводит к большому объему вычислений.

В ответ на две вышеуказанные проблемы мы предлагаемВключая управление скользящим окном с иерархическим дизайномСвин Трансформер.

Операция скользящего окна включает в себяНеперекрывающиеся локальные окна и перекрывающиеся поперечные окна. ограничение вычисления внимания окном,С одной стороны, это может ввести локальность операции свертки CNN, а с другой стороны, это может сэкономить объем вычислений..

图片

Свин-Т и ВиТ

На всех основных задачах с изображениями Swin Transformer показал хорошую производительность.

Эта статья относительно длинная и будет основана на официальном открытом исходном коде (GitHub.com/Microsoft/S…

Общая структура

Давайте сначала посмотрим на общую архитектуру Swin Transformer.

图片

Общая архитектура Swin Transformer

Вся модель имеет иерархический дизайн, включающий в общей сложности 4 этапа, каждый этап будет уменьшать разрешение входной карты объектов и расширять рецептивное поле слой за слоем, как CNN.

  • В начале ввода сделайтеPatch Embedding, разрежьте изображение на плитки и вставьте их вEmbedding.

  • На каждом этапе поPatch Mergingи несколько блоков.

  • вPatch MergingМодуль в основном уменьшает разрешение изображения в начале каждого этапа.

  • Конкретная структура блока показана на рисунке справа, в основном:LayerNorm,MLP,Window AttentionиShifted Window AttentionСостав (некоторые параметры опущу для удобства пояснения)

class SwinTransformer(nn.Module):
    def __init__(...):
        super().__init__()
        ...
        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            
        self.pos_drop = nn.Dropout(p=drop_rate)

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(...)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

Есть несколько мест, где подход отличается от ViT:

  • ViT будет позиционно кодировать вложение на входе. И Swin-T здесь какпо желанию(self.ape), Swin-T сделал相对位置编码

  • ViT добавит обучаемый параметр отдельно в качестве токена для классификации. А Swin-T естьсредний напрямую, выходная классификация, несколько похожая на последний глобальный средний слой пула CNN.

Далее рассмотрим состав каждого компонента

Patch Embedding

Перед входом в блок нам нужно разрезать изображение на патчи, а затем встроить вектор.

Конкретный метод заключается в том, чтобы разрезать исходное изображение на части одну за другой.window_size * window_sizeразмер окна, а затем встроить.

Здесь, через двумерный сверточный слой,Установите шаг, размер ядра равным размеру окна.. Установите выходной канал, чтобы определить размер вектора внедрения. Наконец, разверните размеры H и W и перейдите к первому измерению.

import torch
import torch.nn as nn


class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) # -> (img_size, img_size)
        patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # 假设采取默认参数
        x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) 
        x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
        x = torch.transpose(x, 1, 2)  # 把通道维放到最后 (N, 56*56, 96)
        if self.norm is not None:
            x = self.norm(x)
        return x

Patch Merging

Функция этого модуля состоит в том, чтобы делать даунсэмплинг перед началом каждого этапа для уменьшения разрешения, регулировать количество каналов для формирования иерархического дизайна, а также экономить определенное количество вычислений.

В CNN он используется перед началом каждого этапа сstride=2Слой свертки/объединения для уменьшения разрешения.

Каждая даунсэмплинг выполняется дважды, поэтомуВ направлении строки и столбца интервал 2 выбирает элементы.

Затем склеил как целый тензор и, наконец, расширил.В это время размер канала станет в 4 раза больше исходного(поскольку H и W уменьшены в 2 раза), то передайтеПолносвязный слой корректирует размер канала до удвоенного исходного размера.

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

Ниже приведена схематическая диаграмма (входные тензоры N=1, H=W=8, C=1, исключая окончательную корректировку полносвязного слоя)

图片

Patch Merge

Лично мне кажется, что это обратная операция PixelShuffle.

Window Partition/Reverse

window partitionФункция используется для разделения тензора на окно и указания размера окна. Преобразуйте исходный тензор изN H W C, разделен наnum_windows*B, window_size, window_size, Cnum_windows = H*W / window_size, то есть количество окон. иwindow reverseФункция является соответствующим обратным процессом. Эти две функции будут позжеWindow Attentionиспользовал.

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

Window Attention

Это суть этой статьи. Традиционные трансформерыВычислить внимание на основе глобального, поэтому вычислительная сложность очень высока. Swin Transformer будетВычисление внимания ограничено каждым окном, тем самым уменьшая количество вычислений.

Кратко рассмотрим формулу

Основное отличие состоит в том, что Q, K в исходной формуле расчета ВниманиеДобавлено кодирование относительного положения. Последующие эксперименты показали, что добавление кодирования относительного положения улучшает производительность модели.

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads # nH
        head_dim = dim // num_heads # 每个注意力头对应的通道数
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 设置一个形状为(2*(Wh-1) * 2*(Ww-1), nH)的可学习变量,用于后续的位置编码
  
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
     # 相关位置编码...

Далее я вынесу логику, связанную с релевантным кодированием позиции отдельно, эта часть более сложная.

Во-первых, форма тензора внимания, рассчитанного с помощью QK, имеет вид(numWindows*B, num_heads, window_size*window_size, window_size*window_size).

А для тензора вниманияПринимая разные элементы за начало координат, координаты других элементов также различны.,отwindow_size=2Например, кодировка относительного положения показана на следующем рисунке.

图片

Пример кодирования относительного положения

Сначала мы используемtorch.arangeиtorch.meshgridФункция генерирует соответствующие координаты, здесь мы используемwindowsize=2Например

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
  (tensor([[0, 0],
           [1, 1]]), 
   tensor([[0, 1],
           [0, 1]]))
"""

Затем сложите и разверните в двумерный вектор

coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
"""

Используя механизм трансляции, вставьте измерение в первое измерение и второе измерение и выполните широковещательное вычитание, чтобы получить2, wh*ww, wh*wwтензор

relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量

Поскольку выполняется вычитание, результирующий индекс начинается с отрицательного числа,Мы добавляем смещение, чтобы начать с 0.

relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1

Позже нам нужно расширить его до одномерного смещения. И для двух координат (1, 2) и (2, 1). различаются в двух измерениях,Но при преобразовании в одномерное смещение путем добавления координат x и y его смещение равно.

图片

Расширить до смещения 1D

Итак, в конце мы сделали операцию умножения, чтобы отличить

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

图片

offset multiply

Затем просуммируйте по последнему измерению, разверните в одномерную координату и зарегистрируйте как переменную, которая не участвует в обучении сети.

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

Затем мы смотрим на прямой код

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # (1, num_heads, windowsize, windowsize)

        if mask is not None: # 下文会分析到
            ...
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

  • Сначала форма входного тензораnumWindows*B, window_size * window_size, C(будет объяснено позже)

  • затем послеself.qkvПосле этого полностью связанного слоя выполните изменение формы, отрегулируйте порядок осей и получите форму как3, numWindows*B, num_heads, window_size*window_size, c//num_heads, и присвоеноq,k,v.

  • По формуле имеемqумножить на одинscaleкоэффициент масштабирования, то сk(Чтобы удовлетворить требованиям умножения матриц, необходимо поменять местами два последних измерения) для умножения. получить форму(numWindows*B, num_heads, window_size*window_size, window_size*window_size)изattnТензор

  • Ранее мы установили форму для кодирования положения как(2*window_size-1*2*window_size-1, numHeads)обучаемые переменные. Мы используем рассчитанный индекс относительной позиции кодированияself.relative_position_indexВыберите, чтобы получить форму(window_size*window_size, window_size*window_size, numHeads)код, добавить вattnна тензоре

  • Независимо от ситуации с маской, остальное — это тот же softmax, дропаут и трансформер, что и трансформер.VУмножение матриц, за которым следует полносвязный слой и отсев

Shifted Window Attention

Предыдущее Window Attention вычисляет внимание под каждым окном.Чтобы лучше взаимодействовать с другими окнами, Swin Transformer также вводит операцию смещения окна.

图片

Shift Window

Слева — Window Attention без перекрытия, а справа — Shift Window Attention, который сдвигает окно. Видно, что сдвинутое окно содержит элементы исходного соседнего окна. Но это также вводит новую проблему, а именноКоличество окон увеличилось вдвое., из исходных четырех окон в девять окон.

В реальном коде мыКосвенно достигается за счет смещения карты объектов и установки маски для внимания.. может быть вСохранить исходное количество окон, окончательный результат расчета эквивалентен.图片

Операция сдвига карты объектов

Сдвиг карты функций в коде выполняется с помощьюtorch.rollдля достижения, ниже приведена схематическая диаграмма

图片

сменная операция

Если нужноreverse cyclic shiftзатем просто введите параметрshiftsУстановите соответствующее положительное значение.

Attention Mask

Я думаю, что это суть Swin Transformer, установив разумную маску, пустьShifted Window Attentionв сWindow AttentionПри одинаковом количестве окон достигаются эквивалентные результаты расчета.

Во-первых, мы даем индекс каждому окну после Shift Window и делаемrollОперация (window_size=2, shift_size=1)

图片

Shift window index

Мы надеемся, что при расчете Внимание,Разрешить вычисление QK с одним и тем же индексом и игнорировать результаты расчета QK с разными индексами..

Окончательный правильный результат показан на рисунке ниже.

图片

Shift Attention

Чтобы получить правильный результат в исходных четырех окнах, мы должны добавить маску к результату Attention (как показано в крайнем правом углу рисунка выше).

Соответствующий код выглядит следующим образом:

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

С настройками на рисунке выше мы получим вот такую ​​маску с этим кодом

tensor([[[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],


         [[[   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.],
           [   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.]]],


         [[[   0.,    0., -100., -100.],
           [   0.,    0., -100., -100.],
           [-100., -100.,    0.,    0.],
           [-100., -100.,    0.,    0.]]],


         [[[   0., -100., -100., -100.],
           [-100.,    0., -100., -100.],
           [-100., -100.,    0., -100.],
           [-100., -100., -100.,    0.]]]]])

В прямом коде предыдущего модуля внимания окна есть такой абзац

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)

Добавьте маску к результату расчета внимания и выполните softmax. Значение маски установлено на -100, и соответствующее значение будет игнорироваться после softmax.

Общая архитектура трансформаторного блока

图片

Трансформаторная блочная архитектура

Две последовательные архитектуры блоков показаны на рисунке выше.Следует отметить, что количество блоков, содержащихся в этапе, должно быть четным числом, потому что он должен попеременно содержать блок, содержащийWindow AttentionСлой и содержитShifted Window AttentionСлой.

Давайте посмотрим на прямой код блока

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

Общий процесс выглядит следующим образом

  • Сначала выполните LayerNorm на карте объектов.

  • пройти черезself.shift_sizeРешите, следует ли сместить карту объектов

  • Затем разрежьте карту объектов на окна одно за другим.

  • Рассчитать внимание, поself.attn_maskотличитьWindow Attentionвсе ещеShift Window Attention

  • Объединить отдельные окна обратно

  • Если вы выполняли операцию сдвига раньше, сделайте это сейчасreverse shift, восстановить предыдущую операцию смены

  • Делайте отсев и остаточные соединения

  • Затем передаем слой LayerNorm+ полносвязный слой, а также отсева и остаточное соединение

Результаты экспериментов

图片

Результаты экспериментов

В наборе данных ImageNet22K уровень точности может достигать поразительных 86,4%. Кроме того, он очень хорошо справляется с такими задачами, как обнаружение и сегментация.Если вам интересно, вы можете прочитать экспериментальную часть в конце статьи.

Суммировать

Инновация этой статьи великолепна: она вводит понятие окна, вводит локальность CNN, а также контролирует общий объем вычислений модели. В разделе Shift Window Attention, используя маску и операцию сдвига, очень умно добиться вычислительной эквивалентности. Код автора тоже очень радует глаз, рекомендуется к прочтению!


Добро пожаловать в GiantPandaCV, где вы увидите эксклюзивный обмен глубокими знаниями, настаиваете на оригинальности и делитесь свежими знаниями, которые мы изучаем каждый день. ( • ̀ω•́ )✧

Если у вас есть вопросы по статье или вы хотите присоединиться к группе обмена, добавьте BBuf WeChat:

图片

QR код

В этой статье используетсяПомощник по синхронизации статейСинхронизировать