Удалите Softmax Attention, и сложность уменьшится до O(n)

искусственный интеллект NLP

Как мы все знаем, хотя модель Transformer, основанная на механизме Attention, имеет хорошую параллельную производительность, ее пространственная и временная сложность невелики.O(n2)\mathcal{O}(n^2)уровень,nnдлина последовательности, поэтому, когдаnnКогда он относительно велик, вычислительная нагрузка модели Transformer невыносима. В последнее время много работы было посвящено уменьшению вычислительной нагрузки модели Transformer, например, сокращению модели, квантованию, дистилляции и другим методам оптимизации, или модификации структуры Attention, чтобы ее сложность можно было уменьшить доO(nlogn)\mathcal{O}(nlog⁡n)четноеO(n)\mathcal{O}(n)

бумага«Трансформаторы - это RNN: быстрые авторегрессионные преобразователи с линейным вниманием»В нем упоминался метод линеаризации внимания (Linear Attention), который вызвал у меня интерес, а затем я прочитал несколько связанных блогов и дал хорошие результаты, и, наконец, обобщил свое понимание линеаризованного внимания в этой статье.

Attention

Наиболее популярным механизмом внимания являетсяScaled-Dot Attention,Сейчас

Attention(Q,K,V)=softmax(QK)V(1)\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{aligned}

здесьQеRn×dk,KеRm×dk,VеRm×dv\boldsymbol{Q}\in \mathbb{R}^{n\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{V}\in \mathbb{R}^{m\times d_v}, для простоты я не показывал коэффициент масштабирования Внимания1d\frac{1}{\sqrt{d}}. В этой статье нас в основном интересует сцена «Самовнимание», поэтому для удобства введения мы установилиQ,K,VеRn×d\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in \mathbb{R}^{n\times d}

Удалить Софтмакс

Читатели могут не подумать, что ключевым фактором, ограничивающим производительность Attention, на самом деле является Softmax в определении! На самом деле, простой вывод приводит к такому выводу.QKTQK^TНа этом шаге мы получаемn×nn\times nМатрица, тогда нам нужно сделать Softmax

к одному1×n1\times nSoftmax вектор-строка , временная сложностьO(n)O(n), но дляn×nn\times nСделайте Softmax для каждой строки матрицы, и временная сложностьO(n2)O(n^2)

Если нет Софтмакса, то формула Внимания становится произведением трех матрицQKV\boldsymbol{QK^{\top}V}, а умножение матриц удовлетворяет ассоциативной скорости, поэтому мы можем сначала вычислитьKV\boldsymbol{K^{\top}V}, получитьd×dd\times dматрица (временная сложность этого шагаO(d2n)O(d^2n)), затем используйтеQQУмножьте его налево (временная сложность этого шагаO(d2n)O(d^2n)),так какdnd \ll n, поэтому приблизительная временная сложность равнаO(n)O(n)

Для базы BERT,d=64d=64вместо 768, почему? Поскольку 768 на самом деле получается путем сращивания нескольких головок, и каждая головкаd=64d=64

То есть удаление сложности «Внимание» из Softmax может быть сведено к самому идеальному линейному уровню.O(n)\mathcal{O}(n)! Очевидно, это наша конечная цель: линейное внимание.

общее определение

Вопрос в том, можно ли считать удаление Softmax Внимание? Может ли он по-прежнему иметь стандартный эффект внимания? Чтобы ответить на этот вопрос, мы сначала перепишем определение Scaled-Dot Attention эквивалентным образом (все векторы в этой статье являются векторами-столбцами)

Attention(Q,K,V)i=j=1neqikjvjj=1neqikj(2)\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{aligned}

Вот небольшое объяснение, прежде всего мы знаемQ,KеRn×d\boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d},сделатьM=Q×K\boldsymbol{M} = \boldsymbol{Q}\times \boldsymbol{K^{\top}}, согласно правилу умножения матриц,M\boldsymbol{M}Первая строкаQ\boldsymbol{Q}Умножьте первую строкуK\boldsymbol{K^{\top}}Все столбцы полученного

Attention(Q,K,V)iAttention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_iпредставляет первую из окончательной выходной матрицыiiРяд

qi\boldsymbol{q}_i^{\top}выражатьQеRn×d\boldsymbol{Q}\in \mathbb{R}^{n\times d}первая из матрицыiiстрока (вектор-строка)

kj\boldsymbol{k}_jвыражатьKеRd×n\boldsymbol{K^{\top}}\in \mathbb{R}^{d\times n}первая из матрицыjjстолбец (вектор-столбец)

vj\boldsymbol{v}_jвыражатьVеRd×nV^{\top}\in \mathbb{R}^{d\times n}первая из матрицыjjстолбец (вектор-столбец)

Таким образом, Scaled-Dot Attention на самом деле основано наeqikje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}для весовой парыvj\boldsymbol{v}_jСделайте средневзвешенное значение. Таким образом, мы можем предложить обобщенное определение внимания.

Attention(Q,K,V)i=j=1nsim(qi,kj)vjj=1nsim(qi,kj)(3)\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{aligned}

то есть поставитьeqikje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}заменитьqi,ki\boldsymbol{q}_i,\boldsymbol{k}_iобщая функцияsim(qi,kj)\text{sim}(\boldsymbol{q}_i,\boldsymbol{k}_j), чтобы сохранить аналогичные характеристики распределения внимания, мы требуемsim(qi,kj)0\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0Хэн установил. То есть, если мы хотим определить новое внимание, мы должны сохранить форму уравнения (3) и удовлетворитьsim(qi,kj)0\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0

