This repository was archived by the owner on Aug 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSGDLearner.cpp
More file actions
59 lines (52 loc) · 1.89 KB
/
SGDLearner.cpp
File metadata and controls
59 lines (52 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
//
// Created by Felix Zhang on 2022-12-28.
//
#include "SGDLearner.h"
#include "fmt/core.h"
#include <filesystem>
#include <fstream>
#include <iostream>
SGDLearner::SGDLearner(float rate) : LEARNING_RATE(rate) {
weights = VecLab::Zero();
}
void SGDLearner::saveWeigths(std::filesystem::path relFilepath)
{
// Save trained weights to a file so that the model doesn't need to be
// re-trained when testing. Input is a filepath relative to the source dir.
auto srcDir = getSrcDir(std::filesystem::current_path());
auto path = (srcDir / relFilepath).make_preferred();
std::ofstream ostream(path, std::ios::binary);
if (ostream.is_open()) {
for (auto it = weights.begin(); it != weights.end(); it++) {
double currentValue = *it;
ostream.write(reinterpret_cast<char*>(¤tValue), sizeof(currentValue));
}
std::cout << fmt::format("Weights successfully saved to \"{}\".\n", path.string());
}
else {
std::cout << fmt::format("Error writing to \"{}\".\n", path.string());
}
ostream.close();
}
void SGDLearner::loadWeigths(std::filesystem::path relFilepath)
{
auto srcDir = getSrcDir(std::filesystem::current_path());
auto path = (srcDir / relFilepath).make_preferred();
std::ifstream istream(path, std::ios::binary);
if (istream.is_open()) {
int i = 0;
double* inputBuffer = new double[IMAGE_SIZE * 10];
while (!istream.eof() && i < IMAGE_SIZE * 10) {
istream.read(reinterpret_cast<char*>(&(inputBuffer[i])), sizeof(double));
i++;
}
Eigen::Map<VecLab> readWeights(inputBuffer);
weights = readWeights;
std::cout << fmt::format("Weights successfully read from \"{}\".\n", path.string());
delete[] inputBuffer;
}
else {
std::cout << fmt::format("Error opening \"{}\".\n", path.string());
}
istream.close();
}