Классификация китайского текста с использованием CNN-RNN на основе тензорного потока

искусственный интеллект TensorFlow Нейронные сети модульный тест
Классификация китайского текста с использованием CNN-RNN на основе тензорного потока

Классификация китайского текста с использованием сверточных нейронных сетей, а также рекуррентных нейронных сетей. Документ CNN о классификации предложений см. в разделе «Сверточные нейронные сети для классификации предложений»; вы также можете прочитать блог Деннибрица: «Реализация CNN для классификации текста в TensorFlow» и документ CNN на уровне символов: «Сверточные сети на уровне символов для классификации текста».

Эта статья основана на упрощенной реализации TensorFlow для китайского набора данных с использованием CNN и RNN на уровне символов для классификации китайского текста и дает хорошие результаты.

окрестности

  • Python 3.5
  • TensorFlow 1.3
  • numpy
  • scikit-learn

набор данных

Используйте подмножество THUCNews для обучения и тестирования. Пожалуйста, загрузите набор данных из THUCTC: Набор инструментов для эффективной классификации китайского текста. Пожалуйста, следуйте протоколу с открытым исходным кодом поставщика данных.

В этом обучении используется 10 из них, по 6500 единиц данных в каждом.

Категории следующие:

Спорт, Финансы, Недвижимость, Мебель для дома, Образование, Технологии, Мода, Текущие события, Игры, Развлечения
Этот подмножество можно скачать здесь:
Ссылка на сайт:pan.baidu.com/s/1bpq9Eub
Пароль: ycyw

Набор данных разделен следующим образом:

  • Тренировочный набор: 5000*10
  • Набор для проверки: 500*10
  • Набор тестов: 1000*10

См. два сценария в разделе helper для процесса создания подмножеств из исходного набора данных. Среди них copy_data.sh используется для копирования 6500 файлов из каждой категории, а cnews_group.py используется для объединения нескольких файлов в один файл. После выполнения этого файла получаются три файла данных:

  • cnews.train.txt: обучающий набор (50000)
  • cnews.val.txt: проверочный набор (5000)
  • cnews.test.txt: набор тестов (10000)

предварительная обработка

data / cnews_loader.py - это файл предварительной обработки данных.

  • read_file(): прочитать данные файла;
  • build_vocab(): Создайте словарь, используя представление на уровне символов, эта функция сохранит словарь, чтобы избежать повторной обработки каждый раз;
  • read_vocab(): прочитать словарь, сохраненный на предыдущем шаге, и преобразовать его в представление {word: id};
  • read_category(): исправить категорию и преобразовать ее в представление {category: id};
  • to_words(): преобразовать часть данных, представленных идентификатором, в текст;
  • preocess_file(): преобразует набор данных из текста в представление последовательности идентификаторов фиксированной длины;
  • batch_iter(): подготавливает перемешанные пакеты данных для обучения нейронной сети.

После предварительной обработки данных формат данных выглядит следующим образом:

Data Shape Data Shape
x_train [50000, 600] y_train [50000, 10]
x_val [5000, 600] y_val [5000, 10]
x_test [10000, 600] y_test [10000, 10]

Сверточная нейронная сеть CNN

элемент конфигурации

Настраиваемые параметры CNN показаны ниже в файле cnn_model.py.

class TCNNConfig(object):    """CNN配置参数"""

    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    num_filters = 128        # 卷积核数目
    kernel_size = 5         # 卷积核尺寸
    vocab_size = 5000       # 词汇表达小

    hidden_dim = 128        # 全连接层神经元

    dropout_keep_prob = 0.5 # dropout保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 64         # 每批训练大小
    num_epochs = 10         # 总迭代轮次

    print_per_batch = 100    # 每多少轮输出一次结果
    save_per_batch = 10      # 每多少轮存入tensorboard

модель CNN

Подробнее см. в реализации cnn_model.py.

Общая структура выглядит следующим образом:

обучение и проверка

Запустите python run_cnn.py train, чтобы начать обучение.

Если вы тренировались раньше, удалите tensorboard/textcnn, чтобы результаты обучения TensorBoard не перекрывались.

