Принцип и реализация глубокой сверточной генеративно-состязательной сети (DCGAN) (реализована с помощью Tensorflow2)

искусственный интеллект глубокое обучение
Принцип и реализация глубокой сверточной генеративно-состязательной сети (DCGAN) (реализована с помощью Tensorflow2)

«Это 13-й день моего участия в ноябрьском испытании обновлений. Подробную информацию об этом событии см.:Вызов последнего обновления 2021 г."

Интуитивное понимание GAN

Ян Гудфеллоу первым предложил GAN и использовал метафору изображения, чтобы представить модель GAN: функция сети генерации G состоит в том, чтобы генерировать реалистичные поддельные банкноты, чтобы попытаться обмануть дискриминатор D, а дискриминатор D узнает настоящие банкноты и поддельные банкноты. генерируется генератором G. Узнайте, как идентифицировать банкноты. Две сети обучаются игре друг против друга до тех пор, пока фальшивые банкноты, произведенные генератором G, не станут трудными для распознавания дискриминатором D. DCGAN использует операции свертки и деконволюции, чтобы заменить полносвязную операцию в исходной GAN.

Структура сети DCGAN

GAN содержит генеративную сеть (генератор,G) и дискриминаторной сети (Дискриминатор,D),вGистинное распределение для изучения данных,DдляGСгенерированные данные отличаются от реальных образцов.

DCGAN网络架构

генерирующая сетьG(z)G(z) Gиз предыдущей раздачиpz()p_z(\cdot )скрытые переменные в середине выборкиzpz()z\sim p_z(\cdot), обучение распределению через Gpg(xz)p_g (x|z), чтобы получить сгенерированные образцыx pg(xz)x\sim ~p_g (x|z). где априорное распределение скрытой переменной zpz()p_z (\cdot)Можно предположить общее распределение.

дискриминантная сетьD(x)D(x) Dпредставляет собой сеть бинарной классификации, которая оценивает выборку по реальному распределению данных.pr()p_r (\cdot)Данныеxrpr()x_r\sim p_r(\cdot )и сгенерированные данные, отобранные из генеративной сетиxfpg(xz)x_f\sim p_g (x|z), обучающий набор данных дискриминантной сети состоит изxrx_rиxfx_fсочинение. реальный образецxrx_rМетка 1 помечается как 1, а образец, сгенерированный генеративной сетьюxfx_fПомеченная как 0, дискриминативная сеть оптимизируется путем минимизации ошибки между прогнозируемым значением дискриминационной сети D и меткой.

Учебная цель ГАН

Цель дискриминационной сети состоит в том, чтобы отличить реальные образцыxrx_rс поддельными образцамиxfx_f. Его цель — минимизировать функцию кросс-энтропийных потерь между прогнозируемым значением и истинным значением:

mθinL=CE(Dθ(xr),yr,Dθ(xf),yf)\ underset {θ} мин \ mathcal L = CE (D_θ (x_r), y_r, D_θ (x_f), y_f)

CE представляет функцию потерь перекрестной энтропии CrossEntropy:

L=xrpr()logDθ(xr)xfpg()log(1Dθ(xf))\mathcal L = - \sum_{x_r \sim p_r (\cdot)}logD_θ (x_r) -\sum_{x_f \sim p_g (\cdot)} log (1 - D_θ (x_f))

Цель оптимизации дискриминантной сети D:

θ*=aθrgminxrpr()logDθ(xr)xfpg()log(1Dθ(xf))θ^* = \underset{θ}argmin - \sum_{x_r \sim p_r (\cdot)}logD_θ (x_r) -\sum_{x_f \sim p_g (\cdot)} log (1 - D_θ (x_f))

ПучокminLmin \mathcal Lпреобразовать вmaxLмакс - \ mathcal L:

θ*=aθrgmaxExrpr() logDθ(xr)+Exfpg()log(1Dθ(xf))θ^∗ = \underset{θ}argmax \mathbb E_{x_r \sim p_r (\cdot)}\ logD_θ (x_r) +\mathbb E_{x_f \sim p_g (\cdot)} log (1 − D_θ (x_f) )

для генеративных сетейG(z)G(z), надеемся, что сгенерированные данные смогут обмануть дискриминантную сеть D, поддельные образцыxfx_fЧем ближе выход дискриминационной сети к истинной метке, тем лучше. То есть при обучении генеративной сети предполагается различать выходные данные сети.D(G(z))D(G(z))Чем ближе к 1, тем лучше, минимизироватьD(G(z))D(G(z))Функция кросс-энтропийных потерь между и 1:

mфinL=CE(D(Gф(z)),1)=logD(Gф(z))\underset{φ}min \mathcal L= CE (D (G_φ (z)) , 1) = −logD (G_φ (z))

ПучокminLmin \mathcal Lпреобразовать вmaxLмакс - \ mathcal L:

ф*=aфrgminL=Ezpz()log[1D(Gф(z))]φ ^ * = \ underset {φ} argmin \ mathcal L = \ mathbb E_ {z \ sim p_z (\ cdot)} log [1 - D (G_φ (z))]

вфф– параметры для построения сети G.

