Сохранить и восстановить загрузку моделей Tensorflow

машинное обучение искусственный интеллект TensorFlow GitHub

Адрес блога на Github:GitHub.com/message731/no…


Недавно была проведена некоторая работа по борьбе со спамом.Помимо использования общих правил, таких как сопоставление и фильтрация, для прогнозирования классификации также используются некоторые методы машинного обучения. Мы используем TensorFlow для обучения модели.Обученную модель нужно сохранить.На этапе прогнозирования нам нужно загрузить и восстановить модель, что включает в себя сохранение и восстановление модели TensorFlow.

Кратко опишите наиболее часто используемые методы сохранения моделей Tensorflow.

Сохраните файл модели контрольной точки (.ckpt).

Во-первых, TensorFlow предоставляет очень удобный API,tf.train.Saver()для сохранения и восстановления модели машинного обучения.

сохранить модель

использоватьtf.train.Saver()Сохранять файлы модели очень удобно, вот простой пример:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    op = tf.add(xy, b, name='op_to_store')

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    path = os.path.dirname(os.path.abspath(ckpt_file_path))
    if os.path.isdir(path) is False:
        os.makedirs(path)

    tf.train.Saver().save(sess, ckpt_file_path)
    
    # test
    feed_dict = {x: 2, y: 3}
    print(sess.run(op, feed_dict))

Программа генерирует и сохраняет четыре файла (до версии 0.11 формировалось только три файла: checkpoint, model.ckpt, model.ckpt.meta)

  • текстовый файл контрольной точки, который записывает список информации о пути к файлу модели
  • model.ckpt.data-00000-of-00001 информация о весе сети
  • Два файла model.ckpt.index .data и .index — это двоичные файлы, которые сохраняют информацию о переменных параметрах (весе) в модели.
  • Двоичный файл model.ckpt.meta, который сохраняет информацию о структуре вычислительного графа модели (сетевая структура модели) protobuf

Вышеtf.train.Saver().save()базовое использование,save()Метод также имеет ряд настраиваемых параметров:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

Добавление параметра global_step означает сохранение модели после каждой 1000 итераций и добавление «-1000» после файла модели, model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000- из-00001

Модель сохраняется каждые 1000 итераций, но файл информации о структуре модели не изменится, поэтому он сохраняется только при 1000 итерациях, а не каждые 1000 раз, поэтому, когда нам не нужно сохранять метафайл, мы можем добавитьwrite_meta_graph=Falseпараметры следующим образом:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

Если вы хотите сохранять модель каждые два часа и сохранять только последние 4 модели, вы можете добавить использованиеmax_to_keep(Значение по умолчанию — 5. Если вы хотите сохранять один раз каждую эпоху обучения, вы можете установить для него значение «Нет» или «0», но это бесполезно и не рекомендуется).keep_checkpoint_every_n_hoursпараметры следующим образом:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

в то же времяtf.train.Saver()В классе, если мы не укажем никакой информации, вся информация о параметрах будет сохранена, и мы также можем указать часть содержимого, которое мы хотим сохранить, например, только сохранить параметры x, y (список параметров или dict может быть прошел):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps.В процессе обучения модели имя атрибута имени переменной или параметра, которое необходимо получить после сохранения, нельзя терять, иначе модель не может пройти после восстановления.get_tensor_by_name()Получать.

Восстановление загрузки модели

Для приведенного выше примера сохранения модели процесс восстановления модели выглядит следующим образом:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
    sess = tf.Session()
    saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')  # 加载模型结构
    saver.restore(sess, tf.train.latest_checkpoint('./ckpt'))  # 只需要指定目录就可以恢复所有变量信息

    # 直接获取保存的变量
    print(sess.run('b:0'))

    # 获取placeholder变量
    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')
    # 获取需要进行计算的operator
    op = sess.graph.get_tensor_by_name('op_to_store:0')

    # 加入新的操作
    add_on_op = tf.multiply(op, 2)

    ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
    print(ret)

Сначала восстановить структуру модели, затем восстановить информацию о переменных (параметрах), и, наконец, мы можем получить различную информацию в обученной модели (сохраненные переменные, переменные-заполнители, оператор и т. д.), а также можем добавить различные новые переменные к полученным переменным. операции (см. комментарии к коду выше).

Кроме того, мы также можем загрузить некоторые модели и добавить другие операции на этой основе.Подробности см. в официальной документации и демоверсии.

Для сохранения и восстановления файлов модели ckpt существует stackoverflowотвечатьОбъяснение относительно ясное, вы можете сослаться на него.

При этом модель TensorFlow на cv-tricks.com сохраняется и восстанавливается.руководствоТакже очень хорошо для справки.

«Обучение Tensorflow 1.0: сохранение и восстановление модели (Saver)»Есть несколько советов по использованию Saver.

Сохраните один файл модели (.pb)

Я сам запустил демо-версию Tensorflow Inception-v3 и обнаружил, что после операции будет сгенерирован файл модели .pb. Этот файл используется для последующего прогнозирования или обучения переносу. Это всего лишь один файл, что очень круто и очень удобный.

Основная идея этого процесса заключается в том, что файл graph_def не содержит значения переменной в сети (обычно хранится вес), но содержит постоянное значение, поэтому, если мы можем преобразовать переменную в константу (используяgraph_util.convert_variables_to_constants()функция), вы можете достичь цели использования файла для одновременного хранения сетевой архитектуры и весов.

ps: Здесь .pb — это суффикс файла модели, конечно, мы можем использовать и другие суффиксы (используйте .pb, чтобы соответствовать Google ╮(╯▽╰)╭)

сохранить модель

Также в соответствии с приведенным выше примером простая демонстрация:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    path = os.path.dirname(os.path.abspath(pb_file_path))
    if os.path.isdir(path) is False:
        os.makedirs(path)

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
    with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # test
    feed_dict = {x: 2, y: 3}
    print(sess.run(op, feed_dict))

Программа создает и сохраняет файл

  • Двоичный файл model.pb, который также сохраняет структуру сети модели и информацию о параметрах (весе).

Восстановление загрузки модели

Для приведенного выше примера сохранения модели процесс восстановления модели выглядит следующим образом:

import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
    sess = tf.Session()
    with gfile.FastGFile(pb_file_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')

    print(sess.run('b:0'))

    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')

    op = sess.graph.get_tensor_by_name('op_to_store:0')

    ret = sess.run(op, {input_x: 5, input_y: 5})
    print(ret)

Процесс восстановления модели почти такой же, как и у чекпойнтинга.

CSDN«Экспорт сети TensorFlow в виде одного файла»Выше описано, как TensorFlow сохраняет один файл модели, что похоже, вы можете посмотреть.

считать

Сохранение и загрузка модели — это только одна из самых основных частей TensorFlow.Хотя это просто, но также важно.На практике необходимо обращать внимание на то, когда сохраняется модель, какие переменные нужно сохранять и как для разработки и загрузки для достижения передачи обучения.

При этом функции и классы TensorFlow постоянно меняются и обновляются, и в будущем могут появиться более богатые методы сохранения и восстановления моделей.

Выбор сохранения в виде контрольной точки или отдельного pb-файла зависит от ситуации в бизнесе, и особой разницы нет. Хранилище контрольных точек кажется более гибким, а файл pb больше подходит для онлайн-развертывания (личное мнение).

Полный код выше:github


2017/11/25 done

Эта статья также синхронизирована сЛичный блог на Github