diff --git a/.gitignore b/.gitignore
index 4c17f0a1..179d6cd3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,7 @@
*.class
.DS_Store
+*.idea
+*.iml
bin/*
ABAGAIL.jar
# Package Files #
diff --git a/build.xml b/build.xml
index a14c973c..57ebf971 100644
--- a/build.xml
+++ b/build.xml
@@ -4,6 +4,7 @@
+
@@ -18,7 +19,6 @@
-
@@ -31,18 +31,28 @@
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/test/RunTests.java b/src/test/RunTests.java
new file mode 100644
index 00000000..f9854bb9
--- /dev/null
+++ b/src/test/RunTests.java
@@ -0,0 +1,82 @@
+package test;
+
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.Enumeration;
+import java.util.List;
+
+/**
+ * Author: Derek Miller dmiller303@gatech.edu
+ */
+public class RunTests {
+
+ public static void main(String args[]) throws Exception{
+ System.out.println(handleTests());
+ }
+
+ private static String handleTests() throws Exception{
+ String exceptionHolder = "";
+
+ List classes = getClasses("test.unit");
+ int passCount = 0;
+ int failCount = 0;
+
+ for (Class clazz : classes){
+ Method[] methods = clazz.getDeclaredMethods();
+ for (Method method : methods){
+ if (method.getName().startsWith("test")){
+ try {
+ method.invoke(clazz.newInstance());
+ passCount++;
+ } catch (Exception e){
+ failCount++;
+ exceptionHolder += "\n Class: " + clazz.getName() + " method: " + method.getName() + " message: " + e.getCause().getMessage() + "\n";
+ }
+
+ }
+ }
+ }
+ String finalMessage = "Total Tests Ran: " + (passCount + failCount) + " Tests Passed: " + passCount + " Tests Failed: " + failCount + "\n \n";
+ return finalMessage + exceptionHolder;
+ }
+
+ private static List getClasses(String packageName) throws ClassNotFoundException, IOException, Exception {
+ ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
+ if (classLoader == null) throw new ClassNotFoundException("Class isn't found");
+ //replaces the packagePath with slashes
+ String path = packageName.replace('.', '/');
+ Enumeration resources = classLoader.getResources(path);
+ List dirs = new ArrayList();
+ while (resources.hasMoreElements()) {
+ URL resource = resources.nextElement();
+ dirs.add(new File(resource.getFile()));
+ }
+ ArrayList classes = new ArrayList();
+ for (File directory : dirs) {
+ classes.addAll(findClasses(directory, packageName));
+ }
+ return classes;
+ }
+
+ private static List findClasses(File directory, String packageName) throws ClassNotFoundException, Exception {
+ List classes = new ArrayList();
+ if (!directory.exists()) {
+ return classes;
+ }
+ File[] files = directory.listFiles();
+ for (File file : files) {
+ if (file.isDirectory()) {
+ assert !file.getName().contains(".");
+ classes.addAll(findClasses(file, packageName + "." + file.getName()));
+ } else if (file.getName().endsWith(".class")) {
+ //adds class, substring removes .class
+ classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6)));
+ }
+ }
+ return classes;
+ }
+}
diff --git a/src/test/unit/func/dtree/DecisionTreeTest.java b/src/test/unit/func/dtree/DecisionTreeTest.java
new file mode 100644
index 00000000..a8f40bcb
--- /dev/null
+++ b/src/test/unit/func/dtree/DecisionTreeTest.java
@@ -0,0 +1,41 @@
+package test.unit.func.dtree;
+
+
+import func.DecisionTreeClassifier;
+import func.dtree.*;
+import shared.DataSet;
+import shared.Instance;
+import util.TestUtil;
+
+public class DecisionTreeTest {
+
+ public void testDiscrete(){
+ Instance[] instances = {
+ new Instance(new double[] {0, 0, 0, 1}, 1),
+ new Instance(new double[] {1, 0, 0, 0}, 1),
+ new Instance(new double[] {1, 0, 0, 0}, 1),
+ new Instance(new double[] {1, 0, 0, 0}, 1),
+ new Instance(new double[] {1, 0, 0, 1}, 0),
+ new Instance(new double[] {1, 0, 0, 1}, 0),
+ new Instance(new double[] {1, 0, 0, 1}, 0),
+ new Instance(new double[] {1, 0, 0, 1}, 0)
+ };
+ Instance[] tests = {
+ new Instance(new double[] {0, 1, 1, 1}),
+ new Instance(new double[] {0, 0, 0, 0}),
+ new Instance(new double[] {1, 0, 0, 0}),
+ new Instance(new double[] {1, 1, 1, 1})
+ };
+ DataSet set = new DataSet(instances);
+ PruningCriteria cspc = new ChiSquarePruningCriteria(0);
+ SplitEvaluator gse = new GINISplitEvaluator();
+ SplitEvaluator igse = new InformationGainSplitEvaluator();
+ DecisionTreeClassifier dt = new DecisionTreeClassifier(igse, null, true);
+ dt.estimate(set);
+
+ TestUtil.assertEquals(dt.value(tests[0]).getDiscrete(), 1);
+ TestUtil.assertEquals(dt.value(tests[1]).getDiscrete(), 1);
+ TestUtil.assertEquals(dt.value(tests[2]).getDiscrete(), 1);
+ TestUtil.assertEquals(dt.value(tests[3]).getDiscrete(), 0);
+ }
+}
diff --git a/src/util/TestUtil.java b/src/util/TestUtil.java
new file mode 100644
index 00000000..81d3a379
--- /dev/null
+++ b/src/util/TestUtil.java
@@ -0,0 +1,29 @@
+package util;
+
+
+public class TestUtil {
+
+ public static void assertEquals(Object item, Object item2){
+ boolean isEqual = item.equals(item2);
+ handleFailureOrSuccess(isEqual, "The first item does not match the expected value");
+ }
+
+ public static void assertNotEquals(Object item, Object item2){
+ boolean isEqual = item.equals(item2);
+ handleFailureOrSuccess(isEqual, "The first item matches the expected value");
+ }
+
+ public static void assertTrue(boolean expression){
+ handleFailureOrSuccess(expression, "The provided expression was not true");
+ }
+
+ public static void assertFalse(boolean expression){
+ handleFailureOrSuccess(!expression, "The provided expression was not false");
+ }
+
+ private static void handleFailureOrSuccess(boolean expression, String message){
+ if (!expression){
+ throw new RuntimeException(message);
+ }
+ }
+}