理解GAN对抗神经网络的损失函数和训练过程

GAN最不好理解的就是Loss函数的定义和训练过程,这里用一段代码来辅助理解,就能明白到底是怎么回事。其实GAN的损失函数并没有特殊之处,就是常用的binary_crossentropy,关键在于训练过程中存在两个神经网络和两个损失函数。

np.random.seed(42)
tf.random.set_seed(42)

codings_size = 30

generator = keras.models.Sequential([
    keras.layers.Dense(100, activation="selu", input_shape=[codings_size]),
    keras.layers.Dense(150, activation="selu"),
    keras.layers.Dense(28 * 28, activation="sigmoid"),
    keras.layers.Reshape([28, 28])
])
discriminator = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(150, activation="selu"),
    keras.layers.Dense(100, activation="selu"),
    keras.layers.Dense(1, activation="sigmoid")
])
gan = keras.models.Sequential([generator, discriminator])

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

这里generator并不用compile,因为gan网络已经compile了。具体原因见下文。

训练过程的代码如下

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))              # not shown in the book
        for X_batch in dataset:
            # phase 1 - training the discriminator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)
        plot_multiple_images(generated_images, 8)                     # not shown
        plt.show()                                                    # not shown

第一阶段(discriminator训练)

# phase 1 - training the discriminator
noise = tf.random.normal(shape=[batch_size, codings_size])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)

这个阶段首先生成数量相同的真实图片和假图片,concat在一起,即X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)。然后是label,真图片的label是1,假图片的label是0。

然后是迅速阶段,首先将discrinimator设置为可训练,discriminator.trainable = True,然后开始阶段。第一个阶段的训练过程只训练discriminator,discriminator.train_on_batch(X_fake_and_real, y1),而不是整个GAN网络gan

第二阶段(generator训练)

# phase 2 - training the generator
noise = tf.random.normal(shape=[batch_size, codings_size])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)

在第二阶段首先生成假图片,但是不再生成真图片。把假图片的label全部设置为1,并把discriminator的权重冻结,即discriminator.trainable = False。这一步很关键,应该这么理解:

前面第一阶段的是discriminator的训练,使真图片的预测值尽量接近1,假图片的预测值尽量接近0,以此来达到优化损失函数的目的。现在将discrinimator的权重冻结,网络中输入假图片,并故意把label设置为1。

注意,在整个gan网络中,从上向下的顺序是先通过geneartor,再通过discriminator,即gan = keras.models.Sequential([generator, discriminator])。第二个阶段将discrinimator冻结,并训练网络gan.train_on_batch(noise, y2)。如果generator生成的图片足够真实,经过discrinimator后label会尽可能接近1。由于故意把y2的label设置为1,所以如果genrator生成的图片足够真实,此时generator训练已经达到最优状态,不会大幅度更新权重;如果genrator生成的图片不够真实,经过discriminator之后,预测值会接近0,由于y2的label是1,相当于预测值不准确,这时候gan网络的损失函数较大,generator会通过更新generator的权重来降低损失函数。

之后,重新回到第一阶段训练discriminator,然后第二阶段训练generator。假设整个GAN网络达到理想状态,这时候generator产生的假图片,经过discriminator之后,预测值应该是0.5。假如这个值小于0.5,证明generator不是特别准确,在第二阶段训练过程中,generator的权重会被继续更新。假如这个值大于0.5,证明discriminator不是特别准确,在第一阶段训练中,discriminator的权征会被继续更新。

简单说,对于一张generator生成的假图片,discriminator会尽量把预测值拉下拉,generator会尽量把预测值往上扯,类似一个拔河的过程,最后达到均衡状态,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。

原文地址:https://www.cnblogs.com/yaos/p/14014139.html