Автор|Та-Йинг Ченг, аспирант Оксфордского университета, блогер Medium, многие статьи включены в официальное издание платформы Towards Data Science.
Перевод|Сун Сянь
В прошлом обычно считалось, что генерация изображений — невыполнимая задача, потому что, согласно традиционным представлениям о машинном обучении, у нас вообще не было оснований для проверки качества сгенерированных изображений.
В 2014 году Гудфеллоу и др. предложили создатьГенеративно-состязательная сеть (GAN), что позволяет нам полностью полагаться на машинное обучение для создания чрезвычайно реалистичных изображений. Появление GAN потрясло всю индустрию искусственного интеллекта, а область компьютерного зрения и генерации изображений претерпела большие изменения.
Эта статья познакомит вас сКак работают GAN, и описывает, какЛегко начать работу с GAN с помощью PyTorch.
Принцип ГАН
В соответствии с традиционным методом результаты прогнозирования модели можно напрямую сравнивать с существующими наземными правдами. Однако трудно определить и измерить, что считается «правильными» сгенерированными изображениями.
Гудфеллоу и др. предлагают интересное решение: мы можем сначала обучить инструмент классификации автоматически отличать сгенерированные изображения от реальных изображений. Таким образом, мы можем использовать этот инструмент классификации для обучения генеративной сети до тех пор, пока она не сможет выводить полностью поддельные изображения, и даже сам инструмент классификации не сможет судить об истинности и ложности.Следуя этому ходу мысли, у нас есть GAN: то естьгенераторидискриминатор. Генератор отвечает за создание изображений на основе заданного набора данных, а дискриминатор отвечает за определение того, являются ли изображения реальными или поддельными. Процесс работы GAN показан на рисунке выше.
функция потерь
В процессе работы GAN мы можем обнаружить очевидное противоречие: сложно одновременно оптимизировать генератор и дискриминатор. Как вы понимаете, у этих двух моделей совершенно противоположные цели: генератор хочет максимально подделать реальную вещь, в то время как дискриминатор должен видеть сквозь изображения, сгенерированные генератором.
Чтобы проиллюстрировать это, пусть D(x) будет выходом дискриминатора, т. е. вероятностью того, что x является реальным изображением, и пусть G(z) будет выходом генератора. Дискриминатор похож на бинарный классификатор, поэтому его цель — максимизировать результат этой функции:Эта функция по существу представляет собой неотрицательную бинарную функцию кросс-энтропийных потерь. С другой стороны, целью генератора является минимизация вероятности того, что дискриминатор примет правильное решение, поэтому его цель состоит в минимизации результата вышеуказанной функции.
Следовательно, итоговая функция потерь будет минимаксной игрой между двумя классификаторами, выраженной следующим образом:Теоретически окончательный результат игры будет состоять в том, чтобы позволить дискриминатору оценить вероятность успеха и приблизиться к 0,5. Однако на практике минимаксные игры часто приводят к несходимости сети, поэтому важно тщательно настраивать параметры обучения модели.
При обучении GAN мы должны уделять особое внимание гиперпараметрам, таким как скорость обучения, Небольшая скорость обучения может позволить GAN иметь более равномерный вывод в случае большого количества входного шума.
вычислительная среда
библиотека
Эта статья поможет вам собрать всю программу (включая torchvision) с помощью PyTorch. В то же время мы будем использовать Matplotlib для визуализации сгенерированных результатов GAN. Следующий код может импортировать все вышеперечисленные библиотеки:
"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt
набор данных
Наборы данных очень важны для обучения GAN, особенно если учесть, что мы обычно имеем дело с неструктурированными данными (обычно изображениями, видео и т. д.) в GAN, а распределение данных может быть у любого класса. Именно это распределение данных является основой для вывода, генерируемого GAN.
Чтобы лучше продемонстрировать процесс построения GAN, в этой статье вы будете использовать простейший набор данных MNIST, который содержит 60 000 изображений рукописных арабских цифр.
рисунокMNISTТакие высококачественные неструктурированные наборы данных могут бытьСетка Титанизобщедоступный набор данныхнашел на сайте. На самом деле, платформа Gewuti Open Datasets охватывает множество высококачественных общедоступных наборов данных, а также может достигатьХостинг наборов данных и возможности универсального поиска, которая является очень практичной платформой сообщества для разработчиков ИИ.
аппаратные требования
В целом, несмотря на то, что можно обучать нейронную сеть с помощью ЦП, на самом деле лучшим выбором является графический процессор, так как это может значительно ускорить обучение. Мы можем использовать следующий код, чтобы проверить, можно ли обучить нашу машину с помощью GPU:
"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
выполнить
сетевая структура
Поскольку числа — это очень простая информация, мы можем объединить и дискриминатор, и генератор в полносвязные слои.
Мы можем создать дискриминатор и генератор в PyTorch с помощью следующего кода:
"""
Network Architectures
The following are the discriminator and generator architectures
"""
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 1)
self.activation = nn.LeakyReLU(0.1)
def forward(self, x):
x = x.view(-1, 784)
x = self.activation(self.fc1(x))
x = self.fc2(x)
return nn.Sigmoid()(x)
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.fc1 = nn.Linear(128, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 784)
self.activation = nn.ReLU()
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x = self.fc3(x)
x = x.view(-1, 1, 28, 28)
return nn.Tanh()(x)
тренироваться
При обучении GAN нам нужно оптимизировать дискриминатор при улучшении генератора, поэтому на каждой итерации нам нужно оптимизировать две конфликтующие функции потерь одновременно.
Для генератора мы введем случайный шум и позволим генератору изменить выходное изображение на основе небольшого количества шума:
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
for idx, (imgs, _) in enumerate(train_loader):
idx += 1
# Training the discriminator
# Real inputs are actual images of the MNIST dataset
# Fake inputs are from the generator
# Real inputs should be classified as 1 and fake as 0
real_inputs = imgs.to(device)
real_outputs = D(real_inputs)
real_label = torch.ones(real_inputs.shape[0], 1).to(device)
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
outputs = torch.cat((real_outputs, fake_outputs), 0)
targets = torch.cat((real_label, fake_label), 0)
D_loss = loss(outputs, targets)
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Training the generator
# For generator, goal is to make the discriminator believe everything is 1
noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
G_loss = loss(fake_outputs, fake_targets)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if idx % 100 == 0 or idx == len(train_loader):
print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
if (epoch+1) % 10 == 0:
torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
print('Model saved.')
результат
После 100 эпох обучения мы можем визуализировать набор данных и непосредственно видеть числа, которые модель генерирует из случайного шума:Мы видим, что сгенерированные результаты очень похожи на реальные данные. Учитывая, что мы построили здесь только очень простую модель, у фактического эффекта приложения будет очень много возможностей для улучшения.
не только подражать
GAN отличается от идей, предложенных экспертами по машинному зрению в прошлом, и применение конкретных сценариев с использованием GAN заставило многих людей восхищаться безграничным потенциалом глубоких сетей. Давайте взглянем на два самых известных приложения расширения GAN.
CycleGAN
CycleGAN, опубликованный Чжу Джуняном и другими в 2017 году, может напрямую преобразовывать изображение из домена X в домен Y без сопряжения изображений, например, превращая лошадь в зебру, превращая жаркое лето в середину зимы и изменяя картины Моне. в картины Ван Гога и так далее. Эти, казалось бы, фантастические преобразования CycleGAN может легко сделать, и результаты очень точны.
GauGAN
Nvidia использует GAN, чтобы люди могли получить очень реалистичную картину реальной сцены всего несколькими штрихами, чтобы обрисовать в общих чертах свои идеи. Хотя вычислительная стоимость такого приложения чрезвычайно высока, преобразовательные возможности GauGAN позволили исследовать беспрецедентные области исследований и приложений.
Эпилог
Я считаю, что когда вы это видите, вы уже знаете общий принцип работы GAN, и вы можете легко построить GAN самостоятельно.