[Перевод] Реализация GAN в Keras: создание приложений для устранения размытия изображений

искусственный интеллект Программа перевода самородков GitHub Keras
[Перевод] Реализация GAN в Keras: создание приложений для устранения размытия изображений

В 2014 году Ян Гудфеллоу предложилГенеративно-состязательные сети(GAN), в этой статье основное внимание будет уделено использованиюKerasвыполнитьМодель устранения размытия изображения на основе состязательной генеративной сетиВесь код Keras находится вздесь.

Посмотреть исходный текстscientific publicationа такжеРеализация версии Pytorch.


Краткий обзор генеративно-состязательных сетей

В генеративно-состязательных сетях две сети обучаются друг против друга. Создайте модель черезСоздать поддельный вводВводящие в заблуждение дискриминационные модели. дискриминантная модельРазличать реальный и поддельный ввод.

Учебный процесс ГАН — Source

тренировался с3 основных шага:

  • Используйте генеративные моделиСоздание поддельного ввода на основе шума.
  • Используйте как настоящие, так и поддельные входные данныеОбучите дискриминационную модель.
  • Обучите всю модель:Модель состоит из модели генерации, за которой следует конкатенированная дискриминантная модель.

Обратите внимание, что на третьем этапе веса дискриминационной модели больше не обновляются.

Причина объединения двух модельных сетей заключается в том, что невозможно напрямую передать выходные данные генеративной модели.Единственный критерий, который мы измеряем (результат генеративной модели), заключается в том, принимает ли модель сгенерированные образцы или нет..

Вот краткий обзор структуры GAN. Если вам трудно понять, вы можете обратиться к этомуexcellent introduction.


набор данных

Ян Гудфеллоу впервые применил модель GAN для генерации данных MNIST. В этом уроке мы используемГенеративные состязательные сети для устранения размытия изображения. Таким образом, входными данными для генеративной модели являются не шумы, а размытые изображения.

Набор данных принимаетНабор данных ГОПРО. вы можете скачатьЛайт(9 ГБ) илиполная версия(35 ГБ). это содержитиз нескольких просмотров улицискусственно размытое изображение. Наборы данных находятся в подпапках по сценам.

Начнем с того, что поместим картинки в папки A (размытые) и B (четкие). Структура этих A и B такая же, как и в оригинальной статье.pix2pix articleПоследовательный. я написалпользовательский скриптЧтобы выполнить эту задачу, следуйте README, чтобы использовать его.


Модель

Тренировочный процесс остается прежним. Во-первых, давайте посмотрим на структуру нейронной сети!

генеративная модель

Генеративные модели предназначены для воспроизведения четких изображений. Сетевая модель основана наОстаточная сеть (ResNet) блокировать. Он отслеживает эволюцию исходного размытого изображения. Эта статья основана наUNetверсия, я еще не реализовал ее. Обе структуры подходят для устранения размытия изображения.

Сетевая структура генеративной модели DeblurGAN — Source

Ядро применяется к исходному изображению с апсэмплингом9 блоков ResNet. Давайте посмотрим на реализацию Keras!

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout

def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
    """
    使用序贯(sequential) API 对 Keras Resnet 块进行实例化。
    :param input: 输入张量
    :param filters: 卷积核数目
    :param kernel_size: 卷积核大小
    :param strides: 卷积步幅大小
    :param use_dropout: 布尔值,确定是否使用 dropout
    :return: Keras 模型
    """
    x = ReflectionPadding2D((1,1))(input)
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    if use_dropout:
        x = Dropout(0.5)(x)

    x = ReflectionPadding2D((1,1))(x)
    x = Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x = BatchNormalization()(x)

    # 输入和输出之间连接两个卷积层
    merged = Add()([input, x])
    return merged

Слои ResNet в основном представляют собой сверточные слои с добавленными входными и выходными данными для формирования окончательного вывода.

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model

from layer_utils import ReflectionPadding2D, res_block

ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9


