-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLinearRegression.java
More file actions
128 lines (121 loc) · 4.61 KB
/
LinearRegression.java
File metadata and controls
128 lines (121 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import java.util.List;
public class LinearRegression {
List<List<Object>> data;
double[] weights;
double bias;
int numFeatures;
int numSamples;
EncodeData encoder;
double learningRate = 0.01;
int epochs = 1000;
LinearRegression() {
this.data = null;
this.bias = 0.0;
}
private boolean isNumeric(Class<?> cls) {
return Number.class.isAssignableFrom(cls);
}
private void checkEncoding(List<List<Object>> data, List<Class<?>> colType){
if (this.encoder == null) {
this.encoder = new EncodeData();
}
int numCols = colType.size();
for(int j = 0; j < numCols; j++){
if(isNumeric(colType.get(j))){
continue;
}
String key = "col_" + j;
this.encoder.newMap(key);
int code = 0;
for(int i = 0; i < data.size(); i++){
Object val = data.get(i).get(j);
String subKey = val.toString();
if(!this.encoder.checkSubkey(key, subKey)){
this.encoder.addSingleData(key, subKey, code);
code += 1;
}
int encodedVal = this.encoder.getData(key, subKey);
data.get(i).set(j, encodedVal);
}
colType.set(j, Integer.class);
}
}
public void fit(DataFrame dataframe, int targetIndex){
DataFrame copy = dataframe.deepCopy();
this.data = copy.getData();
this.numSamples = data.size();
this.numFeatures = data.get(0).size() - 1;
this.weights = new double[numFeatures];
List<Object> target = copy.getColumn(targetIndex);
copy.removeColumn(targetIndex);
checkEncoding(this.data, copy.colType);
fit(this.data, target);
}
private double predictRow(List<Object> row) {
double yHat =this.bias;
for (int j = 0; j < this.weights.length; j++) {
yHat += this.weights[j] * ((Number) row.get(j)).doubleValue();
}
return yHat;
}
private void fit(List<List<Object>> features, List<Object> target){
for(int epoch = 0; epoch < epochs; epoch++){
double[] weightGradients = new double[numFeatures];
double biasGradient = 0.0;
for(int i = 0; i < numSamples; i++){
double yHat = predictRow(features.get(i));
double yTrue = ((Number) target.get(i)).doubleValue();
double error = yHat - yTrue;
for(int j = 0; j < numFeatures; j++){
weightGradients[j] += error * ((Number) features.get(i).get(j)).doubleValue();
}
biasGradient += error;
}
double loss = 0.0;
for (int i = 0; i < numSamples; i++) {
double err = predictRow(features.get(i)) -
((Number) target.get(i)).doubleValue();
loss += err * err;
}
loss /= numSamples;
if (epoch % 100 == 0) {
System.out.println("Epoch " + epoch + " Loss: " + loss);
}
for(int j = 0; j < numFeatures; j++){
this.weights[j] -= 2 * (learningRate / numSamples) * weightGradients[j];
}
this.bias -= 2 * (learningRate / numSamples) * biasGradient;
}
}
public double predict(List<Object> row){
if (row.size() != numFeatures) {
throw new IllegalArgumentException("Feature size mismatch"+
": expected " + numFeatures + ", got " + row.size());
}
List<Object> encodedRow = new java.util.ArrayList<>(numFeatures);
for(int i = 0; i < numFeatures; i++){
Object val = row.get(i);
if(val instanceof Number){
encodedRow.add(((Number) val).doubleValue());
continue;
}
String key = "col_" + i;
String subKey = val.toString();
if(!this.encoder.checkSubkey(key, subKey)){
throw new IllegalArgumentException("Unknown category: " + subKey + " in column " + i);
}
double encodedVal = (double) this.encoder.getData(key, subKey);
encodedRow.add(encodedVal);
}
return predictRow(encodedRow);
}
public double[] predict(DataFrame dataframe) {
List<List<Object>> features = dataframe.getData();
int n = features.size();
double[] predictions = new double[n];
for (int i = 0; i < n; i++) {
predictions[i] = predict(features.get(i));
}
return predictions;
}
}