4-5 Cifar10数据集解析

import glob
import os

import numpy as np
import cv2

classification=[
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck']

def unpick(file):#这是cifar10官网提供的解压函数
    import pickle
    with open(file,'rb') as fo:
        dict=pickle.load(fo,encoding='bytes')
    return dict


folders='/home/ubuntu/WorkPlace/data_manager/data/cifar-10-batches-py'#cifar10源数据集
trfiles=glob.glob(folders+'/data_batch*')#获取训练样本的地址

data=[]
labels=[]
for file in trfiles:#各小包解压后数据存在data中,label存在labels中
    dt=unpick(file)
    data+=list(dt[b'data'])
    labels+=list(dt[b'labels'])
print(labels)

#讲数据转换为4维度的数据(也就是直观的图片),cifar中图片32*32
imgs=np.reshape(data,[-1,3,32,32])#-1代表自动获取data的数量
for i in range(imgs.shape[0]):#shape[0]代表图片总量
    im_data=imgs[i,...]
    im_data=np.transpose(im_data,[1,2,0])#维度转换应为opencv非通道优先顺序存储
    im_data=cv2.cvtColor(im_data,cv2.COLOR_RGB2BGR)#cv非RGB格式

    f='{}/{}'.format('data/image/train',classification[labels[i]])#即将储存图片,这里定义每个图片的存放地址

    if not os.path.exists(f):#判断路径是否存在
        os.mkdir(f)
    cv2.imwrite('{}/{}.jpg'.format(f,str(i)),im_data)
原文地址:https://www.cnblogs.com/thgpddl/p/12843536.html