Я искал много информации о сохранении и загрузке модели Tensorflow за последние два дня и обнаружил, что много информации о формате модели контрольных точек, в то время как в последнем формате модели SavedModel информации меньше. как TensorFlow сохраняет модель SavedModel и загружает ее.
Зачем использовать формат SavedModel? Его главное преимущество в том, что SaveModel не зависит от языка, например, вы можете использовать язык python для обучения модели, а затем очень удобно загружать модель в Java. Конечно, это не означает, что формат модели КПП нельзя сделать, но хлопот больше, когда она кросс-языковая. Кроме того, если вы используете сервер Tensorflow Serving для развертывания модели, вы должны выбрать формат SavedModel.
Что содержит SavedModel?
Относительно полная модель SavedModel содержит следующее:
assets/
assets.extra/
variables/
variables.data-*****-of-*****
variables.index
saved_model.pb
save_model.pb — это MetaGraphDef, который содержит структуру графа. Папка переменных содержит веса, полученные во время обучения. В папку с ресурсами можно добавлять внешние файлы, которые могут понадобиться. В assets.extra библиотека может добавлять свои конкретные ресурсы.
MetaGraph — это граф потока данных, а также связанные с ним переменные, активы и подписи. MetaGraphDef — это протокольный буфер, представляющий MetaGraph.
assets и assets.extra необязательны.Например, модель, сохраненная в примере кода в этой статье, содержит только следующее содержимое:
variables/
variables.data-*****-of-*****
variables.index
saved_model.pb
спасти
Для простоты мы используем очень простой код распознавания рукописного ввода в качестве примера, код выглядит следующим образом:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
Этот код очень прост, простая модель регрессии нисходящего градиента. Чтобы сохранить модель, нам также нужно внести небольшое изменение в код.
добавить наименование
Добавьте имена для входных и выходных операций, чтобы мы могли легко обращаться к операциям по имени при загрузке. Измените приведенный выше оператор присваивания x на:
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
Конечно, вы также можете не давать имя. Система даст имя по умолчанию. Например, система x выше даст «заполнитель». Когда нам нужно сослаться на несколько операций, дайте каждой операции имя, которое это действительно удобно для нас, чтобы использовать позже.
Вы также можете использовать tf.identity для имени тензора, например, добавив строку в приведенный выше код:
tf.identity(y, name="myOutput")
Также дайте выходу имя.
сохранить в файл
Самый простой способ сохранения — использовать функцию tf.saved_model.simple_save, код такой:
tf.saved_model.simple_save(sess,
"./model",
inputs={"myInput": x},
outputs={"myOutput": y})
Этот код сохраняет модель в каталоге **./model**.
Конечно, вы также можете использовать более сложный способ записи:
builder = tf.saved_model.builder.SavedModelBuilder("./model")
signature = predict_signature_def(inputs={'myInput': x},
outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map={'predict': signature})
builder.save()
Вроде новый код не сильно отличается, разница в том, что можно самому определять теги, что более гибко в определении сигнатур. Давайте поговорим об использовании тегов здесь.
Модель может содержать разные MetaGraphDef, когда вам нужно несколько MetaGraphDef? Возможно, вы хотите сохранить версии графа для процессора и графического процессора или хотите различать обучающую и релизную версии. В настоящее время теги можно использовать для различения разных MetaGraphDefs.При загрузке различные расчетные графики модели могут быть загружены в соответствии с тегами.
В методе simple_save система выдаст тег по умолчанию: «служить», также можно использовать константу tag_constants.SERVING.
нагрузка
Для разных языков процесс загрузки несколько похож, и вот python в качестве примера:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], "./model")
graph = tf.get_default_graph()
input = np.expand_dims(mnist.test.images[0], 0)
x = sess.graph.get_tensor_by_name('myInput:0')
y = sess.graph.get_tensor_by_name('myOutput:0')
batch_xs, batch_ys = mnist.test.next_batch(1)
scores = sess.run(y,
feed_dict={x: batch_xs})
print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))
Следует отметить, что вторым параметром в функции загрузки является тег, который должен совпадать с параметром при сохранении модели, а третьим параметром является папка, в которой сохраняется модель.
После вызова функции загрузки загружается не только график расчета, но и значения переменных, полученные во время обучения, С помощью этих двух мы можем вызвать его, чтобы вывести новые данные теста.
резюме
После того, как процесс пройдет гладко, вы обнаружите, что сохранение и загрузка SavedModel на самом деле очень просты. Но в процессе исследования я сделал много обходных путей.Основная причина в том, что большая часть искомых данных сейчас по-прежнему использует tf.train.Saver() для сохранения модели, а некоторые используют tf.gfile.FastGFile для сериализации модели. схема модели. .
Полный код этой статьи см. по адресу:GitHub.com/mogo Web/голодание…
Надеюсь, эта статья была вам полезна, спасибо за прочтение!