Обучите дискриминатор и генератор итеративно в процессе обучения.

Реализация DCGAN

использоватьcifar10Учебный набор DCGAN реализован как обучающий набор GAN.

загрузка данных

нагрузкаcifar10обучающий набор и предварительная обработка данных

#批大小
batch_size = 64
(train_x,_),_ = keras.datasets.cifar10.load_data()
#数据归一化
train_x = train_x / (255. / 2) - 1
print(train_x.shape)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)

Интернет

Сеть состоит из дискриминативной сети и генеративной сети.

различающая сеть

class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator,self).__init__()
        filters = 64
        self.conv1 = keras.layers.Conv2D(filters,4,2,'valid',use_bias=False)
        self.bn1 = keras.layers.BatchNormalization()
        self.conv2 = keras.layers.Conv2D(filters*2,4,2,'valid',use_bias=False)
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2D(filters*4,3,1,'valid',use_bias=False)
        self.bn3 = keras.layers.BatchNormalization()
        self.conv4 = keras.layers.Conv2D(filters*8,3,1,'valid',use_bias=False)
        self.bn4 = keras.layers.BatchNormalization()
        #全局池化
        self.pool = keras.layers.GlobalAveragePooling2D()
        self.flatten = keras.layers.Flatten()
        self.fc = keras.layers.Dense(1)

    def call(self,inputs,training=True):
        x = inputs
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
        x = tf.nn.leaky_relu(self.bn4(self.conv4(x),training=training))
        x = self.pool(x)
        x = self.flatten(x)
        logits = self.fc(x)
        return logits

генерирующая сеть

class Generator(keras.Model):
    def __init__(self):
        super(Generator,self).__init__()
        filters = 64
        self.conv1 = keras.layers.Conv2DTranspose(filters*4,4,1,'valid',use_bias=False)
        self.bn1 = keras.layers.BatchNormalization()
        self.conv2 = keras.layers.Conv2DTranspose(filters*3,4,2,'same',use_bias=False)
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2DTranspose(filters*1,4,2,'same',use_bias=False)
        self.bn3 = keras.layers.BatchNormalization()
        self.conv4 = keras.layers.Conv2DTranspose(3,4,2,'same',use_bias=False)

    def call(self,inputs,training=False):
        x = inputs
        x = tf.reshape(x,(x.shape[0],1,1,x.shape[1]))
        x = tf.nn.relu(x)
        x = tf.nn.relu(self.bn1(self.conv1(x),training=training))
        x = tf.nn.relu(self.bn2(self.conv2(x),training=training))
        x = tf.nn.relu(self.bn3(self.conv3(x),training=training))
        x = self.conv4(x)
        x = tf.tanh(x)
        return x

сетевое обучение

При обучении вы можете обучать дискриминатор несколько раз, а затем один раз обучать генератор.

Определите функцию потерь

def celoss_ones(logits):
    # 计算属于与标签为1的交叉熵
    y = tf.ones_like(logits)
    loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
    return tf.reduce_mean(loss)


def celoss_zeros(logits):
    # 计算属于与标签为0的交叉熵
    y = tf.zeros_like(logits)
    loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
    return tf.reduce_mean(loss)

def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # 计算鉴别器的损失函数
    # 采样生成图片
    fake_image = generator(batch_z, is_training)
    # 判定生成图片
    d_fake_logits = discriminator(fake_image, is_training)
    # 判定真实图片
    d_real_logits = discriminator(batch_x, is_training)
    # 真实图片与1之间的误差
    d_loss_real = celoss_ones(d_real_logits)
    # 生成图片与0之间的误差
    d_loss_fake = celoss_zeros(d_fake_logits)
    # 合并误差
    loss = d_loss_fake + d_loss_real

    return loss


def g_loss_fn(generator, discriminator, batch_z, is_training):
	#计算生成器的损失函数
    # 采样生成图片
    fake_image = generator(batch_z, is_training)
    # 在训练生成网络时,需要迫使生成图片判定为真
    d_fake_logits = discriminator(fake_image, is_training)
    # 计算生成图片与1之间的误差
    loss = celoss_ones(d_fake_logits)

    return loss

Создание сети и оптимизатора

#定义超参数
#潜变量维度
z_dim = 100
#epoch大小
epochs = 300
#批大小
batch_size = 64
#学习率
lr = 0.0002
is_training = True
#实例化网络
discriminator = Discriminator()
discriminator.build(input_shape=(4,32,32,3))
discriminator.summary()
generator = Generator()
generator.build(input_shape=(4,z_dim))
generator.summary()
#实例化优化器
g_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)

тренироваться

#统计损失值
d_losses = []
g_losses = []
for epoch in range(epochs):
    for _,batch_x in enumerate(dataset):
        batch_z = tf.random.normal([batch_size,z_dim])
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)
        grads = tape.gradient(d_loss,discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))
        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator,discriminator,batch_z,is_training)
        grads = tape.gradient(g_loss,generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))

Показать результаты

Тренируясь и тестируя, вы можете получить лучшие результаты, настроив гиперпараметры.

Эффект тренировки на 26 эпох:

训练26epoch的效果