-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn_utils.h
More file actions
138 lines (120 loc) · 2.34 KB
/
nn_utils.h
File metadata and controls
138 lines (120 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <math.h>
#include <stdlib.h>
#include<string.h>
#ifndef NN_UTILS_H
#define NN_UTILS_H
typedef enum
{
RELU,
SIG,
LINEAR
} AF;
typedef enum
{
RAND,
HE,
XG
} WI;
#define RELU_PARAM 0.001
#define ARRAY_LEN(x) sizeof((x)) / sizeof((x)[0])
float rand_float();
float sigmoidf(float x);
float sigmoidf_derivative(float x);
float ReLu(float x);
float ReLu_derivative(float x);
float he_init(int input_size);
float xavier_init(int input_size, int output_size);
int isCSVFile(const char *filename);
float rand_float()
{
return (float)rand() / (float)RAND_MAX;
}
float sigmoidf(float x)
{
return 1.0 / (1.0 + expf(-x));
}
float sigmoidf_derivative(float x)
{
return x * (1 - x);
}
float ReLu(float x)
{
return x > 0 ? x : RELU_PARAM * x;
}
float ReLu_derivative(float x)
{
return x > 0 ? 1.0 : RELU_PARAM;
}
float act(float x, AF af)
{
switch (af)
{
case RELU:
return ReLu(x);
break;
case SIG:
return sigmoidf(x);
break;
case LINEAR:
return x;
default:
break;
}
}
float dact(float x, AF af)
{
switch (af)
{
case SIG:
return sigmoidf_derivative(x);
break;
case RELU:
return ReLu_derivative(x);
break;
case LINEAR:
return 1;
default:
break;
}
}
float he_init(int input_size)
{
return (float)rand() / RAND_MAX * sqrt(2.0 / input_size);
}
float xavier_init(int input_size, int output_size)
{
return (float)rand() / RAND_MAX * sqrt(1.0 / (input_size + output_size));
}
float weights_init(int input_size, int output_size, WI wi)
{
switch (wi)
{
case RAND:
return rand_float();
break;
case HE:
return he_init(input_size);
break;
case XG:
return xavier_init(input_size, output_size);
/* code */
break;
default:
break;
}
}
int isCSVFile(const char *filename)
{
// Get the length of the filename
size_t len = strlen(filename);
// Check if the filename ends with ".csv"
if (len >= 4 && strcmp(filename + len - 4, ".csv") == 0)
{
return 1; // It is a CSV file
}
else
{
return 0; // It is not a CSV file
}
}
#endif