提取Mnist数据集图片

  • 这里有已经提取好的图片:百度云共享
  • 主要依赖于opencv
  • 代码
#include<cstdio>
#include<iostream>
#include<string>
#include<map>
#include<opencv2/opencv.hpp>

using namespace std;

#define PRINT(x) cout<<#x<<": "<< x << endl

inline int to_int (const void*data)
{
    unsigned char* ptr = (unsigned char*) data;
    int result = 0;
    
    for (int i = 3; i >= 0; i--)
    {
        result +=  ptr[i] * pow (256, 3 - i);
    }
    
    return result;
}

int main()
{
    const string mnist_data_path (R"(F:libscaffedatamnist	10k-images-idx3-ubyte)");
    const string mnist_label_path (R"(F:libscaffedatamnist	10k-labels-idx1-ubyte)");
    const string dst_img_dir (R"(F:libscaffedatamnist	est_10000)");
    FILE * f_data_handle = fopen (mnist_data_path.c_str(), "rb");
    FILE * f_label_handle = fopen (mnist_label_path.c_str(), "rb");
    fseek (f_label_handle, 8, SEEK_SET);
    char buff[4];
    fread ( (void*) &buff, 4, 1, f_data_handle);
    int magic_num = to_int (buff);
    PRINT (magic_num);
    fread ( (void*) &buff, 4, 1, f_data_handle);
    int num_imgs = to_int (buff);
    PRINT (num_imgs);
    fread ( (void*) &buff, 4, 1, f_data_handle);
    int num_rows = to_int (buff);
    PRINT (num_rows);
    fread ( (void*) &buff, 4, 1, f_data_handle);
    int num_cols = to_int (buff);
    PRINT (num_cols);
    cv::Mat img_buf (num_rows, num_cols, CV_8U, cv::Scalar (0));
    assert (img_buf.isContinuous());
    unsigned char* pixel_ptr = img_buf.ptr<unsigned char> (0);
    int count_all = -1;
    map<unsigned char, int > count_one;
    
    for (int i = 0; i < num_imgs; i++)
    {
        if (i % 5000 == 0) { cout << i << endl; }
        
        unsigned char label = fgetc (f_label_handle);
        count_all++;
        count_one[label]++;
        
        for (int j = 0; j < num_rows * num_cols; j++)
        {
            pixel_ptr[j] = fgetc (f_data_handle);
        }
        
        cv::imwrite (dst_img_dir + "\" + (to_string (label) + "_" + to_string (count_one[label]) + "_" + to_string (count_all) + ".jpg"), img_buf);
    }
    
    count_all++;
    PRINT (count_all);
#ifdef _MSC_VER
    system ("pause");
#endif // _MSC_V
}
原文地址:https://www.cnblogs.com/jiahu-Blog/p/7883097.html