将mnist数据集存储到本地文件

import os
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from matplotlib.image import imsave
import itertools

# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print("X_train original shape", X_train.shape)
print("y_train original shape", y_train.shape)

for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(X_train[i], cmap='gray', interpolation='none')
    plt.title("Class {}".format(y_train[i]))

train_path = './MNIST_data/train'
test_path = './MNIST_data/test'

image_counter = itertools.count(0)
for image, label in zip(X_train, y_train):
    dest_folder = os.path.join(train_path, str(label))
    image_name = next(image_counter)
    image_path = os.path.join(dest_folder, str(image_name) + '.png')

    if not os.path.exists(dest_folder):
        os.mkdir(dest_folder)

    imsave(image_path, image, cmap='gray')

for image, label in zip(X_test, y_test):
    dest_folder = os.path.join(test_path, str(label))
    image_name = next(image_counter)
    image_path = os.path.join(dest_folder, str(image_name) + '.png')

    if not os.path.exists(dest_folder):
        os.mkdir(dest_folder)

    imsave(image_path, image, cmap='gray')

  

原文地址:https://www.cnblogs.com/php-linux/p/11948928.html