def generator_model():
    """构建生成模型"""
    # Current version : ResNet block
    inputs = Input(shape=image_shape)

    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    # 应用 9 ResNet blocks
    mult = 2**n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf*mult, use_dropout=True)

    # 减少卷积核到3个 (RGB)
    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = ReflectionPadding2D((3,3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
    x = Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

Keras реализует генеративные модели

Как и планировалось, 9 блоков ResNet применяются к версии ввода с повышенной дискретизацией. мы добавляемСоединение от входа к выходуИ разделите на 2, чтобы сохранить нормализованный вывод.

Это все для генеративных моделей, давайте посмотрим на дискриминационные модели.

дискриминантная модель

Цель дискриминационной модели состоит в том, чтобы определить, является ли входное изображение искусственным. Следовательно, структура дискриминационной модели является сверточной, ивывод - это одно значение.

from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model

ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)


def discriminator_model():
    """构建判别模型."""
    n_layers, use_sigmoid = 3, False
    inputs = Input(shape=input_shape_discriminator)

    x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
    x = LeakyReLU(0.2)(x)

    nf_mult, nf_mult_prev = 1, 1
    for n in range(n_layers):
        nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
    if use_sigmoid:
        x = Activation('sigmoid')(x)

    x = Flatten()(x)
    x = Dense(1024, activation='tanh')(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model

Keras реализует дискриминационную модель

Завершающим этапом является построение полной модели. Это ГАНособенностьзаключается в том, что вход представляет собой реальное изображение, а не шум. Таким образом, мы получаем прямую обратную связь на выходе генеративной модели.

from keras.layers import Input
from keras.models import Model

def generator_containing_discriminator_multiple_outputs(generator, discriminator):
    inputs = Input(shape=image_shape)
    generated_images = generator(inputs)
    outputs = discriminator(generated_images)
    model = Model(inputs=inputs, outputs=[generated_images, outputs])
    return model

Давайте посмотрим, как воспользоваться этой особенностью, используя две функции потерь.


тренироваться

функция потерь

Мы извлекаем значения потерь на двух уровнях: один в конце генеративной модели, а другой в конце всей модели.

Первый заключается в расчете непосредственно на основе выходных данных генеративной модели.потеря восприятия. Это значение потерь гарантирует, что модель GAN ориентирована на задачу устранения размытия. Он сравнивает VGGпервая сверткавывод.

import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model

image_shape = (256, 256, 3)

def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

Второе значение потерь заключается в вычислении выходных данных всей модели.Wasserstein loss. этоСредняя разница между двумя изображениями. Он известен улучшением сходимости состязательных генеративных сетей.

import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

тренировочный процесс

Первым шагом является загрузка данных и инициализация модели. Мы используем пользовательские функции для загрузки набора данных и добавления оптимизатора Adam в модель. Мы предотвращаем обучение дискриминативной модели, установив обучаемый параметр Keras.

# 载入数据集
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']

# 初始化模型
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

# 初始化优化器
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# 编译模型
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True

Затем мы начинаем итерацию, разделяя набор данных на пакеты.

for epoch in range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0] / batch_size))

  # 将图像随机划入不同批次
  permutated_indexes = np.random.permutation(x_train.shape[0])

  for index in range(int(x_train.shape[0] / batch_size)):
      batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
      image_blur_batch = x_train[batch_indexes]
      image_full_batch = y_train[batch_indexes]

Наконец, мы последовательно обучаем генеративную модель и дискриминационную модель в соответствии с двумя потерями. Мы используем генеративные модели для создания поддельных входных данных. Мы обучаем дискриминационную модель различать поддельные и настоящие входные данные, а затем обучаем всю модель.

for epoch in range(epoch_num):
  for index in range(batches):
    # [Batch Preparation]

    # 生成假输入
    generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
    
    # 在真假输入上训练多次判别模型
    for _ in range(critic_updates):
        d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
        d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
        d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

    d.trainable = False
    # Train generator only on discriminator's decision and generated images
    d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])

    d.trainable = True

вы можете обратиться кGithubСмотрите весь цикл!

какой-то материал

Я использую Deep Learning AMI (версия 3.0)AWS Instance(p2.xlarge) вНабор данных ГОПРОВ облегченной версии время обучения составляет около 5 часов (50 итераций).

Результат устранения размытия изображения

Слева направо: исходное изображение, размытое изображение, выход GAN.

Вышеприведенный результат является результатом нашего Keras Deblur GAN. Даже при сильном размытии сеть смогла уменьшить и сформировать более убедительное изображение. Фары четче и ветки четче.

Слева: тестовое изображение GOPRO, справа: выход GAN.

ОграничениеИндуцированные узоры на изображениях, что может быть вызвано использованием VGG в качестве потерь.

Слева: тестовое изображение GOPRO, справа: выход GAN.

Надеюсь, вам понравилась эта статья об устранении размытости изображения с помощью генеративно-состязательных моделей. Не стесняйтесь комментировать, подписывайтесь на нас илиСвяжитесь со мной.

Если вас интересует компьютерное зрение, вы можете ознакомиться с одной из наших предыдущих статей.Keras реализует поиск изображений на основе содержимого. Ниже приведен список ресурсов для генеративно-состязательных сетей.

Слева: тестовое изображение GOPRO, справа: выход GAN.

Список ресурсов для создания состязательных сетей.


Программа перевода самородковэто сообщество, которое переводит высококачественные технические статьи из Интернета сНаггетсДелитесь статьями на английском языке на . Охват контентаAndroid,iOS,внешний интерфейс,задняя часть,блокчейн,продукт,дизайн,искусственный интеллектЕсли вы хотите видеть более качественные переводы, пожалуйста, продолжайте обращать вниманиеПрограмма перевода самородков,официальный Вейбо,Знай колонку.