Configuring CNN model...
Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:14
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:   8.92%, Time: 0:00:01 *
Iter:    100, Train Loss:   0.88, Train Acc:  73.44%, Val Loss:    1.2, Val Acc:  68.46%, Time: 0:00:04 *
Iter:    200, Train Loss:   0.38, Train Acc:  92.19%, Val Loss:   0.75, Val Acc:  77.32%, Time: 0:00:07 *
Iter:    300, Train Loss:   0.22, Train Acc:  92.19%, Val Loss:   0.46, Val Acc:  87.08%, Time: 0:00:09 *
Iter:    400, Train Loss:   0.24, Train Acc:  90.62%, Val Loss:    0.4, Val Acc:  88.62%, Time: 0:00:12 *
Iter:    500, Train Loss:   0.16, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  90.38%, Time: 0:00:15 *
Iter:    600, Train Loss:  0.084, Train Acc:  96.88%, Val Loss:   0.35, Val Acc:  91.36%, Time: 0:00:17 *
Iter:    700, Train Loss:   0.21, Train Acc:  93.75%, Val Loss:   0.26, Val Acc:  92.58%, Time: 0:00:20 *
Epoch: 2
Iter:    800, Train Loss:   0.07, Train Acc:  98.44%, Val Loss:   0.24, Val Acc:  94.12%, Time: 0:00:23 *
Iter:    900, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.27, Val Acc:  92.86%, Time: 0:00:25
Iter:   1000, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  92.82%, Time: 0:00:28
Iter:   1100, Train Loss:    0.2, Train Acc:  93.75%, Val Loss:   0.23, Val Acc:  93.26%, Time: 0:00:31
Iter:   1200, Train Loss:  0.081, Train Acc:  98.44%, Val Loss:   0.25, Val Acc:  92.96%, Time: 0:00:33
Iter:   1300, Train Loss:  0.052, Train Acc: 100.00%, Val Loss:   0.24, Val Acc:  93.58%, Time: 0:00:36
Iter:   1400, Train Loss:    0.1, Train Acc:  95.31%, Val Loss:   0.22, Val Acc:  94.12%, Time: 0:00:39
Iter:   1500, Train Loss:   0.12, Train Acc:  98.44%, Val Loss:   0.23, Val Acc:  93.58%, Time: 0:00:41
Epoch: 3
Iter:   1600, Train Loss:    0.1, Train Acc:  96.88%, Val Loss:   0.26, Val Acc:  92.34%, Time: 0:00:44
Iter:   1700, Train Loss:  0.018, Train Acc: 100.00%, Val Loss:   0.22, Val Acc:  93.46%, Time: 0:00:47
Iter:   1800, Train Loss:  0.036, Train Acc: 100.00%, Val Loss:   0.28, Val Acc:  92.72%, Time: 0:00:50
No optimization for a long time, auto-stopping...

Наилучший результат на проверочном наборе — 94,12%, и он останавливается всего после 3 итераций.

Точность и погрешность показаны на рисунке:

контрольная работа

Запустите тест python run_cnn.py, чтобы протестировать набор тестов.

Configuring CNN model...
Loading test data...
Testing...
Test Loss:   0.14, Test Acc:  96.04%
Precision, Recall and F1-Score...
             precision    recall  f1-score   support

         体育       0.99      0.99      0.99      1000
         财经       0.96      0.99      0.97      1000
         房产       1.00      1.00      1.00      1000
         家居       0.95      0.91      0.93      1000
         教育       0.95      0.89      0.92      1000
         科技       0.94      0.97      0.95      1000
         时尚       0.95      0.97      0.96      1000
         时政       0.94      0.94      0.94      1000
         游戏       0.97      0.96      0.97      1000
         娱乐       0.95      0.98      0.97      1000

avg / total       0.96      0.96      0.96     10000

Confusion Matrix...
[[991   0   0   0   2   1   0   4   1   1]
 [  0 992   0   0   2   1   0   5   0   0]
 [  0   1 996   0   1   1   0   0   0   1]
 [  0  14   0 912   7  15   9  29   3  11]
 [  2   9   0  12 892  22  18  21  10  14]
 [  0   0   0  10   1 968   4   3  12   2]
 [  1   0   0   9   4   4 971   0   2   9]
 [  1  16   0   4  18  12   1 941   1   6]
 [  2   4   1   5   4   5  10   1 962   6]
 [  1   0   1   6   4   3   5   0   1 979]]
