Как загрузить несколько моделей в TensorFlow

Google машинное обучение искусственный интеллект TensorFlow

При использовании TensorFlow иногда нам нужно загрузить более одной модели, так как же загрузить несколько моделей?

оригинал:Не связывайтесь с ним. Look.com/2017/04/imp…


О TensorFlow можно сказать много. Но в этот раз я только описываю, как импортировать обученную модель (граф), потому что я не могу сделать это, чтобы импортировать вторую модель и использовать ее с первой моделью. Кроме того, этот импорт очень медленный, и я не хочу делать это во второй раз. С другой стороны, нецелесообразно помещать все в одну модель.

В этом уроке я расскажу, как сохранять и загружать модели, и сделаю еще один шаг, как загрузить несколько моделей.

Загрузите модель TensorFlow

Прежде чем представить загрузку нескольких моделей, давайте сначала представим, как загрузить одну модель, официальный документ:woohoo.tensorflow.org/programmers…

Во-первых, нам нужно создать модель, обучить ее и сохранить. Я не хочу вдаваться в подробности в этой части, просто сосредоточусь на том, как сохранить модель, и не забудьте назвать каждую операцию.

Код для создания модели, обучения и сохранения выглядит следующим образом:

import tensorflow as tf
### Linear Regression 线性回归###
# Input placeholders
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
# Model parameters 定义模型的权值参数
W1 = tf.Variable([0.1], tf.float32)
W2 = tf.Variable([0.1], tf.float32)
W3 = tf.Variable([0.1], tf.float32)
b = tf.Variable([0.1], tf.float32)

# Output 模型的输出
linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,
                           name='activation_opt')

# Loss 定义损失函数
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')
# Optimizer and training step 定义优化器运算
optimizer = tf.train.AdamOptimizer(0.001)
train = optimizer.minimize(loss, name='train_step')

# Remember output operation for later aplication
# Adding it to a collections for easy acces
# This is not required if you NAME your output operation
# 记得将输出操作添加到一个集合中,但如何你命名了输出操作,这一步可以省略
tf.add_to_collection("activation", linear_model)

## Start the session ##
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#  CREATE SAVER
saver = tf.train.Saver()

# Training loop 训练
for i in range(10000):
    sess.run(train, {x: data, y: expected})
    if i % 1000 == 0:
        # You can also save checkpoints using global_step variable
        saver.save(sess, "models/model_name", global_step=i)

# SAVE TensorFlow graph into path models/model_name
# 保存模型到指定路径并命名模型文件名字
saver.save(sess, "models/model_name")

Обратите внимание, вот первый пункт --Именование переменных и операций. Это сделано для того, чтобы некоторые весовые параметры можно было указать после загрузки модели.Если они не названы, то эти переменные будут автоматически названы как-то вроде «Placeholder_1». В более сложных моделях использование осциллографов является хорошей практикой, но здесь это не рассматривается.

Короче говоря, смысл == в том, что для того, чтобы иметь возможность вызывать весовые параметры или какие-то операции при загрузке модели, вы должны назвать их или поместить в коллекцию. ==

При сохранении модели эти файлы должны быть включены в папку, предназначенную для сохранения модели:model_name.index,model_name.metaи другие документы. При использованииcheckpointsСуффикс имени модели, также будут имена, содержащиеmodel_name-1000файл, где числа - соответствующие переменныеglobal_step, что является текущим количеством итераций обучения.

Теперь мы можем начать загрузку модели. Загрузить модель на самом деле очень просто, все, что нам нужно, это две функции:tf.train.import_meta_graphиsaver.restore(). Кроме того, это должно обеспечить правильное расположение пути сохранения модели. Кроме того, если мы хотим использовать модель на разных машинах, нам также необходимо задать параметры:clear_device=True.

Затем мы можем вызвать сохраненную операцию или параметр веса с ранее названным именем или именем сохраненной коллекции. Если используется область, то необходимо также включить имя области. При фактическом вызове этих операций вы также должны использовать что-то вроде{'PlaceholderName:0': data}введите заполнитель для , иначе произойдет ошибка.

Код для загрузки модели выглядит следующим образом:

sess = tf.Session()

