Введение
Используйте Keras для реализации сиамской сети и расчета сходства предложений.
принцип
Сиамская сеть относится к двум или более идентичным подсетям в сети, которые в основном используются в таких задачах, как расчет схожести предложений, сопоставление лиц, идентификация подписи и т. д.
- Расчет схожести предложений: введите два предложения, чтобы определить, имеют ли они одинаковое значение.
- Сопоставление лиц: введите два лица, чтобы определить, являются ли они одним и тем же человеком.
- Идентификация подписи: введите две подписи, чтобы определить, написаны ли они одним и тем же лицом.
Взяв в качестве примера расчет схожести предложений, подсети с обеих сторон абсолютно одинаковы от уровня внедрения до уровня LSTM.Вся модель называется MaLSTM (Manhattan LSTM).
Представление двух предложений с фиксированной длиной получается через окончательный вывод слоя LSTM, и следующая формула используется для вычисления сходства между ними, и сходство находится между 0 и 1.
данные
Используйте вопросы Quora на Kaggle для сопоставления данных, Quora соответствует иностранному Zhihu,woohoo.cardreform.com/from/Quora-но…
Учебный набор и тестовый набор содержат 404290 и 3563475 фрагментов данных соответственно, каждый фрагмент данных включает следующие поля, но тестовый набор не включает поле is_duplicate.
- id: идентификатор пары вопросов
- qid1: идентификатор вопроса 1
- qid2: идентификатор вопроса 2
- вопрос1: текст вопроса 1
- вопрос2: текст вопроса 2
- is_duplicate: два вопроса означают одно и то же, 0 или 1
выполнить
загрузить библиотеку
# -*- coding: utf-8 -*-
from keras.preprocessing.sequence import pad_sequences
from keras.models import Model
from keras.layers import Input, Embedding, LSTM, Lambda
import keras.backend as K
from keras.optimizers import Adam
import pandas as pd
import numpy as np
from gensim.models import KeyedVectors
from nltk.corpus import stopwords
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import re
from tqdm import tqdm
import pickle
Загрузка обучающих и тестовых наборов
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')
print(len(train_df), len(test_df))
train_df.head()
Загрузите стоп-слова в nltk (Natural Language Toolkit) и определите функцию предварительной обработки текста.
# 如果报错nltk没有stopwords则下载
# import nltk
# nltk.download('stopwords')
stops = set(stopwords.words('english'))
def preprocess(text):
# input: 'Hello are you ok?'
# output: ['Hello', 'are', 'you', 'ok', '?']
text = str(text)
text = text.lower()
text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text) # 去掉其他符号
text = re.sub(r"what's", "what is ", text) # 缩写
text = re.sub(r"\'s", " is ", text) # 缩写
text = re.sub(r"\'ve", " have ", text) # 缩写
text = re.sub(r"can't", "cannot ", text) # 缩写
text = re.sub(r"n't", " not ", text) # 缩写
text = re.sub(r"i'm", "i am ", text) # 缩写
text = re.sub(r"\'re", " are ", text) # 缩写
text = re.sub(r"\'d", " would ", text) # 缩写
text = re.sub(r"\'ll", " will ", text) # 缩写
text = re.sub(r",", " ", text) # 去除逗号
text = re.sub(r"\.", " ", text) # 去除句号
text = re.sub(r"!", " ! ", text) # 保留感叹号
text = re.sub(r"\/", " ", text) # 去掉右斜杠
text = re.sub(r"\^", " ^ ", text) # 其他符号
text = re.sub(r"\+", " + ", text) # 其他符号
text = re.sub(r"\-", " - ", text) # 其他符号
text = re.sub(r"\=", " = ", text) # 其他符号
text = re.sub(r"\'", " ", text) # 去掉单引号
text = re.sub(r"(\d+)(k)", r"\g<1>000", text) # 把30k等替换成30000
text = re.sub(r":", " : ", text) # 其他符号
text = re.sub(r" e g ", " eg ", text) # 其他词
text = re.sub(r" b g ", " bg ", text) # 其他词
text = re.sub(r" u s ", " american ", text) # 其他词
text = re.sub(r"\0s", "0", text) # 其他词
text = re.sub(r" 9 11 ", " 911 ", text) # 其他词
text = re.sub(r"e - mail", "email", text) # 其他词
text = re.sub(r"j k", "jk", text) # 其他词
text = re.sub(r"\s{2,}", " ", text) # 将多个空白符替换成一个空格
return text.split()
Загрузите предварительно обученный Google 300-мерный вектор слов
word2vec = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
Расположите словарь, всего 58564 слова, замените текст представлением целочисленной последовательности и получите матрицу векторного отображения слов.
vocabulary = []
word2id = {}
id2word = {}
for df in [train_df, test_df]:
for i in tqdm(range(len(df))):
row = df.iloc[i]
for column in ['question1', 'question2']:
q2n = []
for word in preprocess(row[column]):
if word in stops or word not in word2vec.vocab:
continue
if word not in vocabulary:
word2id[word] = len(vocabulary) + 1
id2word[len(vocabulary) + 1] = word
vocabulary.append(word)
q2n.append(word2id[word])
else:
q2n.append(word2id[word])
df.at[i, column] = q2n
embedding_dim = 300
embeddings = np.random.randn(len(vocabulary) + 1, embedding_dim)
embeddings[0] = 0 # 零填充对应的词向量
for index, word in enumerate(vocabulary):
embeddings[index] = word2vec.word_vec(word)
del word2vec
print(len(vocabulary))
Разделите набор для обучения и набор для проверки, дополнив целочисленную последовательность до одинаковой длины.
maxlen = max(train_df.question1.map(lambda x: len(x)).max(),
train_df.question2.map(lambda x: len(x)).max(),
test_df.question1.map(lambda x: len(x)).max(),
test_df.question2.map(lambda x: len(x)).max())
valid_size = 40000
train_size = len(train_df) - valid_size
X = train_df[['question1', 'question2']]
Y = train_df['is_duplicate']
X_train, X_valid, Y_train, Y_valid = train_test_split(X, Y, test_size=valid_size)
X_train = {'left': X_train.question1.values, 'right': X_train.question2.values}
X_valid = {'left': X_valid.question1.values, 'right': X_valid.question2.values}
Y_train = np.expand_dims(Y_train.values, axis=-1)
Y_valid = np.expand_dims(Y_valid.values, axis=-1)
# 前向填充或截断
X_train['left'] = np.array(pad_sequences(X_train['left'], maxlen=maxlen))
X_train['right'] = np.array(pad_sequences(X_train['right'], maxlen=maxlen))
X_valid['left'] = np.array(pad_sequences(X_valid['left'], maxlen=maxlen))
X_valid['right'] = np.array(pad_sequences(X_valid['right'], maxlen=maxlen))
print(X_train['left'].shape, X_train['right'].shape)
print(X_valid['left'].shape, X_valid['right'].shape)
print(Y_train.shape, Y_valid.shape)
Определите модель и обучение
hidden_size = 128
gradient_clipping_norm = 1.25
batch_size = 64
epochs = 20
def exponent_neg_manhattan_distance(args):
left, right = args
return K.exp(-K.sum(K.abs(left - right), axis=1, keepdims=True))
left_input = Input(shape=(None,), dtype='int32')
right_input = Input(shape=(None,), dtype='int32')
embedding_layer = Embedding(len(embeddings), embedding_dim, weights=[embeddings], input_length=maxlen, trainable=False)
embedded_left = embedding_layer(left_input)
embedded_right = embedding_layer(right_input)
shared_lstm = LSTM(hidden_size)
left_output = shared_lstm(embedded_left)
right_output = shared_lstm(embedded_right)
malstm_distance = Lambda(exponent_neg_manhattan_distance, output_shape=(1,))([left_output, right_output])
malstm = Model([left_input, right_input], malstm_distance)
optimizer = Adam(clipnorm=gradient_clipping_norm)
malstm.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
history = malstm.fit([X_train['left'], X_train['right']], Y_train, batch_size=batch_size, epochs=epochs,
validation_data=([X_valid['left'], X_valid['right']], Y_valid))
Постройте кривую точности и кривую функции потерь во время обучения
# Plot Accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
# Plot Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()
Потери обучающего набора продолжают уменьшаться, но потери проверочного набора имеют тенденцию быть плоскими, что указывает на то, что способности модели к обобщению недостаточно.
Правильная скорость обучающего набора увеличилась до более чем 86%, в то время как правильная скорость набора проверки осталась на уровне около 80%, и модель нуждается в дальнейшем улучшении.
Сохраните модель для последующего использования
malstm.save('malstm.h5')
with open('data.pkl', 'wb') as fw:
pickle.dump({'word2id': word2id, 'id2word': id2word}, fw)
Используйте обученную модель на одной машине, чтобы выполнить простой тест, случайным образом возьмите несколько образцов из обучающего набора и посмотрите, соответствует ли результат классификации модели метке, в основном, чтобы знать, как применять модель для вывода.
# -*- coding: utf-8 -*-
from keras.preprocessing.sequence import pad_sequences
from keras.models import Model, load_model
import pandas as pd
import numpy as np
from nltk.corpus import stopwords
import re
import pickle
with open('data.pkl', 'rb') as fr:
data = pickle.load(fr)
word2id = data['word2id']
id2word = data['id2word']
train_df = pd.read_csv('train.csv')
stops = set(stopwords.words('english'))
def preprocess(text):
# input: 'Hello are you ok?'
# output: ['Hello', 'are', 'you', 'ok', '?']
text = str(text)
text = text.lower()
text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text) # 去掉其他符号
text = re.sub(r"what's", "what is ", text) # 缩写
text = re.sub(r"\'s", " is ", text) # 缩写
text = re.sub(r"\'ve", " have ", text) # 缩写
text = re.sub(r"can't", "cannot ", text) # 缩写
text = re.sub(r"n't", " not ", text) # 缩写
text = re.sub(r"i'm", "i am ", text) # 缩写
text = re.sub(r"\'re", " are ", text) # 缩写
text = re.sub(r"\'d", " would ", text) # 缩写
text = re.sub(r"\'ll", " will ", text) # 缩写
text = re.sub(r",", " ", text) # 去除逗号
text = re.sub(r"\.", " ", text) # 去除句号
text = re.sub(r"!", " ! ", text) # 保留感叹号
text = re.sub(r"\/", " ", text) # 去掉右斜杠
text = re.sub(r"\^", " ^ ", text) # 其他符号
text = re.sub(r"\+", " + ", text) # 其他符号
text = re.sub(r"\-", " - ", text) # 其他符号
text = re.sub(r"\=", " = ", text) # 其他符号
text = re.sub(r"\'", " ", text) # 去掉单引号
text = re.sub(r"(\d+)(k)", r"\g<1>000", text) # 把30k等替换成30000
text = re.sub(r":", " : ", text) # 其他符号
text = re.sub(r" e g ", " eg ", text) # 其他词
text = re.sub(r" b g ", " bg ", text) # 其他词
text = re.sub(r" u s ", " american ", text) # 其他词
text = re.sub(r"\0s", "0", text) # 其他词
text = re.sub(r" 9 11 ", " 911 ", text) # 其他词
text = re.sub(r"e - mail", "email", text) # 其他词
text = re.sub(r"j k", "jk", text) # 其他词
text = re.sub(r"\s{2,}", " ", text) # 将多个空白符替换成一个空格
return text.split()
malstm = load_model('malstm.h5')
correct = 0
for i in range(5):
print('Testing Case:', i + 1)
random_sample = dict(train_df.iloc[np.random.randint(len(train_df))])
left = random_sample['question1']
right = random_sample['question2']
print('Origin Questions...')
print('==', left)
print('==', right)
left = preprocess(left)
right = preprocess(right)
print('Preprocessing...')
print('==', left)
print('==', right)
left = [word2id[w] for w in left if w in word2id]
right = [word2id[w] for w in right if w in word2id]
print('To ids...')
print('==', left, [id2word[i] for i in left])
print('==', right, [id2word[i] for i in right])
left = np.expand_dims(left, 0)
right = np.expand_dims(right, 0)
maxlen = max(left.shape[-1], right.shape[-1])
left = pad_sequences(left, maxlen=maxlen)
right = pad_sequences(right, maxlen=maxlen)
print('Padding...')
print('==', left.shape)
print('==', right.shape)
pred = malstm.predict([left, right])
pred = 1 if pred[0][0] > 0.5 else 0
print('True:', random_sample['is_duplicate'])
print('Pred:', pred)
if pred == random_sample['is_duplicate']:
correct += 1
print(correct / 5)
Ссылаться на
- Как предсказать пары вопросов Quora с помощью Siamese Manhattan LSTM:GitHub.com/Eli или Sing/me…
- Сиамские рекуррентные архитектуры для изучения сходства предложений:Woohoo.Peach.Quota/~Jonas M/inf…