Сохранение и восстановление моделей

TensorFlow

Прогресс модели можно сохранить во время и после обучения. Это означает, что вы можете продолжить обучение модели с того места, на котором остановились, и не тратить слишком много времени на обучение. Кроме того, возможность сохранения означает, что вы можете делиться моделями, а другие могут воссоздавать вашу работу. При публикации исследовательских моделей и связанных с ними технологий большинство специалистов по машинному обучению делятся следующим:

  • код, используемый для создания модели, и
  • Тренировочные веса или параметры модели

Совместное использование данных этого типа помогает другим понять, как работает модель, и опробовать ее на новых данных.

Примечание. Будьте осторожны с ненадежным кодом — модели TensorFlow — это код. Дополнительные сведения см. в разделе «Безопасное использование TensorFlow».

опции

Вы можете сохранить модель TensorFlow несколькими способами, в зависимости от используемого вами API. В этом руководстве используется tf.keras, высокоуровневый API для построения и обучения моделей в TensorFlow. Чтобы узнать о других методах, см. руководство по сохранению и восстановлению TensorFlow или сохраните в Eager.

настраивать

установить и импортировать

Установите и импортируйте TensorFlow и зависимости:

In [1]:
!pip install -q h5py pyyaml

Получить образец набора данных

Мы будем обучать модель с использованием набора данных MNIST, чтобы продемонстрировать, как сохранять веса. Чтобы ускорить демонстрационный запуск, используйте только первые 1000 образцов:

In [2]:
from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__
Out[2]:
'1.13.1'
In [3]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

Определите модель

Давайте построим простую модель, чтобы продемонстрировать, как сохранять и загружать веса.

In [4]:
# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy'])

  return model

# Create a basic model instance
model = create_model()
model.summary()
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\layers\core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Сохраняйте контрольные точки во время обучения

Основной вариант использования — автоматическое сохранение контрольных точек во время или в конце обучения. Это позволяет использовать обученную модель без переобучения или возобновлять обучение с того места, где оно было остановлено, в случае прерывания процесса обучения.

tf.keras.callbacks.ModelCheckpoint — это обратный вызов, выполняющий эту задачу. Обратный вызов принимает несколько параметров для настройки контрольной точки.

Использование обратного вызова контрольной точки

Обучите модель и передайте модель обратного вызова ModelCheckpoint:

In [5]:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10,
          validation_data = (test_images,test_labels),
          callbacks = [cp_callback])  # pass callback to training
Train on 1000 samples, validate on 1000 samples
Epoch 1/10
 864/1000 [========================>.....] - ETA: 0s - loss: 1.2590 - acc: 0.6354
Epoch 00001: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\engine\network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.
1000/1000 [==============================] - 1s 791us/sample - loss: 1.1675 - acc: 0.6650 - val_loss: 0.7683 - val_acc: 0.7550
Epoch 2/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.4623 - acc: 0.8750
Epoch 00002: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 351us/sample - loss: 0.4515 - acc: 0.8750 - val_loss: 0.5316 - val_acc: 0.8340
Epoch 3/10
 800/1000 [=======================>......] - ETA: 0s - loss: 0.2790 - acc: 0.9287
Epoch 00003: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 358us/sample - loss: 0.2834 - acc: 0.9270 - val_loss: 0.4607 - val_acc: 0.8520
Epoch 4/10
 928/1000 [==========================>...] - ETA: 0s - loss: 0.2077 - acc: 0.9515
