В прошлой статье мы говорили об использовании метода в предварительно обученной сети, извлечении признаков, Сегодня мы обсуждаем другой метод, тонкую настройку модели, которая также является методом трансферного обучения.
Точная настройка модели
Зачем нужна тонкая настройка модели? Догадываемся и предыдущие эксперименты, у нас такой консенсус, чем меньше объем данных, чем больше узлов признаков в сети, тем легче это приведет к переобучению, на что мы конечно не надеемся, но для тех, кто заранее обученные модели, Также возможно, что работа, которую необходимо выполнить, не может быть выполнена хорошо в конце, поэтому нам также необходимо ее изменить.По этой причине нам нужно взять обученную модель и изменить ее более абстрактный слой в нем, то есть слой за сетевым слоем, а затем использовать новый классификатор, способный лучше решить предложенную выше проблему переобучения.
Шаги для тонкой настройки сети:
-
На основе обученной сети (базовой сети) добавлять пользовательские слои;
-
Заморозить базовую сеть и обучить только что добавленные слои;
-
Заморозить некоторые слои базовой сети, а другую часть можно обучить;
-
Совместно тренируйте размороженные слои и добавленные части.
Метод, упомянутый в предыдущей статье, может выполнить первые два шага, а затем мы увидим, как решить последние два шага. Здесь мы также хотим более подробно рассказать о проблемах, вызванных слишком большой настройкой количества слоев: по мере увеличения количества переменных слоев возрастает риск переобучения. Кроме того, явная настройка слоев в сети, которые распознают пиксели и линии, не так эффективна, как настройка слоев, распознающих уши, потому что уровень метода является более общим, независимо от того, распознает ли он кошек или таблицы, распознающие линии.
Код, который необходимо написать для выполнения этой задачи, также очень прост, то есть настроить модель как обучаемую, затем пройти каждый слой сети и установить, обучаема ли она для каждого слоя, до уровня layer_name, все предыдущие слои не поддаются обучению:
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'layer_name':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
Вот ключевая часть кода, старые правила и наконец весь код будет дан, давайте посмотрим на результаты:
Здесь необходимо обратить внимание на данные, они в начале нестабильны и быстро набирают высоту, поэтому данные ординаты не так хороши, но если мы посмотрим на более поздние данные, то точность обучения и точность проверки оба от 90% до 100%.Были некоторые колебания точности проверки, вызванные некоторым шумом в сети.Я не хочу заставлять их быть такими красивыми.Во-первых, потому что время обучения будет больше , а потому не думаю, что это особо нужно.Верхняя точка колебания и Нижняя точка находятся в пределах допустимого диапазона, и основное внимание следует уделить более важным вопросам.
На основе этой статьи и предыдущей подведем итоги:
-
В области компьютерного зрения производительность сверточных нейронных сетей очень хорошая, а в случае небольших наборов данных производительность очень хорошая.
-
Увеличение данных — хороший способ избежать переобучения, основной причиной которого может быть слишком мало данных или слишком много параметров.
-
Извлечение признаков может лучше применять существующие нейронные сети к небольшим наборам данных, а также может быть оптимизировано с помощью тонкой настройки.
Давайте посмотрим на код.Еще одно предложение.По возможности попробуйте использовать GPU для обучения сетевой модели.ЦП будет немного бессилен для решения этих проблем на данном этапе,и это займет много времени.Читатели также можно рассмотреть возможность уменьшения объема данных для увеличения скорости. , но во избежание переоснащения, пожалуйста, помните об этом виде проблемы, и это направление при возникновении проблемы (конечно, автор очень несчастен, и там не является простым в использовании GPU, поэтому очень мучительно ждать данных, чтобы рисовать скриншоты одно):
#!/usr/bin/env python3
import os
import time
import matplotlib.pyplot as plt
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
def cat():
base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')
# 定义密集连接分类器
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'block5_conv1':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
conv_base.summary()
# 对模型进行配置
model.compile(loss='binary_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-5),
metrics=['acc'])
# 对模型进行训练
history = model.fit_generator(
train_generator,
steps_per_epoch=100,
epochs=100,
validation_data=validation_generator,
validation_steps=50)
# 画图
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.show()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
if __name__ == "__main__":
time_start = time.time()
cat()
time_end = time.time()
print('Time Used: ', time_end - time_start)
Эта статья была впервые опубликована из публичного аккаунта: РАИС