GAN生成式对抗网络(二)——tensorflow代码示例

代码实现

当初学习时,主要学习的这个博客 https://xyang35.github.io/2017/08/22/GAN-1/ ,写的挺好的。

本文目的,用GAN实现最简单的例子,帮助认识GAN算法。

import numpy as np
from matplotlib import pyplot as plt
batch_size = 4

2. 真实数据集,我们要通过GAN学习这个数据集,然后生成和他分布规则一样的数据集

X = np.random.normal(size=(1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
X = np.dot(X, A) + b

plt.scatter(X[:, 0], X[:, 1])
plt.show()


# 等会通过这个函数,不断从中取x值,取值数量为batch_size
def iterate_minibatch(x, batch_size, shuffle=True):
    indices = np.arange(x.shape[0])
    if shuffle:
        np.random.shuffle(indices)

    for i in range(0, x.shape[0], batch_size):
        yield x[indices[i:i + batch_size], :]

图片名称

3.封装GAN对象

包含生成器,判别器

class GAN(object):
    def __init__(self):
        #初始函数,在这里对初始化模型
    def netG(self, z):
        #生成器模型
    def netD(self, x, reuse=False):
        #判别器模型
    

4.生成器netG

随意输入的z,通过z*w+b的矩阵运算(全连接运算),返回结果

    def netG(self, z):
        """1-layer fully connected network"""

        with tf.variable_scope("generator") as scope:
            W = tf.get_variable(name="g_W", shape=[2, 2],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True)
            b = tf.get_variable(name="g_b", shape=[2],
                                initializer=tf.zeros_initializer(),
                                trainable=True)
            return tf.matmul(z, W) + b

5.判别器nefD

判别器为三层全连接网络。隐层部分使用tanh激活函数。输出部分没有激活函数

    def netD(self, x, reuse=False):
        """3-layer fully connected network"""

        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()

            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b1 = tf.get_variable(name="d_b1", shape=[5],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)
            W2 = tf.get_variable(name="d_W2", shape=[5, 3],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b2 = tf.get_variable(name="d_b2", shape=[3],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)
            W3 = tf.get_variable(name="d_W3", shape=[3, 1],
                                 initializer=tf.contrib.layers.xavier_initializer(),
                                 trainable=True)
            b3 = tf.get_variable(name="d_b3", shape=[1],
                                 initializer=tf.zeros_initializer(),
                                 trainable=True)

            layer1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
            layer2 = tf.nn.tanh(tf.matmul(layer1, W2) + b2)
            return tf.matmul(layer2, W3) + b3

6.初始化__init__函数

def __init__(self):
        # input, output
         #占位变量,等会用来保存随机产生的数,
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')   
        #占位变量,真实数据的
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')  

        # define the network
        #生成器,对随机变量进行加工,产生伪造的数据
        self.fake_x = self.netG(self.z)  

         #判别器对真实数据进行判别,返回判别结果
         #reuse=false,表示不是共享变量,需要tensorflow开辟变量地址
        self.real_logits = self.netD(self.x, reuse=False)  

        #判别器对伪造数据进行判别,返回判别结果
         #reuse=true,表示是共享变量,复用netD中已有的变量
        self.fake_logits = self.netD(self.fake_x, reuse=True)


        # define losses
        #判定器的损失值,将真实数据的判定为真实数据,将伪造数据的判断为伪造数据的得分情况
        self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                             labels=tf.ones_like(self.real_logits))) + 
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.zeros_like(self.real_logits)))
        #生成器的生成分数。伪造的数据,别判断器判定为真实数据的得分情况
        self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.ones_like(self.real_logits)))

        # collect variables
        t_vars = tf.trainable_variables()
        #存放判别器中用到的变量
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        #存放生成器中用到的变量
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

7.开始训练

gan = GAN()

#使用随机梯度下降
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    #将数据循环10次
    for epoch in range(10):
        avg_loss = 0.
        count = 0
        #从真实数据集当中,随机抓取batch_size数量个值
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            # generate noise z
            #随机变量,数量为batch_size
            z_batch = np.random.normal(size=(4, 2))

            # update D network
             #将拿到的真实数据值和随机生成的数值,喂养给sess,并bp优化一次
            loss_D, _ = sess.run([gan.loss_D, d_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: x_batch,
                                 })

            # update G network
            loss_G, _ = sess.run([gan.loss_G, g_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: np.zeros(z_batch.shape),  # dummy input
                                 })

            avg_loss += loss_D
            count += 1

        avg_loss /= count
        #每一个epoch都展示一次生成效果
        z = np.random.normal(size=(100, 2))
        # 随机生成100个数值,0到1000---用来从真实值里面取数据
        excerpt = np.random.randint(1000, size=1000)
        fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                    feed_dict={gan.z: z, gan.x: X[excerpt, :]})
        accuracy = 0.5 * (np.sum(real_logits > 0.5) / 100. + np.sum(fake_logits < 0.5) / 100.)
        print('
discriminator loss at epoch %d: %f' % (epoch, avg_loss))
        print('
discriminator accuracy at epoch %d: %f' % (epoch, accuracy))
        plt.scatter(X[:, 0], X[:, 1])
        plt.scatter(fake_x[:, 0], fake_x[:, 1])
        plt.show()


效果

完整代码下载

欢迎转载,转载请注明出处。欢迎沟通交流: panfengqqs@qq.com)

原文地址:https://www.cnblogs.com/panfengde/p/10020224.html