# Import graph from the path and recover session
# 加载模型并恢复到会话中
saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True)
saver.restore(sess, 'models/model_name')

# There are TWO options how to access the operation (choose one)
# 两种方法来调用指定的运算操作,选择其中一个都可以
  # FROM SAVED COLLECTION: 从保存的集合中调用
activation = tf.get_collection('activation')[0]
  # BY NAME: 采用命名的方式
activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]

# Use imported graph for data
# You have to feed data as {'x:0': data}
# Don't forget on ':0' part!
# 采用加载的模型进行操作,不要忘记输入占位符
data = 50
result = sess.run(activation, {'x:0': data})
print(result)

несколько моделей

Выше описано, как загрузить одну модель, но как загрузить несколько моделей?

Если вы загрузите несколько моделей, загрузив одну модель, вы получите ошибку конфликта переменных и не будете работать. Причина этой проблемы из-за графика по умолчанию. Конфликт возникает из-за того, что мы загружаем все переменные в график по умолчанию, взятый текущим сеансом. Когда мы используем сеансы, мы можем передатьtf.Session(graph=MyGraph)чтобы указать использование другого уже созданного графа. Поэтому, если мы хотим загрузить несколько моделей, все, что нам нужно сделать, это загрузить их на разные графики и использовать в разных сеансах.

Здесь класс настраивается для выполнения операции загрузки модели указанного пути в локальный граф. Этот класс также обеспечиваетrunфункция для работы с входными данными с использованием загруженной модели. Этот класс полезен для меня, потому что я всегда помещаю вывод модели в коллекцию или называю ее.activation_optи назовите заполнитель ввода какx. Вы можете изменить и расширить этот класс в соответствии с реальными потребностями вашего приложения.

код показывает, как показано ниже:

import tensorflow as tf

class ImportGraph():
    """  Importing and running isolated TF graph """
    def __init__(self, loc):
        # Create local graph and use it in the session
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.graph.as_default():
            # Import saved model from location 'loc' into local graph
            # 从指定路径加载模型到局部图中
            saver = tf.train.import_meta_graph(loc + '.meta',
                                               clear_devices=True)
            saver.restore(self.sess, loc)
            # There are TWO options how to get activation operation:
            # 两种方式来调用运算或者参数
              # FROM SAVED COLLECTION:            
            self.activation = tf.get_collection('activation')[0]
              # BY NAME:
            self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]

    def run(self, data):
        """ Running the activation operation previously imported """
        # The 'x' corresponds to name of input placeholder
        return self.sess.run(self.activation, feed_dict={"x:0": data})
      
      
### Using the class ###
# 测试样例
data = 50         # random data
model = ImportGraph('models/model_name')
result = model.run(data)
print(result)

Суммировать

Загрузка нескольких моделей не так уж сложна, если вы понимаете механику TensorFlow. Приведенное выше решение может быть не идеальным, но оно простое и быстрое. Наконец, приведен пример кода, обобщающий весь процесс, который находится в блокноте Jupyter, и адрес кода выглядит следующим образом:

gist.GitHub.com/br ETA01/post20…


Наконец, дайте адреса github нескольких примеров кода в статье:

  1. Code for creating, training and saving TensorFlow model.
  2. Importing and using TensorFlow graph (model)
  3. Class for importing multiple TensorFlow graphs.
  4. Example of importing multiple TensorFlow modules

Добро пожаловать, чтобы обратить внимание на мою общедоступную учетную запись WeChat - машинное обучение и компьютерное зрение или отсканируйте QR-код ниже, оставьте сообщение в фоновом режиме, поделитесь со мной своими предложениями и мнениями, исправьте возможные ошибки в статье, и давайте общаться, учиться и прогрессируйте вместе!

Рекомендуемое чтение

1.Серия «Введение в машинное обучение» (1) — Обзор машинного обучения (часть 1)

2.Серия «Введение в машинное обучение» (2) — Обзор машинного обучения (часть 2)

3.[Обучающая серия GAN] Первое знакомство с GAN

4.[GAN Learning Series 2] Происхождение GAN

5.Библиотека Google GAN с открытым исходным кодом — TFGAN