Epoch 00004: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 339us/sample - loss: 0.2046 - acc: 0.9530 - val_loss: 0.4370 - val_acc: 0.8540
Epoch 5/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.1578 - acc: 0.9710
Epoch 00005: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 350us/sample - loss: 0.1526 - acc: 0.9720 - val_loss: 0.4047 - val_acc: 0.8670
Epoch 6/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.1055 - acc: 0.9815
Epoch 00006: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 380us/sample - loss: 0.1062 - acc: 0.9830 - val_loss: 0.4201 - val_acc: 0.8560
Epoch 7/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0826 - acc: 0.9850
Epoch 00007: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 351us/sample - loss: 0.0824 - acc: 0.9850 - val_loss: 0.4168 - val_acc: 0.8660
Epoch 8/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0662 - acc: 0.9919
Epoch 00008: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 357us/sample - loss: 0.0655 - acc: 0.9910 - val_loss: 0.4021 - val_acc: 0.8700
Epoch 9/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0495 - acc: 0.9954
Epoch 00009: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 358us/sample - loss: 0.0491 - acc: 0.9950 - val_loss: 0.4168 - val_acc: 0.8640
Epoch 10/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.0401 - acc: 1.0000
Epoch 00010: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 354us/sample - loss: 0.0397 - acc: 1.0000 - val_loss: 0.4091 - val_acc: 0.8770
Out[5]:
<tensorflow.python.keras.callbacks.History at 0x1403b5f8>

Приведенный выше код создаст коллекцию файлов контрольных точек TensorFlow, которые обновляются в конце каждого цикла:

In [7]:
!dir {checkpoint_dir}
 驱动器 C 中的卷没有标签。
 卷的序列号是 CE2F-63AD

 C:\Users\Administrator\JupyterProject\training_1 的目录

2019/04/28  11:23    <DIR>          .
2019/04/28  11:23    <DIR>          ..
2019/04/28  11:23                71 checkpoint
2019/04/28  11:23         1,631,508 cp.ckpt.data-00000-of-00001
2019/04/28  11:23               648 cp.ckpt.index
               3 个文件      1,632,227 字节
               2 个目录 23,484,948,480 可用字节

Создайте новую необученную модель. При восстановлении модели только по весам у вас должна быть модель с той же архитектурой, что и оригинал. Поскольку архитектура модели одинакова, мы можем разделить веса (хотя и с разными экземплярами модели).

Теперь перестройте совершенно новую необученную модель и оцените ее с помощью тестового набора. Производительность необученной модели во многом случайна (точность около 10%):

In [8]:
model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 81us/sample - loss: 2.3694 - acc: 0.0610
Untrained model, accuracy:  6.10%

Затем загрузите веса из контрольной точки и переоцените:

In [9]:
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 46us/sample - loss: 0.4091 - acc: 0.8770
Restored model, accuracy: 87.70%

Параметры обратного вызова контрольной точки

Этот обратный вызов предоставляет несколько вариантов присвоения сгенерированной контрольной точке уникального имени, а также настройки частоты создания контрольной точки.

Обучите новую модель, сохраните контрольные точки каждые 5 эпох и задайте уникальное имя:

In [10]:
# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)
Epoch 00005: saving model to training_2/cp-0005.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00010: saving model to training_2/cp-0010.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00015: saving model to training_2/cp-0015.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00020: saving model to training_2/cp-0020.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00025: saving model to training_2/cp-0025.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00030: saving model to training_2/cp-0030.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00035: saving model to training_2/cp-0035.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00040: saving model to training_2/cp-0040.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00045: saving model to training_2/cp-0045.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00050: saving model to training_2/cp-0050.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
Out[10]:
<tensorflow.python.keras.callbacks.History at 0x16054e48>

Теперь взгляните на сгенерированные контрольные точки и выберите последнюю контрольную точку:

In [11]:
! dir {checkpoint_dir}
 驱动器 C 中的卷没有标签。
 卷的序列号是 CE2F-63AD

 C:\Users\Administrator\JupyterProject\training_2 的目录

