Структура сети рисования Кераса

машинное обучение
Структура сети рисования Кераса

Мало знаний, большой вызов! Эта статья участвует в "Необходимые знания для программистов«Творческая деятельность.

Keras — это фреймворк глубокого обучения на основе теано/тензорного потока, написанный на чистом питоне. Keras — это высокоуровневый API нейронной сети, который поддерживает быстрые эксперименты и может быстро превратить ваши идеи в результаты. Когда мы понимаем сетевую структуру модели через код, понять более сложную структуру непросто, но если эта структура отображается в виде картинок, мы можем понять ее более интуитивно и быстро.В этой статье используется фреймворк Keras. рисует сетевую структуру модели Bi-LSTM.

1. Предварительные приготовления

1. Установите Пидо

pip install pydot

2. Установите графвиз

graphviz необходимо установить на официальном сайте:graphvizgraphviz.org/

image.png

После установки нужно добавить папку bin каталога где находится программа для добавления системных переменных

image.png

2. Напишите код

1. Импорт связанных пакетов

load_model: используется для загрузки сетевой модели

CRF: в сетевой модели есть слой модели CRF.

plot_model: создать структуру сетевой модели и сохранить ее как изображение.

pyplot: загрузить изображение структуры модели сети

from keras.models import load_model
from keras_contrib.layers import CRF
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt

2. Создайте структуру сетевой модели

параметры интерфейса plot_model:

to_file: путь и имя хранения изображения структуры сетевой модели

show_shapes: отображать ли фигуры (ввод и вывод нейронного слоя)

show_layer_names: показывать ли имена нейронных слоев

rankdir: направление между нейронными слоями, «TB» обозначает вверх и вниз, «LR» обозначает лево и право.

model_path = "./model/ch_ner_model.h5"
# 模型文件
model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)
plot_model(model,to_file='./model/nerbilstm.png',show_shapes=True,show_layer_names='False',rankdir="TB")

3. Загрузите структуру сетевой модели

Используйте метод pyplot в пакете matplotlib, чтобы загрузить сгенерированное изображение структуры сетевой модели.

plt.figure(figsize=(10,10))
img = plt.imread("./model/nerbilstm.png")
plt.imshow(img)
plt.axis("off")
plt.show()

4. Результат загрузки изображения

image.png