学习进度笔记32

观看Tensorflow案例实战视频课程22 迭代及测试网络效果

#生成一个训练batch
def get_next_batch(batch_size=128):
    batch_x=np.zeros([batch_size,IMAGE_HEIGHT*IMAGE_WIDTH])
    batch_y=np.zeros([batch_size,MAX_CAPTCHA*CHAR_SET_LEN])

    #有时生成图像大小不是(60,160,3)
    def wrap_gen_captcha_text_and_image():
        while True:
            text,image=gen_captcha_text_and_image()
            if image.shape==(60,160,3):
                return text,image

    for i in range(batch_size):
        text,image=wrap_gen_captcha_text_and_image()
        image=convert2gray(image)

        batch_x[i,:]=image.flatten()/255#(image.flatten()-128)/128 mean为0
        batch_y[i,:]=text2vec(text)

        return batch_x,batch_y

# 训练
def train_crack_captcha_cnn():
    output=crack_captcha_cnn()
    loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output,Y))
    optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
    predict=tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN])
    max_idx_p=tf.argmax(predict,2)
    max_idx_l=tf.argmax(tf.reshape(Y,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
    correct_pred=tf.equal(max_idx_p,max_idx_l)
    accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))

    saver=tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        step=0
        while True:
            batch_x,batch_y=get_next_batch(64)
            _,loss_=sess.run([optimizer,loss],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
            print(step,loss_)

            # 每100 step计算一次准确率
            if step%100==0:
                batch_x_test,batch_y_test=get_next_batch(100)
                acc=sess.run(accuracy,feed_dict={X:batch_x_test,Y:batch_y_test,keep_prob:1.})
                print(step,acc)
                # 如果准确率大于50%,保存模型,完成训练
                if acc>0.85:
                    saver.save(sess,"./model/crack_captcha.model",global_step=step)
                    break

            step+=1
def crack_captcha(captcha_image):
    output=crack_captcha_cnn()

    saver=tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,"./model/crack_captcha.model-1500")

        predict=tf.argmax(tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
        text_list=sess.run(predict,feed_dict={X:[captcha_image],keep_prob:1})
        text=text_list[0].tolist()
        return text
if __name__=='__main__':
    #train=0
    train = 1
    if train==0:
        number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        #alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v', 'w', 'x', 'y', 'z']
        #ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V', 'W', 'X', 'Y', 'Z']

        text,image=gen_captcha_text_and_image()
        print("验证码图像channel:",image.shape)#(60,160,3)
        #图像大小
        IMAGE_HEIGHT=60
        IMAGE_WIDTH=160
        MAX_CAPTCHA=len(text)
        print("验证码文本最长字符数",MAX_CAPTCHA)
        #文本转向量
        #char_set=number+alphabet+ALPHACET+['_']#如果验证码长度小于4,'_'用来补充
        char_set=number
        CHAR_SET_LEN=len(char_set)

        X=tf.placeholder(tf.float32,[None,IMAGE_HEIGHT*IMAGE_WIDTH])
        Y=tf.placeholder(tf.float32,[None,MAX_CAPTCHA*CHAR_SET_LEN])
        keep_prob=tf.placeholder(tf.float32)# dropout

        train_crack_captcha_cnn()
    if train==1:
        number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        IMAGE_HEIGHT=60
        IMAGE_WIDTH=160
        char_set=number
        CHAR_SET_LEN=len(char_set)

        text,image=gen_captcha_text_and_image()

        f=plt.figure()
        ax=f.add_wuplot(111)
        ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
        plt.imshow(image)

        plt.show()

        MAX_CAPTCHA=len(text)
        image=convert2gray(image)
        image=image.flatten()/255

        X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH])
        Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN])
        keep_prob = tf.placeholder(tf.float32)  # dropout

        predict_text=crack_captcha(image)
        print("正确:() 预测:()".format(text,predict_text))
原文地址:https://www.cnblogs.com/zql-42/p/14632781.html