- Оригинальный адрес:GAN with Keras: Application to Image Deblurring
- Оригинальный автор:Рафаэль Медек
- Перевод с:Программа перевода самородков
- Постоянная ссылка на эту статью:GitHub.com/rare earth/gold-no…
- Переводчик:luochen
- Корректор:SergeyChang mingxing47
В 2014 году Ян Гудфеллоу предложилГенеративно-состязательные сети(GAN), в этой статье основное внимание будет уделено использованиюKerasвыполнитьМодель устранения размытия изображения на основе состязательной генеративной сетиВесь код Keras находится вздесь.
Посмотреть исходный текстscientific publicationа такжеРеализация версии Pytorch.
Краткий обзор генеративно-состязательных сетей
В генеративно-состязательных сетях две сети обучаются друг против друга. Создайте модель черезСоздать поддельный вводВводящие в заблуждение дискриминационные модели. дискриминантная модельРазличать реальный и поддельный ввод.
Учебный процесс ГАН — Source
тренировался с3 основных шага:
- Используйте генеративные моделиСоздание поддельного ввода на основе шума.
- Используйте как настоящие, так и поддельные входные данныеОбучите дискриминационную модель.
- Обучите всю модель:Модель состоит из модели генерации, за которой следует конкатенированная дискриминантная модель.
Обратите внимание, что на третьем этапе веса дискриминационной модели больше не обновляются.
Причина объединения двух модельных сетей заключается в том, что невозможно напрямую передать выходные данные генеративной модели.Единственный критерий, который мы измеряем (результат генеративной модели), заключается в том, принимает ли модель сгенерированные образцы или нет..
Вот краткий обзор структуры GAN. Если вам трудно понять, вы можете обратиться к этомуexcellent introduction.
набор данных
Ян Гудфеллоу впервые применил модель GAN для генерации данных MNIST. В этом уроке мы используемГенеративные состязательные сети для устранения размытия изображения. Таким образом, входными данными для генеративной модели являются не шумы, а размытые изображения.
Набор данных принимаетНабор данных ГОПРО. вы можете скачатьЛайт(9 ГБ) илиполная версия(35 ГБ). это содержитиз нескольких просмотров улицискусственно размытое изображение. Наборы данных находятся в подпапках по сценам.
Начнем с того, что поместим картинки в папки A (размытые) и B (четкие). Структура этих A и B такая же, как и в оригинальной статье.pix2pix articleПоследовательный. я написалпользовательский скриптЧтобы выполнить эту задачу, следуйте README, чтобы использовать его.
Модель
Тренировочный процесс остается прежним. Во-первых, давайте посмотрим на структуру нейронной сети!
генеративная модель
Генеративные модели предназначены для воспроизведения четких изображений. Сетевая модель основана наОстаточная сеть (ResNet) блокировать. Он отслеживает эволюцию исходного размытого изображения. Эта статья основана наUNetверсия, я еще не реализовал ее. Обе структуры подходят для устранения размытия изображения.
Сетевая структура генеративной модели DeblurGAN — Source
Ядро применяется к исходному изображению с апсэмплингом9 блоков ResNet. Давайте посмотрим на реализацию Keras!
from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout
def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
"""
使用序贯(sequential) API 对 Keras Resnet 块进行实例化。
:param input: 输入张量
:param filters: 卷积核数目
:param kernel_size: 卷积核大小
:param strides: 卷积步幅大小
:param use_dropout: 布尔值,确定是否使用 dropout
:return: Keras 模型
"""
x = ReflectionPadding2D((1,1))(input)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
if use_dropout:
x = Dropout(0.5)(x)
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
# 输入和输出之间连接两个卷积层
merged = Add()([input, x])
return merged
Слои ResNet в основном представляют собой сверточные слои с добавленными входными и выходными данными для формирования окончательного вывода.
from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from layer_utils import ReflectionPadding2D, res_block
ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9
def generator_model():
"""构建生成模型"""
# Current version : ResNet block
inputs = Input(shape=image_shape)
x = ReflectionPadding2D((3, 3))(inputs)
x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Increase filter number
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# 应用 9 ResNet blocks
mult = 2**n_downsampling
for i in range(n_blocks_gen):
x = res_block(x, ngf*mult, use_dropout=True)
# 减少卷积核到3个 (RGB)
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
x = Activation('tanh')(x)
# Add direct connection from input to output and recenter to [-1, 1]
outputs = Add()([x, inputs])
outputs = Lambda(lambda z: z/2)(outputs)
model = Model(inputs=inputs, outputs=outputs, name='Generator')
return model
Keras реализует генеративные модели
Как и планировалось, 9 блоков ResNet применяются к версии ввода с повышенной дискретизацией. мы добавляемСоединение от входа к выходуИ разделите на 2, чтобы сохранить нормализованный вывод.
Это все для генеративных моделей, давайте посмотрим на дискриминационные модели.
дискриминантная модель
Цель дискриминационной модели состоит в том, чтобы определить, является ли входное изображение искусственным. Следовательно, структура дискриминационной модели является сверточной, ивывод - это одно значение.
from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)
def discriminator_model():
"""构建判别模型."""
n_layers, use_sigmoid = 3, False
inputs = Input(shape=input_shape_discriminator)
x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
x = LeakyReLU(0.2)(x)
nf_mult, nf_mult_prev = 1, 1
for n in range(n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
if use_sigmoid:
x = Activation('sigmoid')(x)
x = Flatten()(x)
x = Dense(1024, activation='tanh')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=x, name='Discriminator')
return model
Keras реализует дискриминационную модель
Завершающим этапом является построение полной модели. Это ГАНособенностьзаключается в том, что вход представляет собой реальное изображение, а не шум. Таким образом, мы получаем прямую обратную связь на выходе генеративной модели.
from keras.layers import Input
from keras.models import Model
def generator_containing_discriminator_multiple_outputs(generator, discriminator):
inputs = Input(shape=image_shape)
generated_images = generator(inputs)
outputs = discriminator(generated_images)
model = Model(inputs=inputs, outputs=[generated_images, outputs])
return model
Давайте посмотрим, как воспользоваться этой особенностью, используя две функции потерь.
тренироваться
функция потерь
Мы извлекаем значения потерь на двух уровнях: один в конце генеративной модели, а другой в конце всей модели.
Первый заключается в расчете непосредственно на основе выходных данных генеративной модели.потеря восприятия. Это значение потерь гарантирует, что модель GAN ориентирована на задачу устранения размытия. Он сравнивает VGGпервая сверткавывод.
import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model
image_shape = (256, 256, 3)
def perceptual_loss(y_true, y_pred):
vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
loss_model.trainable = False
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
Второе значение потерь заключается в вычислении выходных данных всей модели.Wasserstein loss. этоСредняя разница между двумя изображениями. Он известен улучшением сходимости состязательных генеративных сетей.
import keras.backend as K
def wasserstein_loss(y_true, y_pred):
return K.mean(y_true*y_pred)
тренировочный процесс
Первым шагом является загрузка данных и инициализация модели. Мы используем пользовательские функции для загрузки набора данных и добавления оптимизатора Adam в модель. Мы предотвращаем обучение дискриминативной модели, установив обучаемый параметр Keras.
# 载入数据集
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']
# 初始化模型
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)
# 初始化优化器
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
# 编译模型
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True
Затем мы начинаем итерацию, разделяя набор данных на пакеты.
for epoch in range(epoch_num):
print('epoch: {}/{}'.format(epoch, epoch_num))
print('batches: {}'.format(x_train.shape[0] / batch_size))
# 将图像随机划入不同批次
permutated_indexes = np.random.permutation(x_train.shape[0])
for index in range(int(x_train.shape[0] / batch_size)):
batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
image_blur_batch = x_train[batch_indexes]
image_full_batch = y_train[batch_indexes]
Наконец, мы последовательно обучаем генеративную модель и дискриминационную модель в соответствии с двумя потерями. Мы используем генеративные модели для создания поддельных входных данных. Мы обучаем дискриминационную модель различать поддельные и настоящие входные данные, а затем обучаем всю модель.
for epoch in range(epoch_num):
for index in range(batches):
# [Batch Preparation]
# 生成假输入
generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
# 在真假输入上训练多次判别模型
for _ in range(critic_updates):
d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
d.trainable = False
# Train generator only on discriminator's decision and generated images
d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])
d.trainable = True
вы можете обратиться кGithubСмотрите весь цикл!
какой-то материал
Я использую Deep Learning AMI (версия 3.0)AWS Instance(p2.xlarge) вНабор данных ГОПРОВ облегченной версии время обучения составляет около 5 часов (50 итераций).
Результат устранения размытия изображения
Слева направо: исходное изображение, размытое изображение, выход GAN.
Вышеприведенный результат является результатом нашего Keras Deblur GAN. Даже при сильном размытии сеть смогла уменьшить и сформировать более убедительное изображение. Фары четче и ветки четче.
Слева: тестовое изображение GOPRO, справа: выход GAN.
ОграничениеИндуцированные узоры на изображениях, что может быть вызвано использованием VGG в качестве потерь.
Слева: тестовое изображение GOPRO, справа: выход GAN.
Надеюсь, вам понравилась эта статья об устранении размытости изображения с помощью генеративно-состязательных моделей. Не стесняйтесь комментировать, подписывайтесь на нас илиСвяжитесь со мной.
Если вас интересует компьютерное зрение, вы можете ознакомиться с одной из наших предыдущих статей.Keras реализует поиск изображений на основе содержимого. Ниже приведен список ресурсов для генеративно-состязательных сетей.
Слева: тестовое изображение GOPRO, справа: выход GAN.
Список ресурсов для создания состязательных сетей.
-
NIPS 2016: генеративно-состязательные сети by Ian Goodfellow
-
ICCV 2017: Учебное пособие по состязательным генеративным сетям
-
Keras Реализация состязательных генеративных сетей by Eric Linder-Noren
-
Состязательная генерация списка сетевых ресурсов by deeplearning4j
-
Потрясающие состязательные генеративные сети by Holger Caesar
Программа перевода самородковэто сообщество, которое переводит высококачественные технические статьи из Интернета сНаггетсДелитесь статьями на английском языке на . Охват контентаAndroid,iOS,внешний интерфейс,задняя часть,блокчейн,продукт,дизайн,искусственный интеллектЕсли вы хотите видеть более качественные переводы, пожалуйста, продолжайте обращать вниманиеПрограмма перевода самородков,официальный Вейбо,Знай колонку.