Адрес блога на Github:GitHub.com/message731/no…
Недавно была проведена некоторая работа по борьбе со спамом.Помимо использования общих правил, таких как сопоставление и фильтрация, для прогнозирования классификации также используются некоторые методы машинного обучения. Мы используем TensorFlow для обучения модели.Обученную модель нужно сохранить.На этапе прогнозирования нам нужно загрузить и восстановить модель, что включает в себя сохранение и восстановление модели TensorFlow.
Кратко опишите наиболее часто используемые методы сохранения моделей Tensorflow.
Сохраните файл модели контрольной точки (.ckpt).
Во-первых, TensorFlow предоставляет очень удобный API,tf.train.Saver()
для сохранения и восстановления модели машинного обучения.
сохранить модель
использоватьtf.train.Saver()
Сохранять файлы модели очень удобно, вот простой пример:
import tensorflow as tf
import os
def save_model_ckpt(ckpt_file_path):
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
op = tf.add(xy, b, name='op_to_store')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
path = os.path.dirname(os.path.abspath(ckpt_file_path))
if os.path.isdir(path) is False:
os.makedirs(path)
tf.train.Saver().save(sess, ckpt_file_path)
# test
feed_dict = {x: 2, y: 3}
print(sess.run(op, feed_dict))
Программа генерирует и сохраняет четыре файла (до версии 0.11 формировалось только три файла: checkpoint, model.ckpt, model.ckpt.meta)
- текстовый файл контрольной точки, который записывает список информации о пути к файлу модели
- model.ckpt.data-00000-of-00001 информация о весе сети
- Два файла model.ckpt.index .data и .index — это двоичные файлы, которые сохраняют информацию о переменных параметрах (весе) в модели.
- Двоичный файл model.ckpt.meta, который сохраняет информацию о структуре вычислительного графа модели (сетевая структура модели) protobuf
Вышеtf.train.Saver().save()
базовое использование,save()
Метод также имеет ряд настраиваемых параметров:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
Добавление параметра global_step означает сохранение модели после каждой 1000 итераций и добавление «-1000» после файла модели, model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000- из-00001
Модель сохраняется каждые 1000 итераций, но файл информации о структуре модели не изменится, поэтому он сохраняется только при 1000 итерациях, а не каждые 1000 раз, поэтому, когда нам не нужно сохранять метафайл, мы можем добавитьwrite_meta_graph=False
параметры следующим образом:
tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
Если вы хотите сохранять модель каждые два часа и сохранять только последние 4 модели, вы можете добавить использованиеmax_to_keep
(Значение по умолчанию — 5. Если вы хотите сохранять один раз каждую эпоху обучения, вы можете установить для него значение «Нет» или «0», но это бесполезно и не рекомендуется).keep_checkpoint_every_n_hours
параметры следующим образом:
tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
в то же времяtf.train.Saver()
В классе, если мы не укажем никакой информации, вся информация о параметрах будет сохранена, и мы также можем указать часть содержимого, которое мы хотим сохранить, например, только сохранить параметры x, y (список параметров или dict может быть прошел):
tf.train.Saver([x, y]).save(sess, ckpt_file_path)
ps.В процессе обучения модели имя атрибута имени переменной или параметра, которое необходимо получить после сохранения, нельзя терять, иначе модель не может пройти после восстановления.get_tensor_by_name()
Получать.
Восстановление загрузки модели
Для приведенного выше примера сохранения модели процесс восстановления модели выглядит следующим образом:
import tensorflow as tf
def restore_model_ckpt(ckpt_file_path):
sess = tf.Session()
saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息
# 直接获取保存的变量
print(sess.run('b:0'))
# 获取placeholder变量
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
# 获取需要进行计算的operator
op = sess.graph.get_tensor_by_name('op_to_store:0')
# 加入新的操作
add_on_op = tf.multiply(op, 2)
ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
print(ret)
Сначала восстановить структуру модели, затем восстановить информацию о переменных (параметрах), и, наконец, мы можем получить различную информацию в обученной модели (сохраненные переменные, переменные-заполнители, оператор и т. д.), а также можем добавить различные новые переменные к полученным переменным. операции (см. комментарии к коду выше).
Кроме того, мы также можем загрузить некоторые модели и добавить другие операции на этой основе.Подробности см. в официальной документации и демоверсии.
Для сохранения и восстановления файлов модели ckpt существует stackoverflowотвечатьОбъяснение относительно ясное, вы можете сослаться на него.
При этом модель TensorFlow на cv-tricks.com сохраняется и восстанавливается.руководствоТакже очень хорошо для справки.
«Обучение Tensorflow 1.0: сохранение и восстановление модели (Saver)»Есть несколько советов по использованию Saver.
Сохраните один файл модели (.pb)
Я сам запустил демо-версию Tensorflow Inception-v3 и обнаружил, что после операции будет сгенерирован файл модели .pb. Этот файл используется для последующего прогнозирования или обучения переносу. Это всего лишь один файл, что очень круто и очень удобный.
Основная идея этого процесса заключается в том, что файл graph_def не содержит значения переменной в сети (обычно хранится вес), но содержит постоянное значение, поэтому, если мы можем преобразовать переменную в константу (используяgraph_util.convert_variables_to_constants()
функция), вы можете достичь цели использования файла для одновременного хранения сетевой архитектуры и весов.
ps: Здесь .pb — это суффикс файла модели, конечно, мы можем использовать и другие суффиксы (используйте .pb, чтобы соответствовать Google ╮(╯▽╰)╭)
сохранить модель
Также в соответствии с приведенным выше примером простая демонстрация:
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
def save_mode_pb(pb_file_path):
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
path = os.path.dirname(os.path.abspath(pb_file_path))
if os.path.isdir(path) is False:
os.makedirs(path)
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
# test
feed_dict = {x: 2, y: 3}
print(sess.run(op, feed_dict))
Программа создает и сохраняет файл
- Двоичный файл model.pb, который также сохраняет структуру сети модели и информацию о параметрах (весе).
Восстановление загрузки модели
Для приведенного выше примера сохранения модели процесс восстановления модели выглядит следующим образом:
import tensorflow as tf
from tensorflow.python.platform import gfile
def restore_mode_pb(pb_file_path):
sess = tf.Session()
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print(sess.run('b:0'))
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(op, {input_x: 5, input_y: 5})
print(ret)
Процесс восстановления модели почти такой же, как и у чекпойнтинга.
CSDN«Экспорт сети TensorFlow в виде одного файла»Выше описано, как TensorFlow сохраняет один файл модели, что похоже, вы можете посмотреть.
считать
Сохранение и загрузка модели — это только одна из самых основных частей TensorFlow.Хотя это просто, но также важно.На практике необходимо обращать внимание на то, когда сохраняется модель, какие переменные нужно сохранять и как для разработки и загрузки для достижения передачи обучения.
При этом функции и классы TensorFlow постоянно меняются и обновляются, и в будущем могут появиться более богатые методы сохранения и восстановления моделей.
Выбор сохранения в виде контрольной точки или отдельного pb-файла зависит от ситуации в бизнесе, и особой разницы нет. Хранилище контрольных точек кажется более гибким, а файл pb больше подходит для онлайн-развертывания (личное мнение).
Полный код выше:github
2017/11/25 done
Эта статья также синхронизирована сЛичный блог на Github