Глубоко интересно | 16 удивительных WGAN

искусственный интеллект JavaScript
Глубоко интересно | 16 удивительных WGAN

Введение

На основе DCGAN познакомить с принципом и реализацией WGAN, а в дальнейшем практиковаться на наборах данных LFW и CelebA.

проблема

GAN столкнулся со следующими проблемами и вызовами

  • Трудно обучать, требует тщательного проектирования структуры модели и тщательного согласования степеней обучения G и D.
  • Функции потерь G и D не указывают на процесс обучения, поскольку в них отсутствует значимая метрика, связанная с качеством сгенерированных изображений.
  • коллапс режима, результирующее изображение выглядит реальным, но ему не хватает разнообразия

принцип

По сравнению с традиционной GAN, WGAN вносит только следующие три простых изменения.

  • D удалить сигмовидную из последнего слоя
  • Потери G и D не берут лог(sigmoid_cross_entropy_with_logits)
  • После каждого обновления параметров D усекайте его абсолютное значение не более чем до фиксированной константы c, то есть отсечение градиента (предыдущая работа); или используйте штраф за градиент, то есть штраф за градиент (более поздняя работа).

Функция потерь G изначально

\mathbb{E}_{z\sim p_z(z)}[\log(1-D(G(z)))]

В результате, если D слишком хорошо обучен, G не выучит эффективные градиенты.

Однако, если D недостаточно хорошо обучен, G также не выучит эффективные градиенты.

Точно так же, как если полиция слишком хороша, они сразу убьют вора, но если полиция плохая, они не могут заставить вора стать сильнее.

Следовательно, приведенная выше функция потерь делает обучение GAN особенно нестабильным, и необходимо тщательно координировать уровни обучения G и D.

Авторы GAN предложили другую версию функции потерь G, так называемую-logD trick

\mathbb{E}_{z\sim p_z(z)}[-\log(D(G(z)))]

G необходимо минимизировать указанную выше функцию потерь, что эквивалентно минимизации следующей функции потерь

KL(P_g||P_{data})-2JS(P_{data}||P_g)

Первое - это расхождение KL (расхождение Кульбака-Лейблера).

KL(P_1||P_2)=\mathbb{E}_{x\sim P_1}\log\frac{P_1}{P_2}

Последнее - JS Divergence (Дивергенция Дженсена-Шеннона).

\frac{1}{2}KL(P_1||\frac{P_1+P_2}{2})+\frac{1}{2}KL(P_2||\frac{P_1+P_2}{2})

Оба могут быть использованы для измерения расстояния между двумя распределениями.Чем меньше расстояние, тем более похожи два распределения.

Следовательно, вышеприведенная функция потерь, с одной стороны, должна уменьшить расхождение KL, с другой стороны, увеличить расхождение JS, что приведет к нестабильному обучению.

В дополнение к этому асимметрия расхождения KL приводит к разным штрафам для следующих двух случаев.

  • Г порождает недоверчивую картину, т. е. неточность, наказание
  • G генерирует изображения, похожие на реальные изображения, т. е. отсутствие разнообразия и меньший штраф.

В результате G имеет тенденцию генерировать некоторые уверенные, но похожие изображения и не осмеливается легко пытаться генерировать новые неуверенные изображения, так называемая проблема коллапса моды.

Три изменения, внесенные WGAN, решают проблемы сложности и нестабильности обучения GAN, коллапса режима и т. Д., И чем меньше функция потерь G, тем выше качество соответствующего генерируемого изображения.

Процесс обучения WGAN выглядит следующим образом: Штраф за градиент заставляет D соответствовать непрерывному условию 1. Подробные принципы и подробности см. в соответствующих документах для дальнейшего понимания.

WGAN训练过程

Некоторые экспериментальные результаты в статье следующие: хотя WGAN требует больше времени на обучение, сходимость более стабильна.

WGAN论文实验结果

Что еще более важно, WGAN обеспечивает более стабильную структуру GAN. G в DCGAN рухнет, если пакетная нормализация будет удалена, но WGAN не имеет этого ограничения.

Если WGAN реализована со структурой Deep Convolutional, результаты аналогичны DCGAN. Но в рамках WGAN G и D могут быть реализованы с более глубокими и сложными сетями, такими как ResNet (АР Вест V.org/ABS/1512.03…), чтобы добиться лучшего эффекта генерации

данные

Или два набора данных лиц, которые использовались ранее

  • ЛЧВ:vis-www.cs.umass.edu/lfw/, Labeled Faces in the Wild, в том числе 1680 человек с общим количеством изображений более 1,3 Вт.
  • CelebA:MM lab.IE. Толстый черный ящик. Квота. Скоро/проекты/CE…, CelebFaces Attributes Dataset, включающий 10177 человек с общим количеством изображений более 20 Вт, и каждое изображение также включает 5 позиций ключевых точек лица и 01 аннотацию 40 атрибутов, например, есть ли очки, шляпы, бороды и т. д.