Эта общая форма внимания также известна как нелокальная сеть в CV из статьи«Нелокальные нейронные сети»

несколько примеров

Если вы напрямую удалите Softmax, то этоsim(qi,kj)=qikj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_j, проблема в том, что внутренний продукт не гарантирует неотрицательность, так что это пока не разумный выбор. Ниже мы приводим несколько возможных вариантов

Стоит отметить, что первые два вида линейного внимания, представленные ниже, относятся к полю CV, а третий — кСу ЦзяньлиньБольшой парень задумал (помимо следующего введения, есть ещеEMANetУлучшения внимания в области резюме и т. д.)

Форма функции ядра

Естественная мысль: еслиqi,kj\boldsymbol{q}_i, \boldsymbol{k}_jКаждый элемент неотрицательный, поэтому скалярный продукт, естественно, неотрицательный. Для этого мы можем датьqi,kj\boldsymbol{q}_i, \boldsymbol{k}_jДобавьте функцию активации к каждомуф,ф\phi,\varphi,Сейчас

sim(qi,kj)=ф(qi)ф(kj)(4)\begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{aligned}

вф(),ф()\phi(\cdot), \varphi(\cdot)является функцией активации с неотрицательным диапазоном. Документы, упомянутые в начале статьиТрансформеры - это RNNвыбраноф(x)=ф(x)=elu(x)+1\phi(x)=\varphi(x)=\text{elu}(x)+1

elu(x)={xеслиx>0альфа(ex1)еслиx<0\text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}

Общийальфа\alphaценность[0.1,0.3][0.1, 0.3]

Если вам нужно рассказать историю, уравнение (4) можно связать с «ядерным методом», особенноф=ф\phi=\varphiчас,ф\phiэквивалентна функции ядра, иф(qi),ф(kj)\langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangleЭто внутренний продукт, определяемый функцией ядра. Думая об этом, можно обратиться к статье«Рассечение трансформатора: единое понимание внимания трансформатора через призму ядра», не распространяйтесь слишком много здесь

Волшебное использование Softmax

Еще одна более ранняя статья«Эффективное внимание: внимание с линейными сложностями»дает более интересный вариант. он замечаетQK\boldsymbol{QK^{\top}}середина,Q,KеRn×d\boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d},если"Q\boldsymbol{Q}существуетddчто одно измерение нормализовано, иK\boldsymbol{K}существуетnnчто одно измерение нормализовано", тоQK\boldsymbol{QK^{\top}}заключается в автоматическом удовлетворении нормализации, поэтому выбор, который она дает,

Attention(Q,K,V)=softmax2(Q)softmax1(K)V(5)\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{aligned}

вsoftmax1softmax_1,softmax2softmax_2соответственно в первом(n)(n), второе измерение(d)(d)Выполните операцию Softmax. То есть в это время каждый из нас даетQ,K\boldsymbol{Q},\boldsymbol{K}Добавить Softmax, а не закончитьQK\boldsymbol{QK^{\top}}Затем добавьте Софтмакс

Фактически можно доказать, что эта форма также является частным случаем формулы (4)​, которая соответствуетф(qi)=softmax(qi),ф(kj)=ekj\phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i),\varphi(\boldsymbol{k}_j)=e^{\boldsymbol{k}_j}, читатель может сделать вывод

идея Сушена

Здесь Су Шэнь подал идею. Отправной точкой этой идеи является уже не уравнение (4), а расширение Тейлора исходного определения (2). В расширенном по Тейлору мы имеем

eqikj1+qikj(6)\begin{aligned}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{aligned}

еслиqikj1\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1, то можно гарантировать неотрицательность правой части, так чтоsim(qi,kj)=1+qikj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j. Читатели, возможно, уже подумали об этом здесь и хотят гарантироватьqikj1\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1, необходимо толькоqi,kj\boldsymbol{q}_i,\boldsymbol{k}_jДелатьl2l_2Нормализованный. Поэтому окончательный план, предложенный Су Шэнем, таков:

sim(qi,kj)=1+(qiqi)(kjkj)(7)\begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{aligned}

какx=[x1,x2,...,xn]\boldsymbol{x}=[x_1,x_2,...,x_n],ноx=x12+x22++xn2\Vert x\Vert=\sqrt{x_1^2+x_2^2+...+x_n^2}

Это отличается от уравнения (4), но теоретически оно ближе к исходному масштабируемому точечному вниманию.

выполнить

Это в основном для реализации метода, предложенного Су Шеном, но из-за ограниченного уровня автора, в финальном реализованном коде на самом деле есть некоторые проблемы, в основном:

  1. По результатам тестов улучшенная скорость расчета не улучшилась
  2. Не могу суммировать до 1

Реализация кода предназначена в основном дляPyTorch реализация BERTКод этой статьи, точнее, фактически только измененScaledDotProductAttentionЭто функция, поэтому ниже выложена только эта часть кода

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        Q = F.normalize(Q, dim=3)
        K = F.normalize(K, dim=3)
        M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
        M_sum = torch.sum(M, dim=3)
        M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
        attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
        context = torch.matmul(attn, V)
        return context

Если у вас есть лучший метод реализации, пожалуйста, дайте мне знать

Reference