контрольно-пропускной пункт
В этой статье описывается, как сохранять и восстанавливать модели TensorFlow, скомпилированные с помощью Estimators. TensorFlow предоставляет два формата моделей:
- Контрольные точки: это формат, основанный на создании кода модели.
- SavedModel: это формат, независимый от кода для создания модели.
образец кода
git clone https://github.com/tensorflow/models/
cd models/samples/core/get_started
Большинство фрагментов кода в этой статье находятся вpremade_estimator.py
Немного модифицированная версия на основе.
Сохраните необученную модель
Оценщики автоматически записывают на диск следующее:
- контрольно-пропускной пункт: Различные версии модели, созданные во время обучения.
- файл событий: содержит некоторыеTensorBoardвизуализированная информация
Указывает каталог верхнего уровня, в котором оценщик хранит информацию, присваивая ее любому необязательному параметру конструктора оценщика.model_dir
. Например, следующий код будетmodel_dir
параметр установлен наmodels/iris
содержание:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
Предположим, вы вызываете Estimator'strain
метод. Например:
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
Как показано на следующей диаграмме, первый вызовtrain
Добавьте контрольные точки и другие файлы вmodel_dir
В каталоге:
В UNIX-подобной системе доступные командыls
смотретьmodel_dir
Объекты в каталоге:
$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
вышеls
Команда показывает, что этот Estimator сгенерировал контрольные точки на шаге 1 (в начале обучения) и на шаге 200 (в конце обучения).
Каталог контрольных точек по умолчанию
Если вы укажете в конструкторе Estimatormodel_dir
параметр, этот Estimator записывает файлы контрольных точек во временный каталог, который используется программным обеспечением Python.tempfile.mkdtempуказана функция. Например, конструктор Estimator ниже не указываетmodel_dir
параметр:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
print(classifier.model_dir)
tempfile.mkdtemp
Функция выберет для вас безопасный временный каталог в операционной системе. Например, в операционной системе macOS типичный временный каталог:
/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa
Как часто сохраняются контрольные точки
По умолчанию оценщик будетmodel_dir
сохранить в каталогеконтрольно-пропускной пункт, и принять следующую стратегию:
- Сохраняйте контрольную точку каждые 10 минут (т.е. 600 секунд).
- когда
train
Контрольная точка сохраняется, когда метод начинает выполняться (то есть первый цикл) и когда выполнение заканчивается (последний цикл). - Сохраните последние 5 контрольных точек в каталоге.
Вы можете изменить указанную выше политику по умолчанию, выполнив следующие действия:
tf.estimator.RunConfig
- При создании экземпляра Estimator используйте этот
RunConfig
Объект передан оценщикуconfig
параметр.
Например, следующий код изменяет политику сохранения контрольных точек, чтобы сохранять каждые 20 минут и сохранять последние 10 контрольных точек:
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
восстановить свою модель
При вызове оценщика в первый разtrain
метод, TensorFlow будетmodel_dir
Контрольная точка сохраняется в каталоге. Каждый последующий вызов Estimatortrain
,evaluate
илиpredict
метод, произойдет следующее поведение:
- Создание пользовательских оценщиков
- Этот оценщик восстанавливает данные из самой последней контрольной точки и используется для инициализации весов для новой модели.
Другими словами, как показано на рисунке ниже, когда файл контрольной точки существует, TensorFlow всегда будет вызыватьtrain()
,evaluation()
илиpredict()
при восстановлении модели.
Избегайте плохого восстановления
Только если модель совместима с контрольной точкой, мы можем восстановить состояние модели с этой контрольной точки. Например, предположим, что вы обучаете программу под названиемDNNClassifier
Оценщик, который содержит два скрытых слоя, каждый с 10 узлами:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
После тренировки (разумеется, тоже одновременноmodels/iris
каталог), если вы измените 10 узлов в каждом скрытом слое на 20, а затем попытаетесь восстановить модель:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # 修改模型中的神经元个数
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
Поскольку состояние контрольной точки такое же, какclassifier2
Состояния описанных моделей несовместимы, и восстановление модели завершится ошибкой со следующим сообщением:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
Когда вы обучаете и сравниваете слегка отличающиеся версии моделей во время экспериментов, не забывайте сохранять и создавать каждуюmodel_dir
код. Например, вы можете создать отдельную ветку git для каждого релиза. Такое разделение гарантирует возможность восстановления ваших контрольных точек.
Суммировать
Контрольные точки предоставляют автоматизированный механизм для простого сохранения и восстановления моделей, созданных Estimator.
- Используйте низкоуровневый API TensorFlow для сохранения и восстановления моделей.
- Экспорт и импорт моделей в режиме SavedModel — независимом от языка, восстанавливаемом и сериализуемом формате.