Как мы все знаем, хотя модель Transformer, основанная на механизме Attention, имеет хорошую параллельную производительность, ее пространственная и временная сложность невелики.
бумага«Трансформаторы - это RNN: быстрые авторегрессионные преобразователи с линейным вниманием»В нем упоминался метод линеаризации внимания (Linear Attention), который вызвал у меня интерес, а затем я прочитал несколько связанных блогов и дал хорошие результаты, и, наконец, обобщил свое понимание линеаризованного внимания в этой статье.
Attention
Наиболее популярным механизмом внимания являетсяScaled-Dot Attention,Сейчас
здесь
Удалить Софтмакс
Читатели могут не подумать, что ключевым фактором, ограничивающим производительность Attention, на самом деле является Softmax в определении! На самом деле, простой вывод приводит к такому выводу.
к одному
Softmax вектор-строка , временная сложность , но для Сделайте Softmax для каждой строки матрицы, и временная сложность
Если нет Софтмакса, то формула Внимания становится произведением трех матриц
Для базы BERT,
вместо 768, почему? Поскольку 768 на самом деле получается путем сращивания нескольких головок, и каждая головка
То есть удаление сложности «Внимание» из Softmax может быть сведено к самому идеальному линейному уровню.
общее определение
Вопрос в том, можно ли считать удаление Softmax Внимание? Может ли он по-прежнему иметь стандартный эффект внимания? Чтобы ответить на этот вопрос, мы сначала перепишем определение Scaled-Dot Attention эквивалентным образом (все векторы в этой статье являются векторами-столбцами)
Вот небольшое объяснение, прежде всего мы знаем
,сделать , согласно правилу умножения матриц, Первая строка Умножьте первую строку Все столбцы полученного
представляет первую из окончательной выходной матрицы Ряд
выражать первая из матрицы строка (вектор-строка)
выражать первая из матрицы столбец (вектор-столбец)
выражать первая из матрицы столбец (вектор-столбец)
Таким образом, Scaled-Dot Attention на самом деле основано на
то есть поставить
Эта общая форма внимания также известна как нелокальная сеть в CV из статьи«Нелокальные нейронные сети»
несколько примеров
Если вы напрямую удалите Softmax, то это
Стоит отметить, что первые два вида линейного внимания, представленные ниже, относятся к полю CV, а третий — кСу ЦзяньлиньБольшой парень задумал (помимо следующего введения, есть ещеEMANetУлучшения внимания в области резюме и т. д.)
Форма функции ядра
Естественная мысль: если
в
Общий
ценность
Если вам нужно рассказать историю, уравнение (4) можно связать с «ядерным методом», особенно
Волшебное использование Softmax
Еще одна более ранняя статья«Эффективное внимание: внимание с линейными сложностями»дает более интересный вариант. он замечает
в
Фактически можно доказать, что эта форма также является частным случаем формулы (4), которая соответствует
идея Сушена
Здесь Су Шэнь подал идею. Отправной точкой этой идеи является уже не уравнение (4), а расширение Тейлора исходного определения (2). В расширенном по Тейлору мы имеем
если
как
,но
Это отличается от уравнения (4), но теоретически оно ближе к исходному масштабируемому точечному вниманию.
выполнить
Это в основном для реализации метода, предложенного Су Шеном, но из-за ограниченного уровня автора, в финальном реализованном коде на самом деле есть некоторые проблемы, в основном:
- По результатам тестов улучшенная скорость расчета не улучшилась
- Не могу суммировать до 1
Реализация кода предназначена в основном дляPyTorch реализация BERTКод этой статьи, точнее, фактически только измененScaledDotProductAttention
Это функция, поэтому ниже выложена только эта часть кода
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
context = torch.matmul(attn, V)
return context
Если у вас есть лучший метод реализации, пожалуйста, дайте мне знать