第二十一节,条件变分自编码

一 条件变分自编码(CVAE)

变分自编码存在一个问题,虽然可以生成一个样本,但是只能输出与输入图片相同类别的样本。虽然也可以随机从符合模型生成的高斯分布中取数据来还原成样本,但是这样的话饿哦们并不知道生成的样本属于哪个类别。条件变分编码则可以解决这个问题,让网络按指定的类别生成样本。

在变分自编码的基础上,再取理解条件编码自编码会很容易。主要的改动是,在训练测试时加入一个one-hot向量,用于表示标签向量。其实就是给编码自编码网络加入一个条件,让网络学习图片时加入标签因素,这样就可以按照指定的标签生成图片。 

二 CVAE实例 

在编码节点需要在输入端添加标签对应的特征,在解码阶段同样也需要将标签加入输入,这样,再解码的结果向原始的输入样本不断逼近,最终得到的模型会把输入的标签的特征当成MNIST数据的一部分,从而实现通过标签生成指定的图片。

 该程序在上一节程序上作了一些修改,代码如下:

# -*- coding: utf-8 -*-
"""
Created on Thu May 31 15:34:08 2018

@author: zy
"""

'''
条件变分自编码
'''


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('MNIST-data',one_hot=True)

print(type(mnist)) #<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>

print('Training data shape:',mnist.train.images.shape)           #Training data shape: (55000, 784)
print('Test data shape:',mnist.test.images.shape)                #Test data shape: (10000, 784)
print('Validation data shape:',mnist.validation.images.shape)    #Validation data shape: (5000, 784)
print('Training label shape:',mnist.train.labels.shape)          #Training label shape: (55000, 10)

train_X = mnist.train.images
train_Y = mnist.train.labels
test_X = mnist.test.images
test_Y = mnist.test.labels


'''
定义网络参数
'''
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2
n_classes = 10
learning_rate = 0.001
training_epochs = 20               #迭代轮数
batch_size = 128                   #小批量数量大小
display_epoch = 3
show_num = 10

x = tf.placeholder(dtype=tf.float32,shape=[None,n_input])
y = tf.placeholder(dtype=tf.float32,shape=[None,n_classes])
#后面通过它输入分布数据,用来生成模拟样本数据
zinput = tf.placeholder(dtype=tf.float32,shape=[None,n_hidden_2])


'''
定义学习参数
'''
weights = {
        'w1':tf.Variable(tf.truncated_normal([n_input,n_hidden_1],stddev = 0.001)),
        'w_lab1':tf.Variable(tf.truncated_normal([n_classes,n_hidden_1],stddev = 0.001)),
        'mean_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
        'log_sigma_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
        'w2':tf.Variable(tf.truncated_normal([n_hidden_2+n_classes,n_hidden_1],stddev = 0.001)),
        'w3':tf.Variable(tf.truncated_normal([n_hidden_1,n_input],stddev = 0.001))
        }

biases = {
        'b1':tf.Variable(tf.zeros([n_hidden_1])),
        'b_lab1':tf.Variable(tf.zeros([n_hidden_1])),
        'mean_b1':tf.Variable(tf.zeros([n_hidden_2])),
        'log_sigma_b1':tf.Variable(tf.zeros([n_hidden_2])),
        'b2':tf.Variable(tf.zeros([n_hidden_1])),
        'b3':tf.Variable(tf.zeros([n_input]))
        }


'''
定义网络结构
'''
#第一个全连接层是由784个维度的输入样->256个维度的输出
h1 = tf.nn.relu(tf.add(tf.matmul(x,weights['w1']),biases['b1']))
#输入标签
h_lab1 = tf.nn.relu(tf.add(tf.matmul(y,weights['w_lab1']),biases['b_lab1']))
#合并
hall1 = tf.concat([h1,h_lab1],1)

#第二个全连接层并列了两个输出网络
z_mean = tf.add(tf.matmul(hall1,weights['mean_w1']),biases['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1,weights['log_sigma_w1']),biases['log_sigma_b1'])


#然后将两个输出通过一个公式的计算,输入到以一个2节点为开始的解码部分 高斯分布样本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0],n_hidden_2]),0,1,dtype=tf.float32)
z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)),eps))
#合并
zall = tf.concat([z,y],1)    #None x 12


#解码器 由12个维度的输入->256个维度的输出
h2 = tf.nn.relu(tf.matmul(zall,weights['w2']) + biases['b2'])
#解码器 由256个维度的输入->784个维度的输出  即还原成原始输入数据
reconstruction = tf.matmul(h2,weights['w3']) + biases['b3']


#这两个节点不属于训练中的结构,是为了生成指定数据时用的
zinputall = tf.concat([zinput,y],1)
h2out = tf.nn.relu(tf.matmul(zinputall,weights['w2']) + biases['b2'])
reconstructionout = tf.matmul(h2out,weights['w3']) + biases['b3']

'''
构建模型的反向传播
'''
#计算重建loss
#计算原始数据和重构数据之间的损失,这里除了使用平方差代价函数,也可以使用交叉熵代价函数  
reconstr_loss = 0.5*tf.reduce_sum((reconstruction-x)**2)
print(reconstr_loss.shape)    #(,) 标量
#使用KL离散度的公式
latent_loss = -0.5*tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq),1)
print(latent_loss.shape)      #(128,)
cost = tf.reduce_mean(reconstr_loss+latent_loss)


#定义优化器    
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

num_batch = int(np.ceil(mnist.train.num_examples / batch_size))

'''
开始训练
'''
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    print('开始训练')
    for epoch in range(training_epochs):
        total_cost = 0.0
        for i in range(num_batch):
            batch_x,batch_y = mnist.train.next_batch(batch_size)            
            _,loss = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y})
            total_cost += loss
            
        #打印信息
        if epoch % display_epoch == 0:
            print('Epoch {}/{}  average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
                        
    print('训练完成')
    
    #测试
    print('Result:',cost.eval({x:mnist.test.images,y:mnist.test.labels}))
    #数据可视化   根据原始图片生成自编码数据                  
    reconstruction = sess.run(reconstruction,feed_dict = {x:mnist.test.images[:show_num],y:mnist.test.labels[:show_num]})
    plt.figure(figsize=(1.0*show_num,1*2))        
    for i in range(show_num):
        #原始图像
        plt.subplot(2,show_num,i+1)            
        plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
        plt.axis('off')
           
        #变分自编码器重构图像
        plt.subplot(2,show_num,i+show_num+1)
        plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
        plt.axis('off')
    plt.show()
    

        
    '''
    高斯分布取样,根据标签生成模拟数据
    '''        
    z_sample = np.random.randn(show_num,2)
    reconstruction = sess.run(reconstructionout,feed_dict={zinput:z_sample,y:mnist.test.labels[:show_num]})    
    plt.figure(figsize=(1.0*show_num,1*2))        
    for i in range(show_num):
        #原始图像
        plt.subplot(2,show_num,i+1)            
        plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
        plt.axis('off')
           
        #根据标签成成模拟数据
        plt.subplot(2,show_num,i+show_num+1)
        plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
        plt.axis('off')
    plt.show()
    

上面第一幅图是根据原始图片生成的自编码数据,第一行为原始数据,第二行为自编码数据,该数据仍然保留一些原始图片的特征。

第二幅图片是利用样本数据的标签和高斯分布之z_sample一起生成的模拟数据,我们可以看到通过标签生成的数据,已经彻底学会了样本数据的分布,并生成了与输入截然不同但带有相同意义的数据。

原文地址:https://www.cnblogs.com/zyly/p/9123443.html