2019/04/28  11:24    <DIR>          .
2019/04/28  11:24    <DIR>          ..
2019/04/28  11:24                81 checkpoint
2019/04/28  11:24         1,631,508 cp-0005.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0005.ckpt.index
2019/04/28  11:24         1,631,508 cp-0010.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0010.ckpt.index
2019/04/28  11:24         1,631,508 cp-0015.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0015.ckpt.index
2019/04/28  11:24         1,631,508 cp-0020.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0020.ckpt.index
2019/04/28  11:24         1,631,508 cp-0025.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0025.ckpt.index
2019/04/28  11:24         1,631,508 cp-0030.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0030.ckpt.index
2019/04/28  11:24         1,631,508 cp-0035.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0035.ckpt.index
2019/04/28  11:24         1,631,508 cp-0040.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0040.ckpt.index
2019/04/28  11:24         1,631,508 cp-0045.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0045.ckpt.index
2019/04/28  11:24         1,631,508 cp-0050.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0050.ckpt.index
              21 个文件     16,321,641 字节
               2 个目录 23,468,404,736 可用字节
In [13]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
Out[13]:
'training_2\\cp-0050.ckpt'

Примечание. Формат TensorFlow по умолчанию сохраняет только 5 последних контрольных точек.

Для тестирования сбросьте модель и загрузите последнюю контрольную точку:

In [14]:
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 86us/sample - loss: 0.4830 - acc: 0.8770
Restored model, accuracy: 87.70%

Что это за файлы?

Приведенный выше код хранит веса в отформатированной контрольной точке коллекции файлов, которые содержат только обученные веса (в двоичном формате). К контрольно-пропускным пунктам относятся:

Один или несколько осколков, содержащих вес модели.
Индексный файл, указывающий, какие веса хранятся в каких осколках.

Если вы обучаете модель только на одной машине, у вас будет 1 сегмент с суффиксом .data-00000-of-00001.

Вручную сохранить веса

Выше вы видели, как загрузить веса в модель.

Вручную сохранить веса также легко, просто используйте метод Model.save_weights.

In [15]:
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000018D9D080>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 88us/sample - loss: 0.4830 - acc: 0.8770
Restored model, accuracy: 87.70%

сохранить всю модель

Всю модель можно сохранить в файл, содержащий значения веса, конфигурацию модели и даже конфигурацию оптимизатора. Таким образом, вы можете проверить модель и возобновить обучение позже из того же состояния, не обращаясь к исходному коду.

Полнофункциональные модели полезно сохранять в Keras, которые можно загрузить в TensorFlow.js, а затем обучить и запустить в веб-браузере.

Keras предоставляет базовый формат сохранения с использованием стандарта HDF5. Для нас сохраненная модель может рассматриваться как бинарный блоб.

In [16]:
model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')
Epoch 1/5
1000/1000 [==============================] - 0s 322us/sample - loss: 1.1511 - acc: 0.6830
Epoch 2/5
1000/1000 [==============================] - 0s 235us/sample - loss: 0.4189 - acc: 0.8840s - loss: 0.4545 - acc: 0.8
Epoch 3/5
1000/1000 [==============================] - 0s 235us/sample - loss: 0.2864 - acc: 0.9230
Epoch 4/5
1000/1000 [==============================] - 0s 233us/sample - loss: 0.2147 - acc: 0.9410
Epoch 5/5
1000/1000 [==============================] - 0s 224us/sample - loss: 0.1642 - acc: 0.9660

Теперь воссоздайте модель из этого файла:

In [17]:
# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_14 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_7 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Проверьте его точность:

In [18]:
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 99us/sample - loss: 0.4258 - acc: 0.8530
Restored model, accuracy: 85.30%

Этот трюк сохраняет все следующее:

  • Веса
  • Конфигурация модели (архитектура)
  • Конфигурация оптимизатора

Keras сохраняет модель, проверяя схему. В настоящее время он не может сохранять оптимизаторы TensorFlow (из tf.train). При использовании такого оптимизатора вам необходимо перекомпилировать модель после ее загрузки, чтобы ослабить состояние оптимизатора.

План последующего обучения¶

Это краткое руководство по сохранению и загрузке моделей с помощью tf.keras.

  • В руководстве по tf.keras подробно описано, как сохранять и загружать модели с помощью tf.keras.

  • См. Сохранение в Eager, чтобы узнать, как сохранить модель во время Eager Execution.

  • Руководство по сохранению и восстановлению содержит низкоуровневые сведения о сохранении TensorFlow.