-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMyID3.java
More file actions
268 lines (249 loc) · 9.49 KB
/
MyID3.java
File metadata and controls
268 lines (249 loc) · 9.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
package decisiontree;
import support.decisiontree.Attribute;
import support.decisiontree.DecisionTreeData;
import support.decisiontree.DecisionTreeNode;
import support.decisiontree.ID3;
import java.util.ArrayList;
import java.util.Set;
/**
* This class is where your ID3 algorithm should be implemented.
*/
public class MyID3 implements ID3 {
String[] _classifications;
String _classification;
DecisionTreeData _new_data;
/**
* Constructor. You don't need to edit this.
*/
public MyID3() {
}
/**
* This is the trigger method that actually runs the algorithm.
* This will be called by the visualizer when you click 'train'.
*/
@Override
public DecisionTreeNode id3Trigger(DecisionTreeData data) {
// TODO run the algorithm, return the root of the tree
_new_data = data;
_classifications = data.getClassifications();
ArrayList<Attribute> attributes = data.getAttributeList();
return this.myID3Algorithm(_new_data, data, attributes);
}
/**
* This is the recursive algorithm that sets up the decision tree
*/
private DecisionTreeNode myID3Algorithm(DecisionTreeData data, DecisionTreeData parentData, ArrayList<Attribute> attributes) {
DecisionTreeNode node = new DecisionTreeNode();
if(checkDataEmpty(data)) {
node.setElement(findMostFrequentClassification(parentData));
return node;
}
else if(sameClassificationCheck(data)){
node.setElement(_classification);
return node;
}
else if(checkAttributesEmpty(attributes)){
node.setElement(findMostFrequentClassification(data));
return node;
}
else{
Attribute maxInfoAttribute = this.calculateMaxInfoAttribute(data, attributes);
node.setElement(maxInfoAttribute.getName());
Set<String> values = maxInfoAttribute.getValues();
for (String value: values){
DecisionTreeData new_data = this.newDataInitializer(data, maxInfoAttribute, value, attributes);
if(attributes.contains(maxInfoAttribute)) {
attributes.remove(this.maxInfoAttributeIndex(data, attributes));
}
DecisionTreeNode subTree = this.myID3Algorithm(new_data, data, attributes);
node.addChild(value, subTree);
}
return node;
}
}
/**
* This method finds the most frequent classification in a given data set
*/
public String findMostFrequentClassification(DecisionTreeData data){
String[][] examples = data.getExamples();
_classifications = data.getClassifications();
int lastCol = examples[0].length - 1;
int numFirst = 0;
int numSecond = 0;
for (int i = 0; i < examples.length; i++){
if(examples[i][lastCol].equals(_classifications[0])){
numFirst += 1;
}
if(examples[i][lastCol].equals(_classifications[1])){
numSecond += 1;
}
}
if(numFirst > numSecond){
return _classifications[0];
}
return _classifications[1];
}
/**
* This method checks whether there are no examples in a given data set
*/
public boolean checkDataEmpty(DecisionTreeData data){
String[][] examples = data.getExamples();
if(examples.length == 0){
return true;
}
return false;
}
/**
* This method checks if there are no attributes within a given Attribute List
*/
public boolean checkAttributesEmpty(ArrayList<Attribute> attributes){
if(attributes.isEmpty()){
return true;
}
return false;
}
/**
* This method checks whether every classification in the same in a data set
*/
public boolean sameClassificationCheck(DecisionTreeData data){
String[][] examples = data.getExamples();
int lastCol = examples[0].length - 1;
String baseClassification = examples[0][lastCol];
for (int i = 0; i < examples.length; i++) {
if(!examples[i][lastCol].equals(baseClassification)){
return false;
}
}
_classification = baseClassification;
return true;
}
/**
* This method returns the attribute that has the most information gain in a particular data set
*/
public Attribute calculateMaxInfoAttribute(DecisionTreeData data, ArrayList<Attribute> attributes){
ArrayList<Double> informationGains = new ArrayList<>();
for (Attribute attribute: attributes){
informationGains.add(this.calculateInformationGain(data, attribute));
}
double maxInfo = informationGains.get(0);
int newIndex = 0;
for(int i = 0; i < informationGains.size(); i++){
if(informationGains.get(i) > maxInfo){
maxInfo = informationGains.get(i);
newIndex = i;
}
}
return attributes.get(newIndex);
}
/**
* A method that mathematically computes and returns the information gain for a given attribute
*/
public double calculateInformationGain(DecisionTreeData data, Attribute attribute){
double informationGain = this.calculateEntropy(data) - this.calculateRemainder(data, attribute);
return informationGain;
}
/**
* This method calculates the entropy of a given data set mathematically
*/
public double calculateEntropy(DecisionTreeData data){
_classifications = data.getClassifications();
String[][] examples = data.getExamples();
int lastCol = examples[0].length - 1;
double positive = 0;
double negative = 0;
for (int i = 0; i < examples.length; i++){
if(examples[i][lastCol].equals(_classifications[0])){
positive += 1;
}
if(examples[i][lastCol].equals(_classifications[1])){
negative += 1;
}
}
return calculateEntropyHelper(positive, negative);
}
/**
* A helper method for the mathematical computation of entropy
*/
public double calculateEntropyHelper(double positive, double negative){
double ratio = positive/(positive + negative);
double entropy = -1 * ((ratio * this.logBaseTwo(ratio)) +
((1-ratio)*(this.logBaseTwo(1-ratio))));
return entropy;
}
/**
* A helper method for the mathematical computation of entropy
*/
public double logBaseTwo(double logNumber){
if(logNumber == 0){
return 0;
}
return Math.log(logNumber) / Math.log(2);
}
/**
* A method that calculates the remainder of an attribute by calculating the weighted entropy of its subsets
*/
public double calculateRemainder(DecisionTreeData data, Attribute attribute){
_classifications = data.getClassifications();
double remainder = 0;
String[][] examples = data.getExamples();
double numExamples = examples.length;
int lastCol = examples[0].length - 1;
Set<String> attributeValues = attribute.getValues();
int attributeColumn = attribute.getColumn();
for (String value: attributeValues){
double positive = 0;
double negative = 0;
for (int i = 0; i < numExamples; i++) {
if(examples[i][attributeColumn].equals(value)){
if(examples[i][lastCol].equals(_classifications[0])){
positive += 1;
}
if(examples[i][lastCol].equals(_classifications[1])){
negative += 1;
}
}
}
if(!(positive + negative == 0)){
double subEntropy = this.calculateEntropyHelper(positive, negative);
double weight = (positive+negative)/(numExamples);
remainder += (subEntropy*weight);
}
}
return remainder;
}
/**
* A method that initializes a new set of Data based on the given parameters
*/
public DecisionTreeData newDataInitializer(DecisionTreeData data, Attribute attribute, String value, ArrayList<Attribute> attributes){
_classifications = data.getClassifications();
String[][] examples = data.getExamples();
int numNewRows = 0;
int attributeColumn = attribute.getColumn();
for(int row=0; row < examples.length; row++){
if(examples[row][attributeColumn].equals(value)){
numNewRows += 1;
}
}
String[][] newExamples = new String[numNewRows][examples[0].length];
int newRows = 0;
for(int row=0; row < examples.length; row++){
if(examples[row][attributeColumn].equals(value)){
newExamples[newRows] = examples[row];
newRows += 1;
}
}
DecisionTreeData newData = new DecisionTreeData(newExamples, attributes, _classifications);
return newData;
}
/**
* A helper method for the specific index computation of the max information gain element in an arrayList
*/
public int maxInfoAttributeIndex(DecisionTreeData data, ArrayList<Attribute> attributes){
for(int i=0; i < attributes.size(); i++){
if(attributes.get(i).equals(this.calculateMaxInfoAttribute(data, attributes))){
return i;
}
}
return 0;
}
}