Резюме:Овладение навыками глубокого обучения: автор подробно рассказывает, как сохранить, а затем загрузить модель глубокого обучения во время обучения на простом примере. Благодаря этому небольшому трюку вы никогда больше не будете беспокоиться об ошибках при обучении модели.
Советы по глубокому обучению (2): как сохранить и восстановить модель, обученную scikit-learn
Если сложность модели глубокой нейронной сети очень высока, ее обучение может занять довольно много времени, в зависимости от объема имеющихся у вас данных, аппаратного обеспечения, на котором работает модель, и т. д. В большинстве случаев вы захотите гарантировать стабильность своих экспериментов, сохранив файл, чтобы в случае сбоя (или ошибки) вы могли продолжить с того места, где не было ошибки.
Что еще более важно, в любой среде глубокого обучения, такой как TensorFlow, после успешного обучения вам необходимо повторно использовать изученные параметры модели для выполнения прогнозов на новых данных.
В этом посте мы рассмотрим, как сохранять и восстанавливать модели TensorFlow, опишем некоторые из наиболее полезных методов и приведем несколько примеров.
1.Сначала мы дадим краткое введение в модель TensorFlow.
Основная функция TensorFlow заключается вТензордля передачи своей базовой структуры данных, такой как многомерный массив в NumPy, в то время какправило диаграммыПредставляет вычисления данных. этосимволбиблиотека, а это значит, что определение графиков и тензоров создаст только модель, а конкретные значения и операции для получения тензоров будутсеансв исполнении,сеансМеханизм выполнения операций моделирования на графах. Любые конкретные значения тензоров теряются при закрытии сессии, что является еще одной причиной сохранения модели в файл после запуска сессии.
На примерах легче понять, поэтому давайте создадим простую модель TensorFlow для линейной регрессии на 2D-данных.
Во-первых, мы импортируем нашу библиотеку:
Следующим шагом является создание модели. Мы создадим модель, которая будет оценивать горизонтальное и вертикальное смещение квадратичной функции в виде:
где h — изменение по горизонтали, а v — изменение по вертикали.
Вот как генерируется модель (подробности см. в комментариях к коду):
В процессе создания модели нам необходимо иметьв сессииЗапустите модель и передайте ей реальные данные. Мы генерируем некоторые квадратичные данные и добавляем к ним шум.
2.The Saver class
Класс Saver — это класс, предоставляемый библиотекой TensorFlow, и это предпочтительный способ сохранения структур и переменных графа.
2.1сохранить модель
В следующих строках кода мы определяем объект Saver и в функции train_graph() минимизируем функцию стоимости за 100 итераций. Затем на каждой итерации и после завершения оптимизации сохраняйте модель на диск. Каждый созданный бинарный файл, сохраненный на диске, называется «контрольной точкой».
Теперь давайте обучим модель с помощью вышеуказанной функции и распечатаем обученные параметры.
Хорошо, параметры очень точные. Если мы проверим нашу файловую систему, последние 4 итерации сохранят файл вместе с окончательной моделью.
При сохранении модели вы заметите, что для сохранения требуются 4 типа файлов:
Файл «.meta»: содержит структуру графика.
Файл ".data": содержит значение переменной.
Файл ".index": Идентифицирует контрольные точки.
Файл "checkpoint": Буферы протокола со списком последних контрольных точек.
Вызовите метод tf.train.Saver(), как показано выше, чтобы сохранить все переменные в файл. Передавая их в качестве аргументов, выражения передаются в виде списков или словарей для сохранения подмножеств переменных, например: tf.train.Saver({'hor_estimate': h_est}).
Некоторые другие полезные параметры конструктора Saver, которые также контролируют весь процесс:
1.max_to_keep: Максимальное количество сохраняемых контрольных точек.
2.keep_checkpoint_every_n_hours: Интервал сохранения чекпоинтов.
Если вы хотите узнать больше, см.официальная документацияКласс Saver, который предоставляет другую полезную информацию, которую вы можете изучить и просмотреть.
3.Restoring Models
Первое, что нужно сделать при восстановлении модели TensorFlow, — это загрузить структуру графа из файла «.meta» в текущий граф.
Текущий график также можно просмотреть с помощью tf.get_default_graph(). Затем вторым шагом является загрузка значения переменной. Напоминание: значения существуют только в сеансе
Как упоминалось ранее, этот метод сохраняет только структуру графика и переменные, а это означает, что обучающие данные, введенные через заполнители «X» и «Y», не будут сохранены.
В любом случае, в этом примере мы будем использовать наши определенные обучающие данные tf и визуализировать соответствие модели.
Saver Этот класс позволяет использовать простой способ сохранения и восстановления вашей модели TensorFlow (график и переменные) в/из файла и сохранять несколько контрольных точек вашей работы, что может быть полезно для тонкой настройки вашей модели во время обучения.
4.SavedModelФормат
Новый способ сохранения и восстановления моделей в TensorFlow — использованиеSavedModel, конструктор и загрузчикФункции. Этот метод на самом деле представляет собой сериализацию более высокого уровня, предоставляемую Saver, которая больше подходит для коммерческих целей.
Хотя этот подход SavedModel, похоже, не полностью принят разработчиками, его создатели отмечают: за ним явно будущее. В отличие от классов Saver, которые в основном ориентированы на переменные, SavedModel пытается включить некоторые полезные функции в один пакет, например Signatures: позволяет сохранить график с набором входных и выходных данных, Assets: содержит внешние файлы, используемые при инициализации.
4.1Сохранение моделей с помощью SavedModel Builder
Далее мы пытаемся использовать класс SavedModelBuilder для сохранения модели. В нашем примере мы не используем никаких символов, но этого достаточно, чтобы проиллюстрировать процесс.
4.2Восстановление моделей с помощью программы SavedModel Loader
Восстановление модели использует tf.saved_model.loader и может восстанавливать переменные, символы, сохраненные в области сеанса.
В приведенном ниже примере мы загрузим модель и распечатаем значения наших двух коэффициентов (h_est и v_est). Значения соответствуют ожидаемым, и наша модель была успешно восстановлена.
5.в заключении
Сохранение и восстановление моделей TensorFlow — очень полезная функция, если вы знаете, что обучение вашей сети глубокого обучения может занять много времени. Тема слишком широка, чтобы подробно осветить ее в одном сообщении в блоге. В любом случае, в этом посте мы представляем два инструмента: Saver и SavedModelbuilder/loader, а также создаем файловую структуру для иллюстрации примера с использованием простой линейной регрессии. Надеюсь, это поможет вам лучше обучать модели нейронных сетей.
автор:Mihajlo Pavloski, энтузиаст науки о данных и машинного обучения, аспирант.
Эта статья написанаСообщество Alibaba Cloud YunqiОрганизация переводов.
Оригинальное название статьи «TensorFlow: сохранение и восстановление моделей».
автор:Mihajlo PavloskiПереводчик: Тигр говорит о восьми путях