tensorflow学习笔记6

Mnist数据集简介2

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import input_data

print("packs loaded")

print("Download and Extract MNIST dataset")
mnist = input_data.read_data_sets('data/',one_hot=True) #one_hot=True编码格式为01编码
print
print("type of 'mnist' is %s" % (type(mnist)))
print("number of train data is %d" % (mnist.train.num_examples))
print("number of test data is %d" % (mnist.test.num_examples))

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

#初步看一下数据集的样子
nsample = 5
randidx = np.random.randint(trainimg.shape[0],size=nsample)

for i in randidx:
    curr_img = np.reshape(trainimg[i,:],(28,28))
    curr_label = np.argmax(trainlabel[i,:])
    plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
    plt.show()

原文地址:https://www.cnblogs.com/xrj-/p/14456123.html