- Статья перенесена из публичного аккаунта WeChat [Machine Learning Alchemy]
- Автор заметки: Brother Alchemy (перепечатано с разрешения)
- Контактное лицо: WeChat cyx645016617
- Название диссертации: «Маскированные автоэнкодеры — масштабируемые обучающиеся зрительного восприятия»
0 резюме
В этой статье показано, что маскированный автоэнкодер (MAE) является масштабируемым самоконтролируемым обучающимся для компьютерного зрения. Наш метод MAE прост: мы маскируем случайные участки входного изображения и восстанавливаем недостающие пиксели.
В основе такой конструкции лежат два ядра:
- Мы разрабатываем асимметричную архитектуру кодер-декодер, в которой кодировщик работает только с видимым подмножеством патчей (без масок), а облегченный декодер выводится из скрытого представления, а токен маски восстанавливает исходное изображение.
- Во-вторых, мы обнаружили, что маскирование большой доли входных изображений (например, 75%) дает нетривиальную и содержательную самостоятельную задачу.
Объединение этих двух схем позволяет нам эффективно обучать большие модели: мы ускоряем обучение (в 3 раза и более) и повышаем точность.
1 способ
Как видно из картинки, модель на самом деле очень простая:
- Это структура-трансформер, похожая на VIT, изображение делится на патчи, и тогда модель может видеть только малую часть (25%) патчей, а остальные 75% патчей невидимы;
- Вход кодировщика — это видимая 25%-я заплата плюс маска положения 25%;
- После этого декодер используется для восстановления 25% информации о патчах ко всему изображению для реконструкции.
- После предварительной подготовки декодер отбрасывается, а кодировщик применяется к неповрежденным изображениям, чтобы создать представление для задачи распознавания.
2 Раздел кода — первый шаг
Поскольку это просто, просто посмотрите на код напрямую. Код воспроизведен самим большим парнем, а не официальным!
def pretrain_mae_small_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=384,
encoder_depth=12,
encoder_num_heads=6,
encoder_num_classes=0,
decoder_num_classes=768,
decoder_embed_dim=192,
decoder_depth=4,
decoder_num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model
из кода,patch_size
,encoder_embed_dim
Эти параметры несложно понять.Этот PretrainVisionTransformer представляет собой классическую структуру преобразователя VIT (сначала угадайте, потом проверьте).
3 Раздел кода — шаг второй
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
use_learnable_pos_emb=use_learnable_pos_emb)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values)
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
trunc_normal_(self.mask_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'mask_token'}
def forward(self, x, mask):
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
return x
В целом, он состоит из кодировщика и декодера. Перечислим параметры:
-
img_size
=224 -
patch_size
=16 -
encoder_in_chans
=3 -
encoder_num_classes
=0 -
encoder_embed_dim
=768 -
encoder_depth
=12 -
encoder_num_heads
=12 -
decoder_num_classes
=768 -
decoder_embed_dim
=512 -
decoder_depth
=8 -
decoder_num_heads
=8 -
mlp_ratio
=4. -
qkv_bias
=False -
qk_scale
=None -
drop_rate
=0. -
attn_drop_rate
=0. -
drop_path_rate
=0. -
norm_layer
=nn.LayerNorm -
init_values
=0. -
use_learnable_pos_emb
=False -
num_classes
=0 # avoid the error from create_fn in timm -
in_chans
=0, # avoid the error from create_fn in timm
4 Кодовая часть - энкодер
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_learnable_pos_emb=False):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
# TODO: Add the cls token
# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
# sine-cosine positional embeddings
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
# trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, mask):
x = self.patch_embed(x)
# cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
for blk in self.blocks:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def forward(self, x, mask):
x = self.forward_features(x, mask)
x = self.head(x)
return x
При построении Encoder используются следующие модули:
- self.patch_embed: исправить изображение
- Блок с глубиной укладки, часть извлечения функций трансформатора
- self.head: Это слой идентичности, бессмысленный.
5 Раздел кода — patch_embed
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
Как вы можете видеть в приведенном здесь коде, на самом деле он содержит толькоself.proj(x)
Это просто сверточный слой Я сделал простую демонстрацию, чтобы изучить, как модуль patchembed влияет на форму изображения:
Входные данные представляют собой карту объектов размером 1x3x224x224, а выходная форма y имеет вид:
Здесь я понимаю процесс и значение двух параметров:
- 196 означает количество патчей в изображении, 224 — это ввод, а 16 — размер патча, поэтому изображение имеет (224/16) квадратных патчей, то есть 196 патчей;
- Каждый патч сверточно закодирован в 768-мерный вектор. 768 соответствующих гиперпараметров
embed_dim
- Здесь для kernel_size и шага заданы те же масштабы, что и для патча, что на самом деле математически эквивалентно полностью связанному слою для всех элементов патча. Патч содержит 14x14 пикселей, что составляет 196 пикселей. Такой сверточный слой эквивалентен полносвязному слою от 196 до 768.
6 Раздел кода - Блок
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
Этот блок содержит три модуля,Attention
,Mlp
иDropPath
.
Вход x сначала нормализуется по норме слоя, затем помещается в Attention, затем DropPath, затем нормализация нормы слоя, затем Mlp и затем DropPath.
6 Раздел кода - Внимание
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
Через эту линию полностью связанных слоев входные 768 признаков расширяются до 2304 измерений, соответствующих трем переменным q, k и v соответственно.
С помощью reshape измените [batch, 196, 2304] на [1, 196, 3, 8, 96], а затем транспонируйте его в [3, 1, 8, 196, 96]. Это 3 просто присваивается qkv. Затем, после двух матричных умножений, окончательный результат по-прежнему имеет размерность [batch, 196, 768].
[Сводка]: Attention на самом деле является модулем извлечения функций, входные данные — [пакет, 196, 768], а выходные данные — [пакет, 196, 768].
7 Раздел кода - млп
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
Этот MLP представляет собой двухслойный полносвязный слой, который увеличивает 768 до размера 768x4, а затем становится 768.
7 Раздел кода — декодирование
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_patches=196,
):
super().__init__()
self.num_classes = num_classes
assert num_classes == 3 * patch_size ** 2
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_size = patch_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, return_token_num):
for blk in self.blocks:
x = blk(x)
if return_token_num > 0:
x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
else:
x = self.head(self.norm(x)) # [B, N, 3*16^2]
return x
Однако в целом такое воспроизведение кода отличается от MAE в статье. Проблема с частью декодера. Тогда исправьте это сами.
Я думаю общая проблема в том, что в этом коде после энкодера и до декодера отсутствует восстановление положения изображения. Это шаги в красной рамке на следующем рисунке:
Однако наличие или отсутствие этого шага не влияет на обучение модели, а только на генерацию полного реконструированного графа.