Подробное объяснение алгоритма BPE

искусственный интеллект NLP

Byte Pair Encoding

В моделях НЛП ввод обычно представляет собой предложение, например."I went to New York last week.", предложение содержит много слов (токенов). Традиционная практика заключается в разделении этих слов пробелами, например.['i', 'went', 'to', 'New', 'York', 'last', 'week']. Однако у этого подхода есть много проблем, например, модель не может пройтиold, older, oldestузнал отношения междуsmart, smarter, smartestОтношения между. Если мы можем разделить токен на несколько подтокенов, вышеуказанная проблема может быть решена очень хорошо. В этой статье будет подробно описан наиболее часто используемый алгоритм субтокенов — BPE (Byte-Pair Encoding).

Теперь модели НЛП с большей производительностью, такие как GPT, BERT, RoBERTa и т. д., будут иметь процесс WordPiece при предварительной обработке данных.Основным методом реализации является BPE (Byte-Pair Encoding). Конкретно, например['loved', 'loving', 'loves']эти три слова. На самом деле семантика сама по себе и есть значение слова «любовь», но если использовать слова как единицы, то это не одни и те же слова.В английском языке много слов с разными суффиксами, что сделает словарный запас очень большим. Скорость становится медленнее, а тренировочный эффект не очень хороший. Посредством обучения алгоритм BPE может разбить приведенные выше 3 слова на["lov","ed","ing","es"]Несколько частей, которые могут разделять значение и время самого слова, эффективно сокращая количество словарного запаса. Поток алгоритма выглядит следующим образом:

  1. Установите максимальное количество подсловVV
  2. Разделить все слова на отдельные символы и добавить в конце точку</w>, при этом отмечая количество вхождений слова. Например,"low"Слово появляется 5 раз, затем оно будет обработано как{'l o w </w>': 5}
  3. считать каждый последующийпара байтовЧастота встречаемости, выберите самую высокую частоту для объединения в новое подслово
  4. Повторяйте шаг 3, пока не достигнете размера словарного запаса подслов, установленного в шаге 1, или до следующей максимальной частоты.пара байтовЧастота встречаемости 1

Например

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

Наиболее часто встречающаяся пара байтов **eиs**, есть 6+3=9 раз, так что объедините их

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}

Наиболее часто встречающаяся пара байтов **esиt**, есть 6+3=9 раз, так что объедините их

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}

Наиболее часто встречающаяся пара байтов **estи</w>**, есть 6+3=9 раз, так что объедините их

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

Наиболее часто встречающаяся пара байтов **lиo**, есть 5+2=7 раз, поэтому объедините их

{'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

Наиболее часто встречающаяся пара байтов **loиw**, есть 5+2=7 раз, поэтому объедините их

{'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

...продолжать итерацию до тех пор, пока не будет достигнут заданный размер словаря подслов или следующая по частоте пара байтов не появится с частотой 1. Таким образом, мы получаем более подходящий словарь, который может иметь некоторые сочетания, не являющиеся словами, а форму, имеющую смысл сам по себе.

стоп-символ</w>Смысл в том, что подслово является суффиксом слова. Например:stне добавлять</w>может стоять в начале слова, напримерst ar;добавлять</w>Указывает, что измененное слово находится в конце слова, напримерwide st</w>, они имеют разные значения

Внедрение BPE

import re, collections

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
    return tokens

vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

# Get free book from Gutenberg
# wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt
# vocab = get_vocab('pg16457.txt')

print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')

num_merges = 5
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')

Вывод выглядит следующим образом

==========
Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 17, 'r': 2, 'n': 6, 's': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 0
Best pair: ('e', 's')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'es': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 1
Best pair: ('es', 't')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 2
Best pair: ('est', '</w>')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 3
Best pair: ('l', 'o')
Tokens: defaultdict(<class 'int'>, {'lo': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========
Iter: 4
Best pair: ('lo', 'w')
Tokens: defaultdict(<class 'int'>, {'low': 7, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'w': 9, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========

кодировать и декодировать

кодирование

В предыдущем алгоритме мы получили словарь подслова, и словарь отсортирован по количеству символов. При кодировании для каждого слова просмотрите отсортированный словарь подслов, чтобы найти, существует ли токен, который является подстрокой текущего слова, если да, то токен является одним из токенов, представляющих слово.

Мы итерируем от самой длинной лексемы к самой короткой лексеме, пытаясь заменить подстроку в каждом слове лексемой. В конце концов, мы переберем все токены и заменим все подстроки токенами. Если все еще есть подстроки, которые не были заменены, но все токены были перебраны, замените оставшиеся подстроки специальными токенами, такими как<unk>

Например

# 给定单词序列
["the</w>", "highest</w>", "mountain</w>"]

# 排好序的subword表
# 长度 6         5           4        4         4       4          2
["errrr</w>", "tain</w>", "moun", "est</w>", "high", "the</w>", "a</w>"]

# 迭代结果
"the</w>" -> ["the</w>"]
"highest</w>" -> ["high", "est</w>"]
"mountain</w>" -> ["moun", "tain</w>"]
расшифровка

Просто сложите все жетоны вместе, например

# 编码序列
["the</w>", "high", "est</w>", "moun", "tain</w>"]

# 解码序列
"the</w> highest</w> mountain</w>"

Реализация кодирования и декодирования

import re, collections

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1

    return vocab

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens_from_vocab(vocab):
    tokens_frequencies = collections.defaultdict(int)
    vocab_tokenization = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

def measure_token_length(token):
    if token[-4:] == '</w>':
        return len(token[:-4]) + 1
    else:
        return len(token)

def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
    
    if string == '':
        return []
    if sorted_tokens == []:
        return [unknown_token]

    string_tokens = []
    for i in range(len(sorted_tokens)):
        token = sorted_tokens[i]
        token_reg = re.escape(token.replace('.', '[.]'))

        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        if len(matched_positions) == 0:
            continue
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]

        substring_start_position = 0
        for substring_end_position in substring_end_positions:
            substring = string[substring_start_position:substring_end_position]
            string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
            string_tokens += [token]
            substring_start_position = substring_end_position + len(token)
        remaining_substring = string[substring_start_position:]
        string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
        break
    return string_tokens

# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

vocab = get_vocab('pg16457.txt')

print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')

num_merges = 10000
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
    print('All tokens: {}'.format(tokens_frequencies.keys()))
    print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
    print('==========')

# Let's check how tokenization will be for a known word
word_given_known = 'mountains</w>'
word_given_unknown = 'Ilikeeatingapples!</w>'

sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]

print(sorted_tokens)

word_given = word_given_known 

print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
    print('Tokenization of the known word:')
    print(vocab_tokenization[word_given])
    print('Tokenization treating the known word as unknown:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
    print('Tokenizating of the unknown word:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))

word_given = word_given_unknown 

print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
    print('Tokenization of the known word:')
    print(vocab_tokenization[word_given])
    print('Tokenization treating the known word as unknown:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
    print('Tokenizating of the unknown word:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))

Вывод выглядит следующим образом

Tokenizing word: mountains</w>...
Tokenization of the known word:
['mountains</w>']
Tokenization treating the known word as unknown:
['mountains</w>']
Tokenizing word: Ilikeeatingapples!</w>...
Tokenizating of the unknown word:
['I', 'like', 'ea', 'ting', 'app', 'l', 'es!</w>']

Reference