diff --git a/.gitignore b/.gitignore index 4c17f0a1..179d6cd3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ *.class .DS_Store +*.idea +*.iml bin/* ABAGAIL.jar # Package Files # diff --git a/build.xml b/build.xml index a14c973c..57ebf971 100644 --- a/build.xml +++ b/build.xml @@ -4,6 +4,7 @@ + @@ -18,7 +19,6 @@ - @@ -31,18 +31,28 @@ - - - + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/RunTests.java b/src/test/RunTests.java new file mode 100644 index 00000000..f9854bb9 --- /dev/null +++ b/src/test/RunTests.java @@ -0,0 +1,82 @@ +package test; + + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.URL; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; + +/** + * Author: Derek Miller dmiller303@gatech.edu + */ +public class RunTests { + + public static void main(String args[]) throws Exception{ + System.out.println(handleTests()); + } + + private static String handleTests() throws Exception{ + String exceptionHolder = ""; + + List classes = getClasses("test.unit"); + int passCount = 0; + int failCount = 0; + + for (Class clazz : classes){ + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods){ + if (method.getName().startsWith("test")){ + try { + method.invoke(clazz.newInstance()); + passCount++; + } catch (Exception e){ + failCount++; + exceptionHolder += "\n Class: " + clazz.getName() + " method: " + method.getName() + " message: " + e.getCause().getMessage() + "\n"; + } + + } + } + } + String finalMessage = "Total Tests Ran: " + (passCount + failCount) + " Tests Passed: " + passCount + " Tests Failed: " + failCount + "\n \n"; + return finalMessage + exceptionHolder; + } + + private static List getClasses(String packageName) throws ClassNotFoundException, IOException, Exception { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (classLoader == null) throw new ClassNotFoundException("Class isn't found"); + //replaces the packagePath with slashes + String path = packageName.replace('.', '/'); + Enumeration resources = classLoader.getResources(path); + List dirs = new ArrayList(); + while (resources.hasMoreElements()) { + URL resource = resources.nextElement(); + dirs.add(new File(resource.getFile())); + } + ArrayList classes = new ArrayList(); + for (File directory : dirs) { + classes.addAll(findClasses(directory, packageName)); + } + return classes; + } + + private static List findClasses(File directory, String packageName) throws ClassNotFoundException, Exception { + List classes = new ArrayList(); + if (!directory.exists()) { + return classes; + } + File[] files = directory.listFiles(); + for (File file : files) { + if (file.isDirectory()) { + assert !file.getName().contains("."); + classes.addAll(findClasses(file, packageName + "." + file.getName())); + } else if (file.getName().endsWith(".class")) { + //adds class, substring removes .class + classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6))); + } + } + return classes; + } +} diff --git a/src/test/unit/func/dtree/DecisionTreeTest.java b/src/test/unit/func/dtree/DecisionTreeTest.java new file mode 100644 index 00000000..a8f40bcb --- /dev/null +++ b/src/test/unit/func/dtree/DecisionTreeTest.java @@ -0,0 +1,41 @@ +package test.unit.func.dtree; + + +import func.DecisionTreeClassifier; +import func.dtree.*; +import shared.DataSet; +import shared.Instance; +import util.TestUtil; + +public class DecisionTreeTest { + + public void testDiscrete(){ + Instance[] instances = { + new Instance(new double[] {0, 0, 0, 1}, 1), + new Instance(new double[] {1, 0, 0, 0}, 1), + new Instance(new double[] {1, 0, 0, 0}, 1), + new Instance(new double[] {1, 0, 0, 0}, 1), + new Instance(new double[] {1, 0, 0, 1}, 0), + new Instance(new double[] {1, 0, 0, 1}, 0), + new Instance(new double[] {1, 0, 0, 1}, 0), + new Instance(new double[] {1, 0, 0, 1}, 0) + }; + Instance[] tests = { + new Instance(new double[] {0, 1, 1, 1}), + new Instance(new double[] {0, 0, 0, 0}), + new Instance(new double[] {1, 0, 0, 0}), + new Instance(new double[] {1, 1, 1, 1}) + }; + DataSet set = new DataSet(instances); + PruningCriteria cspc = new ChiSquarePruningCriteria(0); + SplitEvaluator gse = new GINISplitEvaluator(); + SplitEvaluator igse = new InformationGainSplitEvaluator(); + DecisionTreeClassifier dt = new DecisionTreeClassifier(igse, null, true); + dt.estimate(set); + + TestUtil.assertEquals(dt.value(tests[0]).getDiscrete(), 1); + TestUtil.assertEquals(dt.value(tests[1]).getDiscrete(), 1); + TestUtil.assertEquals(dt.value(tests[2]).getDiscrete(), 1); + TestUtil.assertEquals(dt.value(tests[3]).getDiscrete(), 0); + } +} diff --git a/src/util/TestUtil.java b/src/util/TestUtil.java new file mode 100644 index 00000000..81d3a379 --- /dev/null +++ b/src/util/TestUtil.java @@ -0,0 +1,29 @@ +package util; + + +public class TestUtil { + + public static void assertEquals(Object item, Object item2){ + boolean isEqual = item.equals(item2); + handleFailureOrSuccess(isEqual, "The first item does not match the expected value"); + } + + public static void assertNotEquals(Object item, Object item2){ + boolean isEqual = item.equals(item2); + handleFailureOrSuccess(isEqual, "The first item matches the expected value"); + } + + public static void assertTrue(boolean expression){ + handleFailureOrSuccess(expression, "The provided expression was not true"); + } + + public static void assertFalse(boolean expression){ + handleFailureOrSuccess(!expression, "The provided expression was not false"); + } + + private static void handleFailureOrSuccess(boolean expression, String message){ + if (!expression){ + throw new RuntimeException(message); + } + } +}