AAE对抗自编码器

译自:https://hjweide.github.io/adversarial-autoencoders

1.自编码器AE作为生成模型

我们已经简要提到过,编码器输出的属性使我们能够将输入数据转换为有用的表示形式。在使用变分自动编码器的情况下,解码器已受过训练,可以从类似于我们选择的先验样本的样本中重建输入。因此,我们可以从此先验分布中采样数据点,并将其馈送到解码器中,以在原始数据空间中重建逼真的外观数据点。

不幸的是,变分自动编码器通常会在先验分布的空间中留下一些区域,这些区域不会映射到数据中的实际样本。对抗性自动编码器旨在通过鼓励编码器的输出完全填充先验分布的空间来改善此情况,从而使解码器能够从先验采样的任何数据点生成逼真的样本。对抗性自动编码器通过使用两个新组件,即鉴别器和生成器,来代替使用变分推理。接下来讨论这些。

2.训练更新过程

https://zhuanlan.zhihu.com/p/68903857

对抗自编码器的网络结构主要分成两大部分:自编码部分(上半部分)、GAN判别网络(下半部分)。整个框架也就是GAN和AutoEncoder框架二者的结合。训练过程分成两个阶段:首先是样本重构阶段,通过梯度下降更新自编码器encoder部分、以及decoder的参数、使得重构损失函数最小化;然后是正则化约束阶段,交替更新判别网络参数和生成网络(encoder部分)参数以此提高encoder部分混淆判别网络的能力。

下面这张图片似乎更加清晰:

图片来自:https://towardsdatascience.com/a-wizards-guide-to-adversarial-autoencoders-part-2-exploring-latent-space-with-adversarial-2d53a6f8a4f9

上面链接中比较清楚地讲解了AAE两阶段的训练过程: 

首先是重建的部分,这部分和正常的AE没有什么差别。下面是正则化部分:

首先,我们训练判别器对编码器输出(z)和一些随机输入(z’,目标分布)进行分类。 例如,随机输入可以正态分布,平均值为0,标准差为5。

因此,如果我们传入具有所需分布(真实值)的随机输入,则鉴别器应给我们输出1;而当我们传入编码器输出时,鉴别器应给我们输出0(伪值)。 直观地,编码器的输出和判别器的随机输入都应具有相同的size。
下一步将是强制编码器输出具有所需分布的隐码。 为此,我们将编码器输出作为输入连接到鉴别器:

 在AAE中编码部分相当于GAN的生成器。

我们将判别器的权重固定为当前的权重(使它们无法训练),并在判别器的输出端将目标固定为1。 稍后,我们将图像传递到编码器,并确定判别器输出,然后将其用于查找损失(交叉熵代价函数)。我们将仅通过编码器权重进行反向传播,这会导致编码器学习所需的分布并产生具有该分布的输出(将判别器目标固定为1会使得编码器通过查看判别器权重来学习所需的分布 )。

4.GAN与VAE的区别

转自:https://www.zhihu.com/question/317623081

一个本质区别就是loss的区别

VAE是pointwise loss 点匹配,一个典型的特征就是pointwise loss常常会脱离数据流形面,因此看起来生成的图片会模糊;

GAN是分布匹配的loss,更能贴近流行面,看起来就会清晰;

但分布匹配的难度较大,一个例子就是经常发生mode collapse问题,小分布丢失,而pointwise loss就没有这个问题,可以用于做初始化或做纠正,因此发展了一系列GAN+VAE的工作。

VAE希望通过一种显式(explicit)的方法找到一个概率密度,并通过最小化对数似函数的下限来得到最优解;GAN则是对抗的方式来寻找一种平衡,不需要认为给定一个显式的概率密度函数。

5.AAE的训练例子

https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

首先定义模型和各部分优化器:

torch.manual_seed(10)
Q, P = Q_net() = Q_net(), P_net(0)     # Encoder&Decoder
D_gauss = D_net_gauss()                # Discriminator adversarial
# Set optimizators
#AE部分:
P_decoder = optim.Adam(P.parameters(), lr=gen_lr)
Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)

#GAN部分:
Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)#为Encoder又定义了一个优化器,作为生成器的更新
D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)

包括三个部分的优化:

AE部分编码和解码器的优化:

    z_sample = Q(X)
    X_sample = P(z_sample)
    recon_loss = F.binary_cross_entropy(X_sample + TINY, 
                                        X.resize(train_batch_size, X_dim) + TINY)
    recon_loss.backward()
    P_decoder.step()//Enocder
    Q_encoder.step()//Decoder

 判别器D的优化:

    # Compute discriminator outputs and loss
    D_real_gauss, D_fake_gauss = D_gauss(z_real_gauss), D_gauss(z_fake_gauss)
    D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))
    D_loss.backward()       # Backpropagate loss
    D_gauss_solver.step()   # D判别器的优化

生成器(即Encoder的优化):

# Generator
Q.train()   # Back to use dropout
z_fake_gauss = Q(X)
D_fake_gauss = D_gauss(z_fake_gauss)

G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))
G_loss.backward()
Q_generator.step()#优化Ecnoder,即生成器

 6.AAE示例代码

https://github.com/bfarzin/pytorch_aae/blob/master/main_aae.py#L114

https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

https://github.com/shidilrzf/Adversarial-Autoencoders/blob/master/train.py#L147

原文地址:https://www.cnblogs.com/BlueBlueSea/p/13149398.html