Интерпретация кода лонгформера pytorch (обновление)

искусственный интеллект NLP

longformer

Поскольку трансформатор тока является самостоятельным вниманием ко всей последовательности, это приводит кO(n2)O(n^2)временная сложность. Поэтому модель преобразователя ограничена 512 измерениями длины последовательности. В реальных документах количество слов часто превышает 512. Подход преобразователя состоит в том, чтобы разделить его на несколько частично перекрывающихся 512-мерных выборок и ввести их в модель для обучения. Это, очевидно, вносит погрешность в получение контекстной информации и расчет функции потерь при детализации документа.

По этой причине автор был вдохновлен CNN, чтобы предложить механизм внимания скользящего окна, longformer, В отличие от преобразователя, каждое слово должно обращать внимание на 512 слов в последовательности. Слово longformer обращает внимание только на левую и правую стороны.wwСлово окна, то есть окно каждого слова2w+12w+1, сделать самостоятельное внимание в окне. Это снижает временную сложность доO(wn)O(wn)

Логика основного кода — обработка окон

Код автора:GitHub.com/Аллен А.И./ Вы…

Эта статья основана на коде pytorch.Автор утверждает, что функции pytorch поддерживают операции с фрагментами, а tensorflow — нет. По этой причине для изучения того, как работает окно внимания в статье автора, необходимо знать, какую структуру данных автор обработал тензором.

Сначала обратите внимание на qk: автор преобразует признаки последовательности [bs, seqlen, dim] в форму [bs, seqlen//w-1, 2w, dim], используя форму chunk. Используемый метод заключается в чтении значения вектора из памяти путем перехода, то есть адрес будет считывать количество окон.wwВторосортный:

# input;tensor q, tensor k, int w
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0
assert q.size() == k.size()

chunks_count = seqlen // w - 1

# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)

chunk_q = _chunk(q, w)
chunk_k = _chunk(k, w)

# matrix multipication
# bcxd: bsz*num_heads x chunks x 2w x head_dim
# bcyd: bsz*num_heads x chunks x 2w x head_dim
# bcxy: bsz*num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k))  # multiply

Рабочая логика метода Chank показана на следующем рисунке:image.pngОпуститься до уровня кода — идеальная идея автора.

def _chunk(x, w):
    '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''

    # non-overlapping chunks of size = 2w
    x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))

    # use `as_strided` to make the chunks overlap with an overlap size = w
    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    return x.as_strided(size=chunk_size, stride=chunk_stride)

Зачем читать из памяти, а не использовать механизм копирования? Поскольку обновление градиента на самом деле является обновлением базовой памяти с точки зрения нижнего уровня, несколько вызовов одного и того же адреса памяти могут получить несколько обновлений. Копирование приведет к ошибкам обновления градиента, авторский прием действительно хитрый.

В этом методе несколько окон могут быть сгенерированы с избыточностью путем обработки признаков последовательности. Нарисуйте в качестве примера картинку, на которой bs и dim опущены для облегчения понимания:

image.pngКод работает:

image.png