This repository was archived by the owner on Aug 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataset.h
More file actions
94 lines (72 loc) · 2.68 KB
/
Dataset.h
File metadata and controls
94 lines (72 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//
// Created by Felix Zhang on 2022-12-28.
//
#ifndef MNIST_CLASSIFIER_DATASET_H
#define MNIST_CLASSIFIER_DATASET_H
#include "globals.h"
#include "fileHelpers.h"
#include <iostream>
#include <fstream>
#include <filesystem>
template<int NUM_TRAIN, int NUM_TEST>
class Dataset {
private:
// only call on initialization
// This function returns a pointer to heap-allocated memory
template<int N>
DataPoint* readSet(const std::filesystem::path& xFile, const std::filesystem::path& yFile);
public:
DataPoint *trainSet, *testSet;
Dataset(const std::filesystem::path& trainImages, const std::filesystem::path& trainLabels, const std::filesystem::path& testImages,
const std::filesystem::path& testLabels) {
trainSet = readSet<NUM_TRAIN>(TRAIN_IMAGES, TRAIN_LABELS);
testSet = readSet<NUM_TEST>(TEST_IMAGES, TEST_LABELS);
}
~Dataset() {
// Free memory allocated with the "readSet" function
delete[] trainSet;
delete[] testSet;
}
};
// template functions MUST be defined in the header file
template<int NUM_TRAIN, int NUM_TEST>
template<int N>
DataPoint* Dataset<NUM_TRAIN, NUM_TEST>::readSet(const std::filesystem::path& xFile, const std::filesystem::path& yFile) {
std::ifstream fDom, fLab;
// Retrieve training set directory
auto srcDir = getSrcDir(std::filesystem::current_path());
std::filesystem::path xFileAbs = (srcDir / xFile).make_preferred();
std::filesystem::path yFileAbs = (srcDir / yFile).make_preferred();
fDom.open(xFileAbs, std::ios::binary | std::ios::in);
fLab.open(yFileAbs, std::ios::binary | std::ios::in);
if (!fDom.is_open()) {
std::cout << "Error opening " << xFileAbs << "!" << std::endl;
};
if (!fLab.is_open()) {
std::cout << "Error opening " << yFileAbs << "!" << std::endl;
};
// skip the header
char bufferHeader[16];
fDom.read(bufferHeader, 16);
fLab.read(bufferHeader, 8);
char* bufferX = new char[IMAGE_SIZE];
auto* points = new DataPoint[N];
char bufferLab;
int i = 0;
while (!fDom.eof() && i < N) {
fDom.read(bufferX, IMAGE_SIZE);
fLab.get(bufferLab);
auto bufferIm = reinterpret_cast<unsigned char*>(bufferX);
auto* bufferDom = new double[IMAGE_SIZE];
std::transform(bufferIm, bufferIm + IMAGE_SIZE, bufferDom,
[](const unsigned char c) -> double { return (double) c; });
Eigen::Map<VecDom> point(bufferDom);
points[i] = *new DataPoint{point, *reinterpret_cast<unsigned char*>(&bufferLab)};
i++;
}
delete[] bufferX;
fDom.close();
fLab.close();
return points;
}
#endif //MNIST_CLASSIFIER_DATASET_H