Time usage: 0:00:05

Уровень точности в тестовом наборе достиг 96,04%, а точность, полнота и оценка f1 для различных типов превышали 0,9.

Из матрицы путаницы также видно, что эффект классификации очень хороший.

Рекуррентная нейронная сеть RNN

элемент конфигурации

Настраиваемые параметры RNN показаны ниже в файле rnn_model.py.

class TRNNConfig(object):    """RNN配置参数"""

    # 模型参数
    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    vocab_size = 5000       # 词汇表达小

    num_layers= 2           # 隐藏层层数
    hidden_dim = 128        # 隐藏层神经元
    rnn = 'gru'             # lstm 或 gru

    dropout_keep_prob = 0.8 # dropout保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 128         # 每批训练大小
    num_epochs = 10          # 总迭代轮次

    print_per_batch = 100    # 每多少轮输出一次结果
    save_per_batch = 10      # 每多少轮存入tensorboard

модель RNN

Подробнее см. в реализации rnn_model.py.

Общая структура выглядит следующим образом:

обучение и проверка

Код этой части очень похож на run_cnn.py, только модель и некоторые каталоги нужно немного изменить.

Запустите python run_rnn.py train, чтобы начать обучение.

Если вы тренировались раньше, удалите tensorboard/textrnn, чтобы избежать дублирования результатов обучения TensorBoard.

Configuring RNN model...
Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:14
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:   8.59%, Val Loss:    2.3, Val Acc:  11.96%, Time: 0:00:08 *
Iter:    100, Train Loss:   0.95, Train Acc:  64.06%, Val Loss:    1.3, Val Acc:  53.06%, Time: 0:01:15 *
Iter:    200, Train Loss:   0.61, Train Acc:  79.69%, Val Loss:   0.94, Val Acc:  69.88%, Time: 0:02:22 *
Iter:    300, Train Loss:   0.49, Train Acc:  85.16%, Val Loss:   0.63, Val Acc:  81.44%, Time: 0:03:29 *
Epoch: 2
Iter:    400, Train Loss:   0.23, Train Acc:  92.97%, Val Loss:    0.6, Val Acc:  82.86%, Time: 0:04:36 *
Iter:    500, Train Loss:   0.27, Train Acc:  92.97%, Val Loss:   0.47, Val Acc:  86.72%, Time: 0:05:43 *
Iter:    600, Train Loss:   0.13, Train Acc:  98.44%, Val Loss:   0.43, Val Acc:  87.46%, Time: 0:06:50 *
Iter:    700, Train Loss:   0.24, Train Acc:  91.41%, Val Loss:   0.46, Val Acc:  87.12%, Time: 0:07:57
Epoch: 3
Iter:    800, Train Loss:   0.11, Train Acc:  96.09%, Val Loss:   0.49, Val Acc:  87.02%, Time: 0:09:03
Iter:    900, Train Loss:   0.15, Train Acc:  96.09%, Val Loss:   0.55, Val Acc:  85.86%, Time: 0:10:10
Iter:   1000, Train Loss:   0.17, Train Acc:  96.09%, Val Loss:   0.43, Val Acc:  89.44%, Time: 0:11:18 *
Iter:   1100, Train Loss:   0.25, Train Acc:  93.75%, Val Loss:   0.42, Val Acc:  88.98%, Time: 0:12:25
Epoch: 4
Iter:   1200, Train Loss:   0.14, Train Acc:  96.09%, Val Loss:   0.39, Val Acc:  89.82%, Time: 0:13:32 *
Iter:   1300, Train Loss:    0.2, Train Acc:  96.09%, Val Loss:   0.43, Val Acc:  88.68%, Time: 0:14:38
Iter:   1400, Train Loss:  0.012, Train Acc: 100.00%, Val Loss:   0.37, Val Acc:  90.58%, Time: 0:15:45 *
Iter:   1500, Train Loss:   0.15, Train Acc:  96.88%, Val Loss:   0.39, Val Acc:  90.58%, Time: 0:16:52
Epoch: 5
Iter:   1600, Train Loss:  0.075, Train Acc:  97.66%, Val Loss:   0.41, Val Acc:  89.90%, Time: 0:17:59
Iter:   1700, Train Loss:  0.042, Train Acc:  98.44%, Val Loss:   0.41, Val Acc:  90.08%, Time: 0:19:06
Iter:   1800, Train Loss:   0.08, Train Acc:  97.66%, Val Loss:   0.38, Val Acc:  91.36%, Time: 0:20:13 *
Iter:   1900, Train Loss:  0.089, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.18%, Time: 0:21:20
Epoch: 6
Iter:   2000, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  91.42%, Time: 0:22:27 *
Iter:   2100, Train Loss:  0.062, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.56%, Time: 0:23:34
Iter:   2200, Train Loss:  0.053, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.02%, Time: 0:24:41
Iter:   2300, Train Loss:   0.12, Train Acc:  96.09%, Val Loss:   0.37, Val Acc:  90.84%, Time: 0:25:48
Epoch: 7
Iter:   2400, Train Loss:  0.014, Train Acc: 100.00%, Val Loss:   0.41, Val Acc:  90.38%, Time: 0:26:55
Iter:   2500, Train Loss:   0.14, Train Acc:  96.88%, Val Loss:   0.37, Val Acc:  91.22%, Time: 0:28:01
Iter:   2600, Train Loss:   0.11, Train Acc:  96.88%, Val Loss:   0.43, Val Acc:  89.76%, Time: 0:29:08
Iter:   2700, Train Loss:  0.089, Train Acc:  97.66%, Val Loss:   0.37, Val Acc:  91.18%, Time: 0:30:15
Epoch: 8
Iter:   2800, Train Loss: 0.0081, Train Acc: 100.00%, Val Loss:   0.44, Val Acc:  90.66%, Time: 0:31:22
Iter:   2900, Train Loss:  0.017, Train Acc: 100.00%, Val Loss:   0.44, Val Acc:  89.62%, Time: 0:32:29
Iter:   3000, Train Loss:  0.061, Train Acc:  96.88%, Val Loss:   0.43, Val Acc:  90.04%, Time: 0:33:36
No optimization for a long time, auto-stopping...

