Come leggere MNIST dati in C++?

Sto avendo difficoltà a leggere il MNIST banca dati di cifre scritte a mano in C++.

È in un formato binario, che so leggere, ma non so il formato esatto di MNIST.

Pertanto, vorrei chiedere alle persone che hanno letto il MNIST dati sul formato di MNIST dati e avete qualche suggerimento per come leggere questi dati in C++?

 

5 Replies
  1. 12

    Ho fatto alcuni lavori con il MNIST i dati recentemente. Ecco il codice che ho scritto in Java che dovrebbe essere abbastanza facile per voi per porta:

    import net.vivin.digit.DigitImage;    
    import java.io.ByteArrayOutputStream;
    import java.io.IOException;
    import java.io.InputStream;
    import java.nio.ByteBuffer;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    /**
     * Created by IntelliJ IDEA.
     * User: vivin
     * Date: 11/11/11
     * Time: 10:07 AM
     */
    public class DigitImageLoadingService {
    
        private String labelFileName;
        private String imageFileName;
    
        /** the following constants are defined as per the values described at http://yann.lecun.com/exdb/mnist/**/
    
        private static final int MAGIC_OFFSET = 0;
        private static final int OFFSET_SIZE = 4; //in bytes
    
        private static final int LABEL_MAGIC = 2049;
        private static final int IMAGE_MAGIC = 2051;
    
        private static final int NUMBER_ITEMS_OFFSET = 4;
        private static final int ITEMS_SIZE = 4;
    
        private static final int NUMBER_OF_ROWS_OFFSET = 8;
        private static final int ROWS_SIZE = 4;
        public static final int ROWS = 28;
    
        private static final int NUMBER_OF_COLUMNS_OFFSET = 12;
        private static final int COLUMNS_SIZE = 4;
        public static final int COLUMNS = 28;
    
        private static final int IMAGE_OFFSET = 16;
        private static final int IMAGE_SIZE = ROWS * COLUMNS;
    
    
        public DigitImageLoadingService(String labelFileName, String imageFileName) {
            this.labelFileName = labelFileName;
            this.imageFileName = imageFileName;
        }
    
        public List<DigitImage> loadDigitImages() throws IOException {
            List<DigitImage> images = new ArrayList<DigitImage>();
    
            ByteArrayOutputStream labelBuffer = new ByteArrayOutputStream();
            ByteArrayOutputStream imageBuffer = new ByteArrayOutputStream();
    
            InputStream labelInputStream = this.getClass().getResourceAsStream(labelFileName);
            InputStream imageInputStream = this.getClass().getResourceAsStream(imageFileName);
    
            int read;
            byte[] buffer = new byte[16384];
    
            while((read = labelInputStream.read(buffer, 0, buffer.length)) != -1) {
               labelBuffer.write(buffer, 0, read);
            }
    
            labelBuffer.flush();
    
            while((read = imageInputStream.read(buffer, 0, buffer.length)) != -1) {
                imageBuffer.write(buffer, 0, read);
            }
    
            imageBuffer.flush();
    
            byte[] labelBytes = labelBuffer.toByteArray();
            byte[] imageBytes = imageBuffer.toByteArray();
    
            byte[] labelMagic = Arrays.copyOfRange(labelBytes, 0, OFFSET_SIZE);
            byte[] imageMagic = Arrays.copyOfRange(imageBytes, 0, OFFSET_SIZE);
    
            if(ByteBuffer.wrap(labelMagic).getInt() != LABEL_MAGIC)  {
                throw new IOException("Bad magic number in label file!");
            }
    
            if(ByteBuffer.wrap(imageMagic).getInt() != IMAGE_MAGIC) {
                throw new IOException("Bad magic number in image file!");
            }
    
            int numberOfLabels = ByteBuffer.wrap(Arrays.copyOfRange(labelBytes, NUMBER_ITEMS_OFFSET, NUMBER_ITEMS_OFFSET + ITEMS_SIZE)).getInt();
            int numberOfImages = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_ITEMS_OFFSET, NUMBER_ITEMS_OFFSET + ITEMS_SIZE)).getInt();
    
            if(numberOfImages != numberOfLabels) {
                throw new IOException("The number of labels and images do not match!");
            }
    
            int numRows = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_OF_ROWS_OFFSET, NUMBER_OF_ROWS_OFFSET + ROWS_SIZE)).getInt();
            int numCols = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_OF_COLUMNS_OFFSET, NUMBER_OF_COLUMNS_OFFSET + COLUMNS_SIZE)).getInt();
    
            if(numRows != ROWS && numRows != COLUMNS) {
                throw new IOException("Bad image. Rows and columns do not equal " + ROWS + "x" + COLUMNS);
            }
    
            for(int i = 0; i < numberOfLabels; i++) {
                int label = labelBytes[OFFSET_SIZE + ITEMS_SIZE + i];
                byte[] imageData = Arrays.copyOfRange(imageBytes, (i * IMAGE_SIZE) + IMAGE_OFFSET, (i * IMAGE_SIZE) + IMAGE_OFFSET + IMAGE_SIZE);
    
                images.add(new DigitImage(label, imageData));
            }
    
            return images;
        }
    }
    • Dove è DigitImage ? E numRows != COLUMNS non mi sembra proprio…
  2. 22
    int reverseInt (int i) 
    {
        unsigned char c1, c2, c3, c4;
    
        c1 = i & 255;
        c2 = (i >> 8) & 255;
        c3 = (i >> 16) & 255;
        c4 = (i >> 24) & 255;
    
        return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
    }
    void read_mnist(/*string full_path*/)
    {
        ifstream file (/*full_path*/"t10k-images-idx3-ubyte.gz");
        if (file.is_open())
        {
            int magic_number=0;
            int number_of_images=0;
            int n_rows=0;
            int n_cols=0;
            file.read((char*)&magic_number,sizeof(magic_number)); 
            magic_number= reverseInt(magic_number);
            file.read((char*)&number_of_images,sizeof(number_of_images));
            number_of_images= reverseInt(number_of_images);
            file.read((char*)&n_rows,sizeof(n_rows));
            n_rows= reverseInt(n_rows);
            file.read((char*)&n_cols,sizeof(n_cols));
            n_cols= reverseInt(n_cols);
            for(int i=0;i<number_of_images;++i)
            {
                for(int r=0;r<n_rows;++r)
                {
                    for(int c=0;c<n_cols;++c)
                    {
                        unsigned char temp=0;
                        file.read((char*)&temp,sizeof(temp));
    
                    }
                }
            }
        }
    }
    • Nota per salvare gli altri da duplicare il mio stupido errore: anche se i nomi dei file in questa risposta, hanno una “.gz” estensione, è, infatti, necessario che il file non compresso prima di essere usate. (C’è un commento sul sito originale che potrebbe essere interpretato come suggeriscono il contrario.) Se i primi quattro byte del file sono 0x1f8b0808 invece di 0x00000801 o 0x00000803, il che significa che il file non è stato decompresso.
  3. 3

    Per quel che vale, ho sintonizzato il @mrgloom codice:

    Per la lettura di immagini del set di dati:

    uchar** read_mnist_images(string full_path, int& number_of_images, int& image_size) {
        auto reverseInt = [](int i) {
            unsigned char c1, c2, c3, c4;
            c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
            return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
        };
    
        typedef unsigned char uchar;
    
        ifstream file(full_path, ios::binary);
    
        if(file.is_open()) {
            int magic_number = 0, n_rows = 0, n_cols = 0;
    
            file.read((char *)&magic_number, sizeof(magic_number));
            magic_number = reverseInt(magic_number);
    
            if(magic_number != 2051) throw runtime_error("Invalid MNIST image file!");
    
            file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
            file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
            file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
    
            image_size = n_rows * n_cols;
    
            uchar** _dataset = new uchar*[number_of_images];
            for(int i = 0; i < number_of_images; i++) {
                _dataset[i] = new uchar[image_size];
                file.read((char *)_dataset[i], image_size);
            }
            return _dataset;
        } else {
            throw runtime_error("Cannot open file `" + full_path + "`!");
        }
    }

    Per la lettura di etichette di set di dati:

    uchar* read_mnist_labels(string full_path, int& number_of_labels) {
        auto reverseInt = [](int i) {
            unsigned char c1, c2, c3, c4;
            c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
            return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
        };
    
        typedef unsigned char uchar;
    
        ifstream file(full_path, ios::binary);
    
        if(file.is_open()) {
            int magic_number = 0;
            file.read((char *)&magic_number, sizeof(magic_number));
            magic_number = reverseInt(magic_number);
    
            if(magic_number != 2049) throw runtime_error("Invalid MNIST label file!");
    
            file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
    
            uchar* _dataset = new uchar[number_of_labels];
            for(int i = 0; i < number_of_labels; i++) {
                file.read((char*)&_dataset[i], 1);
            }
            return _dataset;
        } else {
            throw runtime_error("Unable to open file `" + full_path + "`!");
        }
    }

    MODIFICA:
    Grazie a @Jürgen Brauer per ricordare a me per correggere la mia risposta, anche se io l’ho risolto nel mio codice indietro nel tempo, ma si è dimenticato di aggiornare la risposta.

    • Ho stampato i valori di un singolo 28×28 cifre ed erano tutti quasi 0. Ho scoperto, che è di importazione di scrivere ifstream file(full_path, ios::binary); e non ifstream file(full_path); nel codice di cui sopra. Speranza che aiuta a impedire ad altri di perdere tempo.
  4. 0

    Il codice qui sotto è da caffe,e devo fare qualche modifica e la conversione di cv::Mat:

    uint32_t swap_endian(uint32_t val) {
        val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
        return (val << 16) | (val >> 16);
    }
    
    void read_mnist_cv(const char* image_filename, const char* label_filename){
        //Open files
        std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
        std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
    
        //Read the magic and the meta data
        uint32_t magic;
        uint32_t num_items;
        uint32_t num_labels;
        uint32_t rows;
        uint32_t cols;
    
        image_file.read(reinterpret_cast<char*>(&magic), 4);
        magic = swap_endian(magic);
        if(magic != 2051){
            cout<<"Incorrect image file magic: "<<magic<<endl;
            return;
        }
    
        label_file.read(reinterpret_cast<char*>(&magic), 4);
        magic = swap_endian(magic);
        if(magic != 2049){
            cout<<"Incorrect image file magic: "<<magic<<endl;
            return;
        }
    
        image_file.read(reinterpret_cast<char*>(&num_items), 4);
        num_items = swap_endian(num_items);
        label_file.read(reinterpret_cast<char*>(&num_labels), 4);
        num_labels = swap_endian(num_labels);
        if(num_items != num_labels){
            cout<<"image file nums should equal to label num"<<endl;
            return;
        }
    
        image_file.read(reinterpret_cast<char*>(&rows), 4);
        rows = swap_endian(rows);
        image_file.read(reinterpret_cast<char*>(&cols), 4);
        cols = swap_endian(cols);
    
        cout<<"image and label num is: "<<num_items<<endl;
        cout<<"image rows: "<<rows<<", cols: "<<cols<<endl;
    
        char label;
        char* pixels = new char[rows * cols];
    
        for (int item_id = 0; item_id < num_items; ++item_id) {
            //read image pixel
            image_file.read(pixels, rows * cols);
            //read label
            label_file.read(&label, 1);
    
            string sLabel = std::to_string(int(label));
            cout<<"lable is: "<<sLabel<<endl;
            //convert it to cv Mat, and show it
            cv::Mat image_tmp(rows,cols,CV_8UC1,pixels);
            //resize bigger for showing
            cv::resize(image_tmp, image_tmp, cv::Size(100, 100));
            cv::imshow(sLabel, image_tmp);
            cv::waitKey(0);
        }
    
        delete[] pixels;
    }

    Di utilizzo(ho semplificare il codice, ommited intestazioni e lo spazio dei nomi):

    string base_dir = "/home/xy/caffe-master/data/mnist/";
    string img_path = base_dir + "train-images-idx3-ubyte";
    string label_path = base_dir + "train-labels-idx1-ubyte";
    
    read_mnist_cv(img_path.c_str(), label_path.c_str());

    Uscita come di seguito:

    Come leggere MNIST dati in C++?

  5. 0

    Utilizzando in() , è in grado di leggere qualsiasi formato di dati che si desidera.

    const int MAXN = 6e4 + 7;
    unsigned int image[MAXN][30][30];
    unsigned int num, magic, rows, cols;
    unsigned int label[MAXN];
    unsigned int in(ifstream& icin, unsigned int size) {
        unsigned int ans = 0;
        for (int i = 0; i < size; i++) {
            unsigned char x;
            icin.read((char*)&x, 1);
            unsigned int temp = x;
            ans <<= 8;
            ans += temp;
        }
        return ans;
    }
    void input() {
        ifstream icin;
        icin.open("train-images.idx3-ubyte", ios::binary);
        magic = in(icin, 4), num = in(icin, 4), rows = in(icin, 4), cols = in(icin, 4);
        for (int i = 0; i < num; i++) {
            for (int x = 0; x < rows; x++) {
                for (int y = 0; y < cols; y++) {
                    image[i][x][y] = in(icin, 1);
                }
            }
        }
        icin.close();
        icin.open("train-labels.idx1-ubyte", ios::binary);
        magic = in(icin, 4), num = in(icin, 4);
        for (int i = 0; i < num; i++) {
            label[i] = in(icin, 1);
        }
    }

Lascia un commento