Введение
Помимо VAE, Generative Adversarial Nets (GAN) также является очень популярной неконтролируемой генеративной моделью.
GAN в основном включает две основные сети.
- Генератор (Generator): обозначается как G, за счет изучения большого количества образцов он может генерировать некоторые поддельные образцы, аналогичные VAE.
- Дискриминатор (Discriminator): обозначается как D, принимает реальные выборки и выборки, сгенерированные G, и различает и различает
- G и D играют друг против друга. Благодаря обучению как генерирующая способность G, так и различительная способность D постепенно усиливаются и сближаются.
Обучение GAN очень сложное, и есть много деталей, на которые нужно обратить внимание, чтобы создавать высококачественные изображения.
- Надлежащее использование пакетной нормализации, LeakyReLU
- использовать
strides
Свертка 2 вместо объединения - Чередуйте тренировки, чтобы одна сторона не была слишком сильной
Здесь мы беремMNIST
Например, поTensorFlow
Для реализации GAN его также называют DCGAN (Deep Convolutional GAN) из-за использования глубокой сверточной нейронной сети.
принцип
Для случайного распределения шума z генератор генерирует поддельные выборки через сложную функцию отображения
Дискриминатор использует еще одну сложную функцию отображения.Для реальных образцов или поддельных образцов он выводит значение от 0 до 1. Чем больше значение, тем больше вероятность того, что это настоящий образец.
Общая целевая функция выглядит следующим образом
выполнить
загрузить библиотеку
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os, imageio
Скачать данные
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
Определите некоторые константы, сетевые входы, вспомогательные функции
batch_size = 100
z_dim = 100
OUTPUT_DIR = 'samples'
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
X = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')
def lrelu(x, leak=0.2):
return tf.maximum(x, leak * x)
def sigmoid_cross_entropy_with_logits(x, y):
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
Часть дискриминатора
def discriminator(image, reuse=None, is_training=is_training):
momentum = 0.9
with tf.variable_scope('discriminator', reuse=reuse):
h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))
h1 = tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same')
h1 = lrelu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
h2 = tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same')
h2 = lrelu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
h3 = tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same')
h3 = lrelu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
h4 = tf.contrib.layers.flatten(h3)
h4 = tf.layers.dense(h4, units=1)
return tf.nn.sigmoid(h4), h4
секция генератора
def generator(z, is_training=is_training):
momentum = 0.9
with tf.variable_scope('generator', reuse=None):
d = 3
h0 = tf.layers.dense(z, units=d * d * 512)
h0 = tf.reshape(h0, shape=[-1, d, d, 512])
h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))
h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')
h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')
h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')
h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=1, strides=1, padding='valid', activation=tf.nn.tanh, name='g')
return h4
Определите функцию потерь, обратите внимание, что здесь реализованы два дискриминатора, но параметры являются общими.
g = generator(noise)
d_real, d_real_logits = discriminator(X)
d_fake, d_fake_logits = discriminator(g, reuse=True)
vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits, tf.ones_like(d_real)))
loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))
loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.ones_like(d_fake)))
loss_d = loss_d_real + loss_d_fake
Определите функцию оптимизации, обратите внимание, что функция потерь должна соответствовать регулируемым параметрам.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)
Определите вспомогательную функцию для отображения нескольких изображений вместе в сетке.
def montage(images):
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
return m
Начните обучение, тренируйте G дважды за итерацию
sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
samples = []
loss = {'d': [], 'g': []}
for i in range(60000):
n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
batch = mnist.train.next_batch(batch_size=batch_size)[0]
batch = np.reshape(batch, [-1, 28, 28, 1])
batch = (batch - 0.5) * 2
d_ls, g_ls = sess.run([loss_d, loss_g], feed_dict={X: batch, noise: n, is_training: True})
loss['d'].append(d_ls)
loss['g'].append(g_ls)
sess.run(optimizer_d, feed_dict={X: batch, noise: n, is_training: True})
sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True})
sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True})
if i % 1000 == 0:
print(i, d_ls, g_ls)
gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, 0] for img in gen_imgs]
gen_imgs = montage(imgs)
plt.axis('off')
plt.imshow(gen_imgs, cmap='gray')
plt.savefig(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i))
plt.show()
samples.append(gen_imgs)
plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig('Loss.png')
plt.show()
imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5)
Сгенерированное изображение выглядит следующим образом: поскольку в функции потерь не используется попиксельное сравнение, края графика не будут размыты.
Сохраните модель для последующего использования
saver = tf.train.Saver()
saver.save(sess, './mnist_dcgan', global_step=60000)
Загрузите модель, если необходимо, например, для использования на одной машине
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
batch_size = 100
z_dim = 100
def montage(images):
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
return m
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph('./mnist_dcgan-60000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')
n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, 0] for img in gen_imgs]
gen_imgs = montage(imgs)
plt.axis('off')
plt.imshow(gen_imgs, cmap='gray')
plt.show()
Ссылаться на
- Генеративно-состязательные сети:АР Вест V.org/ABS/1406.26…
- Генеративно-состязательные сети:Woohoo Ян Гудфеллоу.com/slides/2017…
- Документы о состязательных сетях:GitHub.com/Чжан Цяньху…
- DCGAN-тензорный поток:GitHub.com/карп EDM20/D…