При использовании TensorFlow иногда нам нужно загрузить более одной модели, так как же загрузить несколько моделей?
оригинал:Не связывайтесь с ним. Look.com/2017/04/imp…
О TensorFlow можно сказать много. Но в этот раз я только описываю, как импортировать обученную модель (граф), потому что я не могу сделать это, чтобы импортировать вторую модель и использовать ее с первой моделью. Кроме того, этот импорт очень медленный, и я не хочу делать это во второй раз. С другой стороны, нецелесообразно помещать все в одну модель.
В этом уроке я расскажу, как сохранять и загружать модели, и сделаю еще один шаг, как загрузить несколько моделей.
Загрузите модель TensorFlow
Прежде чем представить загрузку нескольких моделей, давайте сначала представим, как загрузить одну модель, официальный документ:woohoo.tensorflow.org/programmers…
Во-первых, нам нужно создать модель, обучить ее и сохранить. Я не хочу вдаваться в подробности в этой части, просто сосредоточусь на том, как сохранить модель, и не забудьте назвать каждую операцию.
Код для создания модели, обучения и сохранения выглядит следующим образом:
import tensorflow as tf
### Linear Regression 线性回归###
# Input placeholders
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
# Model parameters 定义模型的权值参数
W1 = tf.Variable([0.1], tf.float32)
W2 = tf.Variable([0.1], tf.float32)
W3 = tf.Variable([0.1], tf.float32)
b = tf.Variable([0.1], tf.float32)
# Output 模型的输出
linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,
name='activation_opt')
# Loss 定义损失函数
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')
# Optimizer and training step 定义优化器运算
optimizer = tf.train.AdamOptimizer(0.001)
train = optimizer.minimize(loss, name='train_step')
# Remember output operation for later aplication
# Adding it to a collections for easy acces
# This is not required if you NAME your output operation
# 记得将输出操作添加到一个集合中,但如何你命名了输出操作,这一步可以省略
tf.add_to_collection("activation", linear_model)
## Start the session ##
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# CREATE SAVER
saver = tf.train.Saver()
# Training loop 训练
for i in range(10000):
sess.run(train, {x: data, y: expected})
if i % 1000 == 0:
# You can also save checkpoints using global_step variable
saver.save(sess, "models/model_name", global_step=i)
# SAVE TensorFlow graph into path models/model_name
# 保存模型到指定路径并命名模型文件名字
saver.save(sess, "models/model_name")
Обратите внимание, вот первый пункт --Именование переменных и операций. Это сделано для того, чтобы некоторые весовые параметры можно было указать после загрузки модели.Если они не названы, то эти переменные будут автоматически названы как-то вроде «Placeholder_1». В более сложных моделях использование осциллографов является хорошей практикой, но здесь это не рассматривается.
Короче говоря, смысл == в том, что для того, чтобы иметь возможность вызывать весовые параметры или какие-то операции при загрузке модели, вы должны назвать их или поместить в коллекцию. ==
При сохранении модели эти файлы должны быть включены в папку, предназначенную для сохранения модели:model_name.index
,model_name.meta
и другие документы. При использованииcheckpoints
Суффикс имени модели, также будут имена, содержащиеmodel_name-1000
файл, где числа - соответствующие переменныеglobal_step
, что является текущим количеством итераций обучения.
Теперь мы можем начать загрузку модели. Загрузить модель на самом деле очень просто, все, что нам нужно, это две функции:tf.train.import_meta_graph
иsaver.restore()
. Кроме того, это должно обеспечить правильное расположение пути сохранения модели. Кроме того, если мы хотим использовать модель на разных машинах, нам также необходимо задать параметры:clear_device=True
.
Затем мы можем вызвать сохраненную операцию или параметр веса с ранее названным именем или именем сохраненной коллекции. Если используется область, то необходимо также включить имя области. При фактическом вызове этих операций вы также должны использовать что-то вроде{'PlaceholderName:0': data}
введите заполнитель для , иначе произойдет ошибка.
Код для загрузки модели выглядит следующим образом:
sess = tf.Session()
# Import graph from the path and recover session
# 加载模型并恢复到会话中
saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True)
saver.restore(sess, 'models/model_name')
# There are TWO options how to access the operation (choose one)
# 两种方法来调用指定的运算操作,选择其中一个都可以
# FROM SAVED COLLECTION: 从保存的集合中调用
activation = tf.get_collection('activation')[0]
# BY NAME: 采用命名的方式
activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]
# Use imported graph for data
# You have to feed data as {'x:0': data}
# Don't forget on ':0' part!
# 采用加载的模型进行操作,不要忘记输入占位符
data = 50
result = sess.run(activation, {'x:0': data})
print(result)
несколько моделей
Выше описано, как загрузить одну модель, но как загрузить несколько моделей?
Если вы загрузите несколько моделей, загрузив одну модель, вы получите ошибку конфликта переменных и не будете работать. Причина этой проблемы из-за графика по умолчанию. Конфликт возникает из-за того, что мы загружаем все переменные в график по умолчанию, взятый текущим сеансом. Когда мы используем сеансы, мы можем передатьtf.Session(graph=MyGraph)
чтобы указать использование другого уже созданного графа. Поэтому, если мы хотим загрузить несколько моделей, все, что нам нужно сделать, это загрузить их на разные графики и использовать в разных сеансах.
Здесь класс настраивается для выполнения операции загрузки модели указанного пути в локальный граф. Этот класс также обеспечиваетrun
функция для работы с входными данными с использованием загруженной модели. Этот класс полезен для меня, потому что я всегда помещаю вывод модели в коллекцию или называю ее.activation_opt
и назовите заполнитель ввода какx
. Вы можете изменить и расширить этот класс в соответствии с реальными потребностями вашего приложения.
код показывает, как показано ниже:
import tensorflow as tf
class ImportGraph():
""" Importing and running isolated TF graph """
def __init__(self, loc):
# Create local graph and use it in the session
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.graph.as_default():
# Import saved model from location 'loc' into local graph
# 从指定路径加载模型到局部图中
saver = tf.train.import_meta_graph(loc + '.meta',
clear_devices=True)
saver.restore(self.sess, loc)
# There are TWO options how to get activation operation:
# 两种方式来调用运算或者参数
# FROM SAVED COLLECTION:
self.activation = tf.get_collection('activation')[0]
# BY NAME:
self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]
def run(self, data):
""" Running the activation operation previously imported """
# The 'x' corresponds to name of input placeholder
return self.sess.run(self.activation, feed_dict={"x:0": data})
### Using the class ###
# 测试样例
data = 50 # random data
model = ImportGraph('models/model_name')
result = model.run(data)
print(result)
Суммировать
Загрузка нескольких моделей не так уж сложна, если вы понимаете механику TensorFlow. Приведенное выше решение может быть не идеальным, но оно простое и быстрое. Наконец, приведен пример кода, обобщающий весь процесс, который находится в блокноте Jupyter, и адрес кода выглядит следующим образом:
gist.GitHub.com/br ETA01/post20…
Наконец, дайте адреса github нескольких примеров кода в статье:
- Code for creating, training and saving TensorFlow model.
- Importing and using TensorFlow graph (model)
- Class for importing multiple TensorFlow graphs.
- Example of importing multiple TensorFlow modules
Добро пожаловать, чтобы обратить внимание на мою общедоступную учетную запись WeChat - машинное обучение и компьютерное зрение или отсканируйте QR-код ниже, оставьте сообщение в фоновом режиме, поделитесь со мной своими предложениями и мнениями, исправьте возможные ошибки в статье, и давайте общаться, учиться и прогрессируйте вместе!
Рекомендуемое чтение
1.Серия «Введение в машинное обучение» (1) — Обзор машинного обучения (часть 1)
2.Серия «Введение в машинное обучение» (2) — Обзор машинного обучения (часть 2)
3.[Обучающая серия GAN] Первое знакомство с GAN