введение
В настоящее время при применении Transformer к полю изображения возникают две основные проблемы:
-
Визуальный объект сильно меняется, и производительность визуального преобразователя может быть не очень хорошей в разных сценариях.
-
Разрешение изображения высокое и много пикселей Расчет Трансформера, основанный на глобальном внутреннем внимании, приводит к большому объему вычислений.
В ответ на две вышеуказанные проблемы мы предлагаемВключая управление скользящим окном с иерархическим дизайномСвин Трансформер.
Операция скользящего окна включает в себяНеперекрывающиеся локальные окна и перекрывающиеся поперечные окна. ограничение вычисления внимания окном,С одной стороны, это может ввести локальность операции свертки CNN, а с другой стороны, это может сэкономить объем вычислений..
Свин-Т и ВиТ
На всех основных задачах с изображениями Swin Transformer показал хорошую производительность.
Эта статья относительно длинная и будет основана на официальном открытом исходном коде (GitHub.com/Microsoft/S…
Общая структура
Давайте сначала посмотрим на общую архитектуру Swin Transformer.
Общая архитектура Swin Transformer
Вся модель имеет иерархический дизайн, включающий в общей сложности 4 этапа, каждый этап будет уменьшать разрешение входной карты объектов и расширять рецептивное поле слой за слоем, как CNN.
-
В начале ввода сделайте
Patch Embedding
, разрежьте изображение на плитки и вставьте их вEmbedding
. -
На каждом этапе по
Patch Merging
и несколько блоков. -
в
Patch Merging
Модуль в основном уменьшает разрешение изображения в начале каждого этапа. -
Конкретная структура блока показана на рисунке справа, в основном:
LayerNorm
,MLP
,Window Attention
иShifted Window Attention
Состав (некоторые параметры опущу для удобства пояснения)
class SwinTransformer(nn.Module):
def __init__(...):
super().__init__()
...
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(...)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
Есть несколько мест, где подход отличается от ViT:
-
ViT будет позиционно кодировать вложение на входе. И Swin-T здесь какпо желанию(
self.ape
), Swin-T сделал相对位置编码
-
ViT добавит обучаемый параметр отдельно в качестве токена для классификации. А Swin-T естьсредний напрямую, выходная классификация, несколько похожая на последний глобальный средний слой пула CNN.
Далее рассмотрим состав каждого компонента
Patch Embedding
Перед входом в блок нам нужно разрезать изображение на патчи, а затем встроить вектор.
Конкретный метод заключается в том, чтобы разрезать исходное изображение на части одну за другой.window_size * window_size
размер окна, а затем встроить.
Здесь, через двумерный сверточный слой,Установите шаг, размер ядра равным размеру окна.. Установите выходной канал, чтобы определить размер вектора внедрения. Наконец, разверните размеры H и W и перейдите к первому измерению.
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) # -> (img_size, img_size)
patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
# 假设采取默认参数
x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4)
x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
x = torch.transpose(x, 1, 2) # 把通道维放到最后 (N, 56*56, 96)
if self.norm is not None:
x = self.norm(x)
return x
Patch Merging
Функция этого модуля состоит в том, чтобы делать даунсэмплинг перед началом каждого этапа для уменьшения разрешения, регулировать количество каналов для формирования иерархического дизайна, а также экономить определенное количество вычислений.
В CNN он используется перед началом каждого этапа с
stride=2
Слой свертки/объединения для уменьшения разрешения.
Каждая даунсэмплинг выполняется дважды, поэтомуВ направлении строки и столбца интервал 2 выбирает элементы.
Затем склеил как целый тензор и, наконец, расширил.В это время размер канала станет в 4 раза больше исходного(поскольку H и W уменьшены в 2 раза), то передайтеПолносвязный слой корректирует размер канала до удвоенного исходного размера.
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
Ниже приведена схематическая диаграмма (входные тензоры N=1, H=W=8, C=1, исключая окончательную корректировку полносвязного слоя)
Patch Merge
Лично мне кажется, что это обратная операция PixelShuffle.
Window Partition/Reverse
window partition
Функция используется для разделения тензора на окно и указания размера окна. Преобразуйте исходный тензор изN H W C
, разделен наnum_windows*B, window_size, window_size, C
,вnum_windows = H*W / window_size
, то есть количество окон. иwindow reverse
Функция является соответствующим обратным процессом. Эти две функции будут позжеWindow Attention
использовал.
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
Window Attention
Это суть этой статьи. Традиционные трансформерыВычислить внимание на основе глобального, поэтому вычислительная сложность очень высока. Swin Transformer будетВычисление внимания ограничено каждым окном, тем самым уменьшая количество вычислений.
Кратко рассмотрим формулу
Основное отличие состоит в том, что Q, K в исходной формуле расчета ВниманиеДобавлено кодирование относительного положения. Последующие эксперименты показали, что добавление кодирования относительного положения улучшает производительность модели.
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads # nH
head_dim = dim // num_heads # 每个注意力头对应的通道数
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 设置一个形状为(2*(Wh-1) * 2*(Ww-1), nH)的可学习变量,用于后续的位置编码
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
# 相关位置编码...
Далее я вынесу логику, связанную с релевантным кодированием позиции отдельно, эта часть более сложная.
Во-первых, форма тензора внимания, рассчитанного с помощью QK, имеет вид(numWindows*B, num_heads, window_size*window_size, window_size*window_size)
.
А для тензора вниманияПринимая разные элементы за начало координат, координаты других элементов также различны.,отwindow_size=2
Например, кодировка относительного положения показана на следующем рисунке.
Пример кодирования относительного положения
Сначала мы используемtorch.arange
иtorch.meshgrid
Функция генерирует соответствующие координаты, здесь мы используемwindowsize=2
Например
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
(tensor([[0, 0],
[1, 1]]),
tensor([[0, 1],
[0, 1]]))
"""
Затем сложите и разверните в двумерный вектор
coords = torch.stack(coords) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
"""
Используя механизм трансляции, вставьте измерение в первое измерение и второе измерение и выполните широковещательное вычитание, чтобы получить2, wh*ww, wh*ww
тензор
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
Поскольку выполняется вычитание, результирующий индекс начинается с отрицательного числа,Мы добавляем смещение, чтобы начать с 0.
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
Позже нам нужно расширить его до одномерного смещения. И для двух координат (1, 2) и (2, 1). различаются в двух измерениях,Но при преобразовании в одномерное смещение путем добавления координат x и y его смещение равно.
Расширить до смещения 1D
Итак, в конце мы сделали операцию умножения, чтобы отличить
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
offset multiply
Затем просуммируйте по последнему измерению, разверните в одномерную координату и зарегистрируйте как переменную, которая не участвует в обучении сети.
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
Затем мы смотрим на прямой код
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) # (1, num_heads, windowsize, windowsize)
if mask is not None: # 下文会分析到
...
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
-
Сначала форма входного тензора
numWindows*B, window_size * window_size, C
(будет объяснено позже) -
затем после
self.qkv
После этого полностью связанного слоя выполните изменение формы, отрегулируйте порядок осей и получите форму как3, numWindows*B, num_heads, window_size*window_size, c//num_heads
, и присвоеноq,k,v
. -
По формуле имеем
q
умножить на одинscale
коэффициент масштабирования, то сk
(Чтобы удовлетворить требованиям умножения матриц, необходимо поменять местами два последних измерения) для умножения. получить форму(numWindows*B, num_heads, window_size*window_size, window_size*window_size)
изattn
Тензор -
Ранее мы установили форму для кодирования положения как
(2*window_size-1*2*window_size-1, numHeads)
обучаемые переменные. Мы используем рассчитанный индекс относительной позиции кодированияself.relative_position_index
Выберите, чтобы получить форму(window_size*window_size, window_size*window_size, numHeads)
код, добавить вattn
на тензоре -
Независимо от ситуации с маской, остальное — это тот же softmax, дропаут и трансформер, что и трансформер.
V
Умножение матриц, за которым следует полносвязный слой и отсев
Shifted Window Attention
Предыдущее Window Attention вычисляет внимание под каждым окном.Чтобы лучше взаимодействовать с другими окнами, Swin Transformer также вводит операцию смещения окна.
Shift Window
Слева — Window Attention без перекрытия, а справа — Shift Window Attention, который сдвигает окно. Видно, что сдвинутое окно содержит элементы исходного соседнего окна. Но это также вводит новую проблему, а именноКоличество окон увеличилось вдвое., из исходных четырех окон в девять окон.
В реальном коде мыКосвенно достигается за счет смещения карты объектов и установки маски для внимания.. может быть вСохранить исходное количество окон, окончательный результат расчета эквивалентен.
Операция сдвига карты объектов
Сдвиг карты функций в коде выполняется с помощьюtorch.roll
для достижения, ниже приведена схематическая диаграмма
сменная операция
Если нужно
reverse cyclic shift
затем просто введите параметрshifts
Установите соответствующее положительное значение.
Attention Mask
Я думаю, что это суть Swin Transformer, установив разумную маску, пустьShifted Window Attention
в сWindow Attention
При одинаковом количестве окон достигаются эквивалентные результаты расчета.
Во-первых, мы даем индекс каждому окну после Shift Window и делаемroll
Операция (window_size=2, shift_size=1)
Shift window index
Мы надеемся, что при расчете Внимание,Разрешить вычисление QK с одним и тем же индексом и игнорировать результаты расчета QK с разными индексами..
Окончательный правильный результат показан на рисунке ниже.
Shift Attention
Чтобы получить правильный результат в исходных четырех окнах, мы должны добавить маску к результату Attention (как показано в крайнем правом углу рисунка выше).
Соответствующий код выглядит следующим образом:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
С настройками на рисунке выше мы получим вот такую маску с этим кодом
tensor([[[[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]],
[[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]]],
[[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]]],
[[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]]]])
В прямом коде предыдущего модуля внимания окна есть такой абзац
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
Добавьте маску к результату расчета внимания и выполните softmax. Значение маски установлено на -100, и соответствующее значение будет игнорироваться после softmax.
Общая архитектура трансформаторного блока
Трансформаторная блочная архитектура
Две последовательные архитектуры блоков показаны на рисунке выше.Следует отметить, что количество блоков, содержащихся в этапе, должно быть четным числом, потому что он должен попеременно содержать блок, содержащийWindow Attention
Слой и содержитShifted Window Attention
Слой.
Давайте посмотрим на прямой код блока
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
Общий процесс выглядит следующим образом
-
Сначала выполните LayerNorm на карте объектов.
-
пройти через
self.shift_size
Решите, следует ли сместить карту объектов -
Затем разрежьте карту объектов на окна одно за другим.
-
Рассчитать внимание, по
self.attn_mask
отличитьWindow Attention
все ещеShift Window Attention
-
Объединить отдельные окна обратно
-
Если вы выполняли операцию сдвига раньше, сделайте это сейчас
reverse shift
, восстановить предыдущую операцию смены -
Делайте отсев и остаточные соединения
-
Затем передаем слой LayerNorm+ полносвязный слой, а также отсева и остаточное соединение
Результаты экспериментов
Результаты экспериментов
В наборе данных ImageNet22K уровень точности может достигать поразительных 86,4%. Кроме того, он очень хорошо справляется с такими задачами, как обнаружение и сегментация.Если вам интересно, вы можете прочитать экспериментальную часть в конце статьи.
Суммировать
Инновация этой статьи великолепна: она вводит понятие окна, вводит локальность CNN, а также контролирует общий объем вычислений модели. В разделе Shift Window Attention, используя маску и операцию сдвига, очень умно добиться вычислительной эквивалентности. Код автора тоже очень радует глаз, рекомендуется к прочтению!
Добро пожаловать в GiantPandaCV, где вы увидите эксклюзивный обмен глубокими знаниями, настаиваете на оригинальности и делитесь свежими знаниями, которые мы изучаем каждый день. ( • ̀ω•́ )✧
Если у вас есть вопросы по статье или вы хотите присоединиться к группе обмена, добавьте BBuf WeChat:
QR код
В этой статье используетсяПомощник по синхронизации статейСинхронизировать