«Это 13-й день моего участия в ноябрьском испытании обновлений. Подробную информацию об этом событии см.:Вызов последнего обновления 2021 г."
Интуитивное понимание GAN
Ян Гудфеллоу первым предложил GAN и использовал метафору изображения, чтобы представить модель GAN: функция сети генерации G состоит в том, чтобы генерировать реалистичные поддельные банкноты, чтобы попытаться обмануть дискриминатор D, а дискриминатор D узнает настоящие банкноты и поддельные банкноты. генерируется генератором G. Узнайте, как идентифицировать банкноты. Две сети обучаются игре друг против друга до тех пор, пока фальшивые банкноты, произведенные генератором G, не станут трудными для распознавания дискриминатором D. DCGAN использует операции свертки и деконволюции, чтобы заменить полносвязную операцию в исходной GAN.
Структура сети DCGAN
GAN содержит генеративную сеть (генератор,G
) и дискриминаторной сети (Дискриминатор,D
),вG
истинное распределение для изучения данных,D
дляG
Сгенерированные данные отличаются от реальных образцов.
генерирующая сеть G
из предыдущей раздачискрытые переменные в середине выборки, обучение распределению через G, чтобы получить сгенерированные образцы. где априорное распределение скрытой переменной zМожно предположить общее распределение.
дискриминантная сеть D
представляет собой сеть бинарной классификации, которая оценивает выборку по реальному распределению данных.Данныеи сгенерированные данные, отобранные из генеративной сети, обучающий набор данных дискриминантной сети состоит изисочинение. реальный образецМетка 1 помечается как 1, а образец, сгенерированный генеративной сетьюПомеченная как 0, дискриминативная сеть оптимизируется путем минимизации ошибки между прогнозируемым значением дискриминационной сети D и меткой.
Учебная цель ГАН
Цель дискриминационной сети состоит в том, чтобы отличить реальные образцыс поддельными образцами. Его цель — минимизировать функцию кросс-энтропийных потерь между прогнозируемым значением и истинным значением:
CE представляет функцию потерь перекрестной энтропии CrossEntropy:
Цель оптимизации дискриминантной сети D:
Пучокпреобразовать в:
для генеративных сетей, надеемся, что сгенерированные данные смогут обмануть дискриминантную сеть D, поддельные образцыЧем ближе выход дискриминационной сети к истинной метке, тем лучше. То есть при обучении генеративной сети предполагается различать выходные данные сети.Чем ближе к 1, тем лучше, минимизироватьФункция кросс-энтропийных потерь между и 1:
Пучокпреобразовать в:
в– параметры для построения сети G.
Обучите дискриминатор и генератор итеративно в процессе обучения.
Реализация DCGAN
использоватьcifar10
Учебный набор DCGAN реализован как обучающий набор GAN.
загрузка данных
нагрузкаcifar10
обучающий набор и предварительная обработка данных
#批大小
batch_size = 64
(train_x,_),_ = keras.datasets.cifar10.load_data()
#数据归一化
train_x = train_x / (255. / 2) - 1
print(train_x.shape)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
Интернет
Сеть состоит из дискриминативной сети и генеративной сети.
различающая сеть
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator,self).__init__()
filters = 64
self.conv1 = keras.layers.Conv2D(filters,4,2,'valid',use_bias=False)
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv2D(filters*2,4,2,'valid',use_bias=False)
self.bn2 = keras.layers.BatchNormalization()
self.conv3 = keras.layers.Conv2D(filters*4,3,1,'valid',use_bias=False)
self.bn3 = keras.layers.BatchNormalization()
self.conv4 = keras.layers.Conv2D(filters*8,3,1,'valid',use_bias=False)
self.bn4 = keras.layers.BatchNormalization()
#全局池化
self.pool = keras.layers.GlobalAveragePooling2D()
self.flatten = keras.layers.Flatten()
self.fc = keras.layers.Dense(1)
def call(self,inputs,training=True):
x = inputs
x = tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
x = tf.nn.leaky_relu(self.bn4(self.conv4(x),training=training))
x = self.pool(x)
x = self.flatten(x)
logits = self.fc(x)
return logits
генерирующая сеть
class Generator(keras.Model):
def __init__(self):
super(Generator,self).__init__()
filters = 64
self.conv1 = keras.layers.Conv2DTranspose(filters*4,4,1,'valid',use_bias=False)
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv2DTranspose(filters*3,4,2,'same',use_bias=False)
self.bn2 = keras.layers.BatchNormalization()
self.conv3 = keras.layers.Conv2DTranspose(filters*1,4,2,'same',use_bias=False)
self.bn3 = keras.layers.BatchNormalization()
self.conv4 = keras.layers.Conv2DTranspose(3,4,2,'same',use_bias=False)
def call(self,inputs,training=False):
x = inputs
x = tf.reshape(x,(x.shape[0],1,1,x.shape[1]))
x = tf.nn.relu(x)
x = tf.nn.relu(self.bn1(self.conv1(x),training=training))
x = tf.nn.relu(self.bn2(self.conv2(x),training=training))
x = tf.nn.relu(self.bn3(self.conv3(x),training=training))
x = self.conv4(x)
x = tf.tanh(x)
return x
сетевое обучение
При обучении вы можете обучать дискриминатор несколько раз, а затем один раз обучать генератор.
Определите функцию потерь
def celoss_ones(logits):
# 计算属于与标签为1的交叉熵
y = tf.ones_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def celoss_zeros(logits):
# 计算属于与标签为0的交叉熵
y = tf.zeros_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
# 计算鉴别器的损失函数
# 采样生成图片
fake_image = generator(batch_z, is_training)
# 判定生成图片
d_fake_logits = discriminator(fake_image, is_training)
# 判定真实图片
d_real_logits = discriminator(batch_x, is_training)
# 真实图片与1之间的误差
d_loss_real = celoss_ones(d_real_logits)
# 生成图片与0之间的误差
d_loss_fake = celoss_zeros(d_fake_logits)
# 合并误差
loss = d_loss_fake + d_loss_real
return loss
def g_loss_fn(generator, discriminator, batch_z, is_training):
#计算生成器的损失函数
# 采样生成图片
fake_image = generator(batch_z, is_training)
# 在训练生成网络时,需要迫使生成图片判定为真
d_fake_logits = discriminator(fake_image, is_training)
# 计算生成图片与1之间的误差
loss = celoss_ones(d_fake_logits)
return loss
Создание сети и оптимизатора
#定义超参数
#潜变量维度
z_dim = 100
#epoch大小
epochs = 300
#批大小
batch_size = 64
#学习率
lr = 0.0002
is_training = True
#实例化网络
discriminator = Discriminator()
discriminator.build(input_shape=(4,32,32,3))
discriminator.summary()
generator = Generator()
generator.build(input_shape=(4,z_dim))
generator.summary()
#实例化优化器
g_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
тренироваться
#统计损失值
d_losses = []
g_losses = []
for epoch in range(epochs):
for _,batch_x in enumerate(dataset):
batch_z = tf.random.normal([batch_size,z_dim])
with tf.GradientTape() as tape:
d_loss = d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)
grads = tape.gradient(d_loss,discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator,discriminator,batch_z,is_training)
grads = tape.gradient(g_loss,generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))
Показать результаты
Тренируясь и тестируя, вы можете получить лучшие результаты, настроив гиперпараметры.
Эффект тренировки на 26 эпох: