Мало знаний, большой вызов! Эта статья участвует в "Необходимые знания для программистов«Творческая деятельность.
Keras — это фреймворк глубокого обучения на основе теано/тензорного потока, написанный на чистом питоне. Keras — это высокоуровневый API нейронной сети, который поддерживает быстрые эксперименты и может быстро превратить ваши идеи в результаты. Когда мы понимаем сетевую структуру модели через код, понять более сложную структуру непросто, но если эта структура отображается в виде картинок, мы можем понять ее более интуитивно и быстро.В этой статье используется фреймворк Keras. рисует сетевую структуру модели Bi-LSTM.
1. Предварительные приготовления
1. Установите Пидо
pip install pydot
2. Установите графвиз
graphviz необходимо установить на официальном сайте:graphvizgraphviz.org/
После установки нужно добавить папку bin каталога где находится программа для добавления системных переменных
2. Напишите код
1. Импорт связанных пакетов
load_model: используется для загрузки сетевой модели
CRF: в сетевой модели есть слой модели CRF.
plot_model: создать структуру сетевой модели и сохранить ее как изображение.
pyplot: загрузить изображение структуры модели сети
from keras.models import load_model
from keras_contrib.layers import CRF
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
2. Создайте структуру сетевой модели
параметры интерфейса plot_model:
to_file: путь и имя хранения изображения структуры сетевой модели
show_shapes: отображать ли фигуры (ввод и вывод нейронного слоя)
show_layer_names: показывать ли имена нейронных слоев
rankdir: направление между нейронными слоями, «TB» обозначает вверх и вниз, «LR» обозначает лево и право.
model_path = "./model/ch_ner_model.h5"
# 模型文件
model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)
plot_model(model,to_file='./model/nerbilstm.png',show_shapes=True,show_layer_names='False',rankdir="TB")
3. Загрузите структуру сетевой модели
Используйте метод pyplot в пакете matplotlib, чтобы загрузить сгенерированное изображение структуры сетевой модели.
plt.figure(figsize=(10,10))
img = plt.imread("./model/nerbilstm.png")
plt.imshow(img)
plt.axis("off")
plt.show()