Лучший результат на проверочном наборе — 91,42%, и он останавливается после 8 итераций, что намного медленнее, чем CNN.

Точность и погрешность показаны на рисунке:

контрольная работа

Запустите тест python run_rnn.py, чтобы протестировать набор тестов.

Testing...
Test Loss:   0.21, Test Acc:  94.22%
Precision, Recall and F1-Score...
             precision    recall  f1-score   support

         体育       0.99      0.99      0.99      1000
         财经       0.91      0.99      0.95      1000
         房产       1.00      1.00      1.00      1000
         家居       0.97      0.73      0.83      1000
         教育       0.91      0.92      0.91      1000
         科技       0.93      0.96      0.94      1000
         时尚       0.89      0.97      0.93      1000
         时政       0.93      0.93      0.93      1000
         游戏       0.95      0.97      0.96      1000
         娱乐       0.97      0.96      0.97      1000

avg / total       0.94      0.94      0.94     10000

Confusion Matrix...
[[988   0   0   0   4   0   2   0   5   1]
 [  0 990   1   1   1   1   0   6   0   0]
 [  0   2 996   1   1   0   0   0   0   0]
 [  2  71   1 731  51  20  88  28   3   5]
 [  1   3   0   7 918  23   4  31   9   4]
 [  1   3   0   3   0 964   3   5  21   0]
 [  1   0   1   7   1   3 972   0   6   9]
 [  0  16   0   0  22  26   0 931   2   3]
 [  2   3   0   0   2   2  12   0 972   7]
 [  0   3   1   1   7   3  11   5   9 960]]
Time usage: 0:00:33

Уровень точности на тестовом наборе достиг 94,22%, а все типы точности, отзыва и оценки f1, за исключением домашней категории, превышали 0,9.

Из матрицы путаницы видно, что эффект классификации очень хороший.

Сравнивая две модели, можно увидеть, что производительность RNN не очень хороша в классификации домохозяйств, а остальные категории не сильно отличаются от CNN.

Вы также можете дополнительно настроить параметры для достижения лучших результатов.

адрес:GitHub.com/Gauss IC/Особенности…

Эта статья взята из публичного аккаунта глобального искусственного интеллекта WeChat.