tensorflow学习014——tf.data运用实例

3.2tf.data运用实例

使用tf.data作为输入,改写之前写过的MNIST代码

点击查看代码
import tensorflow as tf
#下载数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
#对图片数据进行归一化
train_images = train_images / 255
test_images = test_images / 255

ds_train_images = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
#zip到一起,为了后面的shuffle,否则image与label的会对应错误
ds_train = tf.data.Dataset.zip((ds_train_images,ds_train_labels))

ds_train  = ds_train.shuffle(10000).repeat().batch(4)

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation= 'softmax')
])
model.compile(optimizer = 'adam',
              loss= 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])
ds_test = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
ds_test = ds_test.batch(4)
steps_per_epoch = train_images.shape[0] / 4 #表明每轮训练多少步,这是因为上面对dataser进行了repeat()所以需要指定每一轮训练多少步
model.fit(ds_train,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=ds_test) 


作者:孙建钊
出处:http://www.cnblogs.com/sunjianzhao/
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

原文地址:https://www.cnblogs.com/sunjianzhao/p/15581513.html