выполнить

загрузить библиотеку

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave, mimsave
import cv2
import glob
from tqdm import tqdm

Выберите набор данных

dataset = 'lfw_new_imgs' # LFW
# dataset = 'celeba' # CelebA
images = glob.glob(os.path.join(dataset, '*.*')) 
print(len(images))

Определите некоторые константы, сетевые входы, вспомогательные функции

batch_size = 100
z_dim = 100
WIDTH = 64
HEIGHT = 64
LAMBDA = 10
DIS_ITERS = 3 # 5

OUTPUT_DIR = 'samples_' + dataset
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

X = tf.placeholder(dtype=tf.float32, shape=[batch_size, HEIGHT, WIDTH, 3], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

def lrelu(x, leak=0.2):
    return tf.maximum(x, leak * x)

В части дискриминатора обратите внимание, что пакетную нормализацию необходимо удалить, иначе это приведет к корреляции между пакетами, что повлияет на расчет штрафа за градиент.

def discriminator(image, reuse=None, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('discriminator', reuse=reuse):
        h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))
        
        h1 = lrelu(tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same'))
        
        h2 = lrelu(tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same'))
        
        h3 = lrelu(tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same'))
        
        h4 = tf.contrib.layers.flatten(h3)
        h4 = tf.layers.dense(h4, units=1)
        return h4

секция генератора

def generator(z, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('generator', reuse=None):
        d = 4
        h0 = tf.layers.dense(z, units=d * d * 512)
        h0 = tf.reshape(h0, shape=[-1, d, d, 512])
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))
        
        h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
        
        h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
        
        h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')
        h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
        
        h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=3, strides=2, padding='same', activation=tf.nn.tanh, name='g')
        return h4

функция потерь

g = generator(noise)
d_real = discriminator(X)
d_fake = discriminator(g, reuse=True)

loss_d_real = -tf.reduce_mean(d_real)
loss_d_fake = tf.reduce_mean(d_fake)
loss_g = -tf.reduce_mean(d_fake)
loss_d = loss_d_real + loss_d_fake

alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolates = alpha * X + (1 - alpha) * g
grad = tf.gradients(discriminator(interpolates, reuse=True), [interpolates])[0]
slop = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1]))
gp = tf.reduce_mean((slop - 1.) ** 2)
loss_d += LAMBDA * gp

vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]

функция оптимизации

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)

функция чтения изображения

def read_image(path, height, width):
    image = imread(path)
    h = image.shape[0]
    w = image.shape[1]
    
    if h > w:
        image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
    else:
        image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    
    
    image = cv2.resize(image, (width, height))
    return image / 255.

функция синтеза изображений

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

Функция для случайной генерации пакетов данных

def get_random_batch(nums):
    img_index = np.arange(len(images))
    np.random.shuffle(img_index)
    img_index = img_index[:nums]
    batch = np.array([read_image(images[i], HEIGHT, WIDTH) for i in img_index])
    batch = (batch - 0.5) * 2
    
    return batch

обучение модели

sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
samples = []
loss = {'d': [], 'g': []}

for i in tqdm(range(60000)):
    for j in range(DIS_ITERS):
        n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
        batch = get_random_batch(batch_size)
        _, d_ls = sess.run([optimizer_d, loss_d], feed_dict={X: batch, noise: n, is_training: True})
    
    _, g_ls = sess.run([optimizer_g, loss_g], feed_dict={X: batch, noise: n, is_training: True})
    
    loss['d'].append(d_ls)
    loss['g'].append(g_ls)
    
    if i % 500 == 0:
        print(i, d_ls, g_ls)
        gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})
        gen_imgs = (gen_imgs + 1) / 2
        imgs = [img[:, :, :] for img in gen_imgs]
        gen_imgs = montage(imgs)
        plt.axis('off')
        plt.imshow(gen_imgs)
        imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i), gen_imgs)
        plt.show()
        samples.append(gen_imgs)

plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig(os.path.join(OUTPUT_DIR, 'Loss.png'))
plt.show()
mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=10)

Результаты генерации лиц LFW следующие: они более стабильны, чем DCGAN.

WGAN生成LFW

Результаты генерации лица CelebA следующие:

WGAN生成CelebA

Сохраните модель для последующего использования

saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, 'wgan_' + dataset), global_step=60000)

Используйте модель для создания изображений лиц на одном компьютере.

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

batch_size = 100
z_dim = 100
# dataset = 'lfw_new_imgs'
dataset = 'celeba'

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph(os.path.join('samples_' + dataset, 'wgan_' + dataset + '-60000.meta'))
saver.restore(sess, tf.train.latest_checkpoint('samples_' + dataset))
graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')

n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, :] for img in gen_imgs]
gen_imgs = montage(imgs)
gen_imgs = np.clip(gen_imgs, 0, 1)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(gen_imgs)
plt.show()

Ссылаться на

видеоурок

Глубоко и интересно (1)