转自:https://blog.csdn.net/weixin_41923961/article/details/80715321
1.条件GAN
生成器和判别器都增加额外信息y为条件, y可以是任意信息,例如类别信息,或者其他模态的数据。
通过将额外信息y输送给判别模型和生成模型,作为输入层的一部分,从而实现条件GAN。
在生成模型中,先验输入噪声p(z)和条件信息y联合组成了联合隐层表征。
条件GAN的目标函数是带有条件概率的二人极小极大值博弈(two-player minimax game ):
比如应用在条件生成MNIST手写数据集上。
2.cGAN中y的不同引入方式
- y直接和x拼接输入;
- 和x的中间层输入拼接;
- 额外计算一个生成损失;
- 和中间变量作为做内积和输出拼接在一起
https://blog.csdn.net/qq_34914551/article/details/90732874
判别器部分代码实现:
def forward(self, x, y=None): h = x h = self.block1(h) h = self.block2(h) h = self.block3(h) h = self.block4(h) h = self.block5(h) h = self.activation(h) # Global pooling h = torch.sum(h, dim=(2, 3)) output = self.l6(h) if y is not None: output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True) return output
y和倒数第二层做内积之后和输出拼接在一起,作为输出。