bestSuggestions = new LinkedList<>();
+ double[] preSplitDist = observedClassDistribution.getArrayCopy();
+ if (!noPrePrune) {
+ // add null split as an option
+ bestSuggestions.add(new AttributeSplitSuggestion(null,
+ new double[0][], criterion.getMeritOfSplit(
+ preSplitDist, new double[][]{preSplitDist})));
+ }
+ for (int i = 0; i < attributeObservers.size(); i++) {
+ AttributeClassObserver obs = attributeObservers.get(i);
+ if (obs != null) {
+ AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(
+ criterion, preSplitDist, i, binaryOnly
+ );
+ if (bestSuggestion != null) {
+ bestSuggestions.add(bestSuggestion);
+ }
+ }
+ }
+ return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
+ }
+
+ /**
+ * Train on the provided instance. Re-evaluate the tree
+ *
+ * @param instance the instance to train on
+ * @param totalNumInstances the total number of instances observed so far. (not used in EFDT but might be overwritten by subclasses)
+ **/
+ public void learnInstance(Instance instance, int totalNumInstances) {
+ seenWeight += instance.weight();
+ nodeTime++;
+ updateStatistics(instance);
+ updateObservers(instance);
+
+ if (isLeaf() && nodeTime % gracePeriod == 0) {
+ attemptInitialSplit(instance);
+ }
+ if (!isLeaf() && nodeTime % minSamplesReevaluate == 0) {
+ reevaluateSplit(instance);
+ }
+ if (!isLeaf()) {
+ propagateToSuccessors(instance, totalNumInstances);
+ }
+
+ if (!"MC".equals(leafPrediction))
+ updateNaiveBayesAdaptive(instance);
+ }
+
+ public void updateNaiveBayesAdaptive(Instance inst) {
+ int trueClass = (int) inst.classValue();
+ if (this.observedClassDistribution.maxIndex() == trueClass) {
+ this.mcCorrectWeight += inst.weight();
+ }
+ if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution, this.attributeObservers)) == trueClass) {
+ this.nbCorrectWeight += inst.weight();
+ }
+ }
+
+ /**
+ * Predict the provided instance
+ *
+ * Traverses the tree until reaching a leaf and then returns the class votes of that leaf.
+ *
+ *
+ * @param instance the instance to predict
+ * @return the class votes
+ **/
+ public double[] predict(Instance instance) {
+ if (!isLeaf()) {
+ CustomEFDTNode successor = getSuccessor(instance);
+ if (successor == null)
+ return getClassVotes();
+ return successor.predict(instance);
+ }
+ if ("MC".equals(leafPrediction))
+ return getClassVotes();
+ else if ("NB".equals(leafPrediction))
+ return NaiveBayes.doNaiveBayesPrediction(instance, this.observedClassDistribution, this.attributeObservers);
+ else
+ return doNBAdaptive(instance);
+ }
+
+ private double[] doNBAdaptive(Instance instance) {
+ if (this.mcCorrectWeight > this.nbCorrectWeight) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+ return NaiveBayes.doNaiveBayesPrediction(instance,
+ this.observedClassDistribution, this.attributeObservers);
+ }
+
+ /**
+ * Finds the successor of the current node based on the provided instance.
+ *
+ * @param instance the instance
+ * @return the successor node if it exists. Else, returns null.
+ **/
+ CustomEFDTNode getSuccessor(Instance instance) {
+ if (isLeaf())
+ return null;
+ Double attVal = instance.value(splitAttribute);
+ return successors.getSuccessorNode(attVal);
+ }
+
+ /**
+ * Attempts to split this leaf
+ *
+ * @param instance the current instance
+ **/
+ protected void attemptInitialSplit(Instance instance) {
+ if (depth >= maxDepth) {
+ return;
+ }
+ if (isPure())
+ return;
+
+ numSplitAttempts++;
+
+ AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSuggestions);
+ AttributeSplitSuggestion xBest = bestSuggestions[bestSuggestions.length - 1];
+ xBest = replaceBestSuggestionIfAttributeIsBlocked(xBest, bestSuggestions, blockedAttributeIndex);
+
+ if (!shouldSplitLeaf(bestSuggestions, currentConfidence(), observedClassDistribution))
+ return;
+ if (xBest.splitTest == null) {
+ // preprune - null wins
+ System.out.println("preprune - null wins");
+ killSubtree();
+ resetSplitAttribute();
+ return;
+ }
+ int instanceIndex = modelAttIndexToInstanceAttIndex(xBest.splitTest.getAttsTestDependsOn()[0], instance);
+ Attribute newSplitAttribute = instance.attribute(instanceIndex);
+ makeSplit(newSplitAttribute, xBest);
+ classDistributionAtTimeOfCreation = new DoubleVector(observedClassDistribution.getArrayCopy());
+ }
+
+ /**
+ * Reevaluates a split decision that was already made
+ *
+ * @param instance the current instance
+ **/
+ protected void reevaluateSplit(Instance instance) {
+ numSplitAttempts++;
+
+ AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSuggestions);
+ if (bestSuggestions.length == 0)
+ return;
+
+ // get best split suggestions
+ AttributeSplitSuggestion[] bestSplitSuggestions = getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSplitSuggestions);
+ AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+
+ double bestSuggestionAverageMerit = bestSuggestion.splitTest == null ? 0.0 : bestSuggestion.merit;
+ double currentAverageMerit = getCurrentSuggestionAverageMerit(bestSuggestions);
+ double deltaG = bestSuggestionAverageMerit - currentAverageMerit;
+ double eps = computeHoeffdingBound();
+
+ if (deltaG > eps || (eps < tauReevaluate && deltaG > tauReevaluate * relMinDeltaG)) {
+
+ if (bestSuggestion.splitTest == null) {
+ System.out.println("preprune - null wins");
+ killSubtree();
+ resetSplitAttribute();
+ } else {
+ boolean doResplit = true;
+ if (
+ getSplitTest() instanceof NumericAttributeBinaryTest
+ && getSplitTest().getAttsTestDependsOn()[0] == bestSuggestion.splitTest.getAttsTestDependsOn()[0]
+ ) {
+ Set keys = successors.getKeyset();
+ for (SuccessorIdentifier key: keys) {
+ if (key.isLower()) {
+ if (argmax(bestSuggestion.resultingClassDistributions[0]) == argmax(successors.getSuccessorNode(key).observedClassDistribution.getArrayRef())) {
+ doResplit = false;
+ break;
+ }
+ }
+ else {
+ if (argmax(bestSuggestion.resultingClassDistributions[1]) == argmax(successors.getSuccessorNode(key).observedClassDistribution.getArrayRef())) {
+ doResplit = false;
+ break;
+ }
+ }
+ }
+ }
+ performedTreeRevision = true;
+ if (!doResplit) {
+ NumericAttributeBinaryTest test = (NumericAttributeBinaryTest) bestSuggestion.splitTest;
+ successors.adjustThreshold(test.getSplitValue());
+ setSplitTest(bestSuggestion.splitTest);
+ nodeTime = 0;
+ seenWeight = 0.0;
+ }
+ else {
+ int instanceIndex = modelAttIndexToInstanceAttIndex(bestSuggestion.splitTest.getAttsTestDependsOn()[0], instance);
+ Attribute newSplitAttribute = instance.attribute(instanceIndex);
+ makeSplit(newSplitAttribute, bestSuggestion);
+ nodeTime = 0;
+ seenWeight = 0.0;
+ }
+ }
+ }
+ }
+
+ /**
+ * Initializes the successor nodes when performing a split
+ *
+ * @param xBest the split suggestion for the best split
+ * @param splitAttribute the attribute to split on
+ * @return if the initialization was successful.
+ **/
+ protected boolean initializeSuccessors(AttributeSplitSuggestion xBest, Attribute splitAttribute) {
+ boolean isNominal = splitAttribute.isNominal();
+ boolean isBinary = !(xBest.splitTest instanceof NominalAttributeMultiwayTest);
+ Double splitValue = null;
+ if (isNominal && isBinary)
+ splitValue = ((NominalAttributeBinaryTest) xBest.splitTest).getValue();
+ else if (!isNominal)
+ splitValue = ((NumericAttributeBinaryTest) xBest.splitTest).getSplitValue();
+
+ Integer splitAttributeIndex = xBest.splitTest.getAttsTestDependsOn()[0];
+ if (splitAttribute.isNominal()) {
+ if (!isBinary) {
+ for (int i = 0; i < xBest.numSplits(); i++) {
+ double[] stats = xBest.resultingClassDistributionFromSplit(i);
+ if (stats.length == 0)
+ continue;
+ CustomEFDTNode s = newNode(
+ depth + 1,
+ new DoubleVector(stats),
+ getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex));
+ boolean success = successors.addSuccessorNominalMultiway((double) i, s);
+ if (!success) {
+ successors = null;
+ return false;
+ }
+ }
+ return !isLeaf();
+ } else {
+ double[] stats1 = xBest.resultingClassDistributionFromSplit(0);
+ double[] stats2 = xBest.resultingClassDistributionFromSplit(1);
+ CustomEFDTNode s1 = newNode(depth + 1, new DoubleVector(stats1), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex));
+ CustomEFDTNode s2 = newNode(depth + 1, new DoubleVector(stats2), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex));
+ boolean success = successors.addSuccessorNominalBinary(splitValue, s1);
+ success = success && successors.addDefaultSuccessorNominalBinary(s2);
+ if (!success) {
+ successors = null;
+ return false;
+ }
+ return !isLeaf();
+ }
+ } else {
+ boolean success = successors.addSuccessorNumeric(
+ splitValue,
+ newNode(depth + 1, new DoubleVector(xBest.resultingClassDistributionFromSplit(0)), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)),
+ true
+ );
+ success = success && successors.addSuccessorNumeric(
+ splitValue,
+ newNode(depth + 1, new DoubleVector(xBest.resultingClassDistributionFromSplit(1)), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)),
+ false
+ );
+ if (!success) {
+ successors = null;
+ return false;
+ }
+ return !isLeaf();
+ }
+ }
+
+ /**
+ * Sets the split attribute for this node
+ *
+ * @param xBest the split suggestion for the best split
+ * @param splitAttribute the attribute to split on
+ **/
+ protected void setSplitAttribute(AttributeSplitSuggestion xBest, Attribute splitAttribute) {
+ this.splitAttribute = splitAttribute;
+ setSplitTest(xBest.splitTest);
+ }
+
+ protected void resetSplitAttribute() {
+ splitAttribute = null;
+ splitTest = null;
+ }
+
+ /**
+ * Kills the subtree by removing the link to the successors
+ */
+ protected void killSubtree() {
+ successors = null;
+ }
+
+ double[] getClassVotes() {
+ return observedClassDistribution.getArrayCopy();
+ }
+
+ /**
+ * Updates the class statistics (which class occurred how often)
+ * @param instance the current instance
+ */
+ protected void updateStatistics(Instance instance) {
+ observedClassDistribution.addToValue((int) instance.classValue(), instance.weight());
+ }
+
+ /**
+ * Propagates the instance down the tree
+ * @param instance the current instance
+ * @param totalNumInstances the number of instances seen so far
+ */
+ protected void propagateToSuccessors(Instance instance, int totalNumInstances) {
+ Double attValue = instance.value(splitAttribute);
+ CustomEFDTNode successor = successors.getSuccessorNode(attValue);
+ if (successor == null)
+ successor = addSuccessor(instance);
+ if (successor != null)
+ successor.learnInstance(instance, totalNumInstances);
+ }
+
+ /**
+ * Add a successor to the children
+ * @param instance the current instance
+ * @return the creates successor. Null, if the successor could not be created
+ */
+ protected CustomEFDTNode addSuccessor(Instance instance) {
+ List usedNomAttributes = new ArrayList<>(usedNominalAttributes); //deep copy
+ CustomEFDTNode successor = newNode(depth + 1, null, usedNomAttributes);
+ double value = instance.value(splitAttribute);
+ if (splitAttribute.isNominal()) {
+ if (!successors.isBinary()) {
+ boolean success = successors.addSuccessorNominalMultiway(value, successor);
+ return success ? successor : null;
+ } else {
+ boolean success = successors.addSuccessorNominalBinary(value, successor);
+ if (!success) // this is the case if the split is binary nominal but the "left" successor exists.
+ success = successors.addDefaultSuccessorNominalBinary(successor);
+ return success ? successor : null;
+ }
+ } else {
+ if (successors.lowerIsMissing()) {
+ boolean success = successors.addSuccessorNumeric(value, successor, true);
+ return success ? successor : null;
+ } else if (successors.upperIsMissing()) {
+ boolean success = successors.addSuccessorNumeric(value, successor, false);
+ return success ? successor : null;
+ }
+ }
+ return null;
+ }
+
+ protected CustomEFDTNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) {
+ return new CustomEFDTNode(
+ splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence,
+ leafPrediction, minSamplesReevaluate, depth, maxDepth,
+ tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune, nominalObserverBlueprint,
+ classDistribution, usedNominalAttributes, -1 // we don't block attributes in EFDT
+ );
+ }
+
+ protected void updateObservers(Instance instance) {
+ for (int i = 0; i < instance.numAttributes() - 1; i++) { //update likelihood
+ int instAttIndex = modelAttIndexToInstanceAttIndex(i, instance);
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs == null) {
+ obs = instance.attribute(instAttIndex).isNominal() ? newNominalClassObserver() : newNumericClassObserver();
+ this.attributeObservers.set(i, obs);
+ }
+ obs.observeAttributeClass(instance.value(instAttIndex), (int) instance.classValue(), instance.weight());
+ }
+ }
+
+ /**
+ * If this node is a leaf node
+ * @return true if the node is a leaf
+ */
+ boolean isLeaf() {
+ if (successors == null)
+ return true;
+ return successors.size() == 0;
+ }
+
+ /**
+ * If this node only has seen instances of the same class
+ * @return true if the number of observed classes is less than 2
+ */
+ boolean isPure() {
+ return observedClassDistribution.numNonZeroEntries() < 2;
+ }
+
+ protected NominalAttributeClassObserver newNominalClassObserver() {
+ return (NominalAttributeClassObserver) nominalObserverBlueprint.copy();
+ }
+
+ protected NumericAttributeClassObserver newNumericClassObserver() {
+ return (NumericAttributeClassObserver) numericObserverBlueprint.copy();
+ }
+
+ protected List getUsedNominalAttributesForSuccessor(Attribute splitAttribute, Integer splitAttributeIndex) {
+ List usedNomAttributesCpy = new ArrayList<>(usedNominalAttributes); //deep copy
+ if (splitAttribute.isNominal())
+ usedNomAttributesCpy.add(splitAttributeIndex);
+ return usedNomAttributesCpy;
+ }
+
+ protected void updateInfogainSum(AttributeSplitSuggestion[] suggestions) {
+ for (AttributeSplitSuggestion sugg : suggestions) {
+ if (sugg.splitTest != null) {
+ if (!infogainSum.containsKey((sugg.splitTest.getAttsTestDependsOn()[0]))) {
+ infogainSum.put((sugg.splitTest.getAttsTestDependsOn()[0]), 0.0);
+ }
+ double currentSum = infogainSum.get((sugg.splitTest.getAttsTestDependsOn()[0]));
+ infogainSum.put((sugg.splitTest.getAttsTestDependsOn()[0]), currentSum + sugg.merit);
+ } else { // handle the null attribute
+ double currentSum = infogainSum.get(-1); // null split
+ infogainSum.put(-1, Math.max(0.0, currentSum + sugg.merit));
+ assert infogainSum.get(-1) >= 0.0 : "Negative infogain shouldn't be possible here.";
+ }
+ }
+ }
+
+ protected boolean shouldSplitLeaf(AttributeSplitSuggestion[] suggestions,
+ double confidence,
+ DoubleVector observedClassDistribution
+ ) {
+ boolean shouldSplit = false;
+ if (suggestions.length < 2) {
+ shouldSplit = suggestions.length > 0;
+ } else {
+ AttributeSplitSuggestion bestSuggestion = suggestions[suggestions.length - 1];
+
+ double bestSuggestionAverageMerit = bestSuggestion.merit;
+ double currentAverageMerit = 0.0;
+ double eps = computeHoeffdingBound();
+
+ shouldSplit = bestSuggestionAverageMerit - currentAverageMerit > eps || eps < tau;
+ if (bestSuggestion.merit < 1e-10)
+ shouldSplit = false; // we don't use average here
+
+ if (shouldSplit) {
+ for (Integer i : usedNominalAttributes) {
+ if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) {
+ shouldSplit = false;
+ break;
+ }
+ }
+ }
+ }
+ return shouldSplit;
+ }
+
+ /**
+ * Get the merit if the current split
+ * @param suggestions the suggestions for the possible splits
+ * @return the merit of the current split
+ */
+ double getCurrentSuggestionAverageMerit(AttributeSplitSuggestion[] suggestions) {
+ double merit = 0.0;
+ if (splitTest != null) {
+ if (splitTest instanceof NominalAttributeMultiwayTest) {
+ for (AttributeSplitSuggestion s: suggestions) {
+ if (s.splitTest == null)
+ continue;
+ if (s.splitTest.getAttsTestDependsOn()[0] == getSplitAttributeIndex()) {
+ merit = s.merit;
+ break;
+ }
+ }
+ }
+ else if (splitTest instanceof NominalAttributeBinaryTest) {
+ double currentValue = successors.getReferenceValue();
+ NominalAttributeClassObserver obs = (NominalAttributeClassObserver) attributeObservers.get(getSplitAttributeIndex());
+ AttributeSplitSuggestion xCurrent = obs.forceSplit(splitCriterion, observedClassDistribution.getArrayCopy(), getSplitAttributeIndex(), true, currentValue);
+ merit = xCurrent == null ? 0.0 : xCurrent.merit;
+ if (xCurrent != null)
+ merit = xCurrent.splitTest == null ? 0.0 : xCurrent.merit;
+ }
+ else if (splitTest instanceof NumericAttributeBinaryTest) {
+ double currentThreshold = successors.getReferenceValue();
+ GaussianNumericAttributeClassObserver obs = (GaussianNumericAttributeClassObserver) attributeObservers.get(getSplitAttributeIndex());
+ AttributeSplitSuggestion xCurrent = obs.forceSplit(splitCriterion, observedClassDistribution.getArrayCopy(), getSplitAttributeIndex(), currentThreshold);
+ merit = xCurrent == null ? 0.0 : xCurrent.merit;
+ if (xCurrent != null)
+ merit = xCurrent.splitTest == null ? 0.0 : xCurrent.merit;
+ }
+ }
+ return merit == Double.NEGATIVE_INFINITY ? 0.0 : merit;
+ }
+
+ double getSuggestionAverageMerit(InstanceConditionalTest splitTest) {
+ double averageMerit;
+
+ if (splitTest == null) {
+ averageMerit = infogainSum.get(-1) / Math.max(numSplitAttempts, 1.0);
+ } else {
+ Integer key = splitTest.getAttsTestDependsOn()[0];
+ if (!infogainSum.containsKey(key)) {
+ infogainSum.put(key, 0.0);
+ }
+ averageMerit = infogainSum.get(key) / Math.max(numSplitAttempts, 1.0);
+ }
+ return averageMerit;
+ }
+
+ int argmax(double[] array) {
+ double max = array[0];
+ int maxarg = 0;
+
+ for (int i = 1; i < array.length; i++) {
+
+ if (array[i] > max) {
+ max = array[i];
+ maxarg = i;
+ }
+ }
+ return maxarg;
+ }
+
+ public int getSubtreeDepth() {
+ if (isLeaf())
+ return depth;
+ Set succDepths = new HashSet<>();
+ for (CustomEFDTNode successor: successors.getAllSuccessors()) {
+ succDepths.add(successor.getSubtreeDepth());
+ }
+ return Collections.max(succDepths);
+ }
+
+ /**
+ * Replaces the best split suggestion if the node is not allowed to split on that attribute.
+ * This is the case when the parent node splits on that attribute. It prevents splitting the space into thinner and thinner slices.
+ * @param bestSuggestion the best split suggestion
+ * @param suggestions all suggestions
+ * @param blockedAttributeIndex the attribute index of the blocked attribute
+ * @return
+ */
+ AttributeSplitSuggestion replaceBestSuggestionIfAttributeIsBlocked(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion[] suggestions, int blockedAttributeIndex) {
+ if (suggestions.length == 0)
+ return null;
+ if (bestSuggestion.splitTest == null)
+ return bestSuggestion;
+ if (suggestions.length == 1)
+ return bestSuggestion;
+ if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == blockedAttributeIndex) {
+ ArrayUtils.remove(suggestions, suggestions.length - 1);
+ return suggestions[suggestions.length - 1];
+ }
+ return bestSuggestion;
+ }
+
+ protected boolean makeSplit(Attribute splitAttribute, AttributeSplitSuggestion suggestion) {
+ boolean isNominal = splitAttribute.isNominal();
+ boolean isBinary = !(suggestion.splitTest instanceof NominalAttributeMultiwayTest);
+ Double splitValue = null;
+ if (isNominal && isBinary)
+ splitValue = ((NominalAttributeBinaryTest) suggestion.splitTest).getValue();
+ else if (!isNominal)
+ splitValue = ((NumericAttributeBinaryTest) suggestion.splitTest).getSplitValue();
+
+ successors = new Successors(isBinary, !isNominal, splitValue);
+
+ setSplitAttribute(suggestion, splitAttribute);
+ return initializeSuccessors(suggestion, splitAttribute);
+ }
+
+ /**
+ * Checks if the subtree of this node performed split revision
+ * @return true if any node in the subtree performed a split revision
+ */
+ @Override
+ public boolean didPerformTreeRevision() {
+ boolean didRevise = performedTreeRevision;
+ performedTreeRevision = false;
+ if (isLeaf()) {
+ return didRevise;
+ }
+ for (CustomEFDTNode child: successors.getAllSuccessors()) {
+ didRevise |= child.didPerformTreeRevision();
+ }
+ return didRevise;
+ }
+
+ @Override
+ public int getLeafNumber() {
+ if (isLeaf())
+ return 1;
+ int sum = 0;
+ for (CustomEFDTNode s: successors.getAllSuccessors()) {
+ sum += s.getLeafNumber();
+ }
+ return sum;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java
new file mode 100644
index 000000000..9ab193dec
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java
@@ -0,0 +1,80 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.classifiers.core.AttributeSplitSuggestion;
+import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
+import moa.classifiers.core.splitcriteria.SplitCriterion;
+import moa.core.DoubleVector;
+
+import java.util.LinkedList;
+import java.util.List;
+
+public class CustomHTNode extends CustomEFDTNode {
+
+ public CustomHTNode(SplitCriterion splitCriterion,
+ int gracePeriod,
+ Double confidence,
+ Double adaptiveConfidence,
+ boolean useAdaptiveConfidence,
+ String leafPrediction,
+ Integer depth,
+ Integer maxDepth,
+ Double tau,
+ boolean binaryOnly,
+ boolean noPrePrune,
+ NominalAttributeClassObserver nominalObserverBlueprint,
+ DoubleVector observedClassDistribution,
+ List usedNominalAttributes,
+ int blockedAttributeIndex) {
+ super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction,
+ Integer.MAX_VALUE, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune,
+ nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex);
+
+ }
+
+ @Override
+ protected void reevaluateSplit(Instance instance) {
+ }
+
+ @Override
+ protected CustomHTNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) {
+ return new CustomHTNode(
+ splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence,
+ leafPrediction, depth, maxDepth,
+ tau, binaryOnly, noPrePrune, nominalObserverBlueprint,
+ classDistribution, new LinkedList<>(), -1 // we don't block attributes in HT
+ );
+ }
+
+ @Override
+ protected boolean shouldSplitLeaf(AttributeSplitSuggestion[] suggestions,
+ double confidence,
+ DoubleVector observedClassDistribution
+ ) {
+ boolean shouldSplit;
+ if (suggestions.length < 2) {
+ shouldSplit = suggestions.length > 0;
+ } else {
+ AttributeSplitSuggestion bestSuggestion = suggestions[suggestions.length - 1];
+ AttributeSplitSuggestion secondBestSuggestion = suggestions[suggestions.length - 2];
+
+ double bestSuggestionAverageMerit = bestSuggestion.merit;
+ double currentAverageMerit = secondBestSuggestion.merit;
+ double eps = computeHoeffdingBound();
+
+ shouldSplit = bestSuggestionAverageMerit - currentAverageMerit > eps || eps < tau;
+ if (bestSuggestion.merit < 1e-10)
+ shouldSplit = false; // we don't use average here
+
+ if (shouldSplit) {
+ for (Integer i : usedNominalAttributes) {
+ if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) {
+ shouldSplit = false;
+ break;
+ }
+ }
+ }
+ }
+ return shouldSplit;
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java
new file mode 100644
index 000000000..c76596c0c
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java
@@ -0,0 +1,174 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
+import moa.classifiers.core.splitcriteria.SplitCriterion;
+import moa.core.DoubleVector;
+
+import java.util.LinkedList;
+import java.util.List;
+
+public class EFHATNode extends CustomEFDTNode {
+
+ private CustomADWINChangeDetector changeDetector; // we need to access the width of adwin to compute switch significance. This is not possible with the default adwin change detector class.
+ private EFHATNode backgroundLearner;
+ private LinkedList predictions = new LinkedList<>();
+ private boolean isBackgroundLearner = false;
+
+ public EFHATNode(SplitCriterion splitCriterion,
+ int gracePeriod,
+ Double confidence,
+ Double adaptiveConfidence,
+ boolean useAdaptiveConfidence,
+ String leafPrediction,
+ Integer minSamplesReevaluate,
+ Integer depth,
+ Integer maxDepth,
+ Double tau,
+ boolean binaryOnly,
+ boolean noPrePrune,
+ NominalAttributeClassObserver nominalObserverBlueprint,
+ DoubleVector observedClassDistribution,
+ List usedNominalAttributes,
+ int blockedAttributeIndex,
+ CustomADWINChangeDetector changeDetector) {
+ super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction,
+ minSamplesReevaluate, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune,
+ nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex);
+ this.changeDetector = changeDetector == null ? new CustomADWINChangeDetector() : changeDetector;
+ }
+
+ public EFHATNode(SplitCriterion splitCriterion,
+ int gracePeriod,
+ Double confidence,
+ Double adaptiveConfidence,
+ boolean useAdaptiveConfidence,
+ String leafPrediction,
+ Integer minSamplesReevaluate,
+ Integer depth,
+ Integer maxDepth,
+ Double tau,
+ boolean binaryOnly,
+ boolean noPrePrune,
+ NominalAttributeClassObserver nominalObserverBlueprint,
+ DoubleVector observedClassDistribution,
+ List usedNominalAttributes,
+ int blockedAttributeIndex,
+ CustomADWINChangeDetector changeDetector,
+ boolean isBackgroundLearner) {
+ super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction,
+ minSamplesReevaluate, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune,
+ nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex);
+ this.changeDetector = changeDetector == null ? new CustomADWINChangeDetector() : changeDetector;
+ this.isBackgroundLearner = isBackgroundLearner;
+ }
+
+ @Override
+ public double[] predict(Instance instance) {
+ double[] pred = super.predict(instance);
+ if (pred.length > 0)
+ predictions.add((double) argmax(pred));
+ if (backgroundLearner != null)
+ backgroundLearner.predict(instance);
+ return pred;
+ }
+
+ @Override
+ public void learnInstance(Instance instance, int totalNumInstances) {
+ seenWeight += instance.weight();
+ nodeTime++;
+ updateStatistics(instance);
+ updateObservers(instance);
+ updateChangeDetector(instance.classValue());
+
+ if (backgroundLearner != null)
+ backgroundLearner.learnInstance(instance, totalNumInstances);
+
+ if (isLeaf() && nodeTime % gracePeriod == 0)
+ attemptInitialSplit(instance);
+
+ if (!isLeaf()
+// && nodeTime % minSamplesReevaluate == 0
+ )
+ hatGrow();
+
+ if (!isLeaf()) //Do NOT! put this in the upper (!isleaf()) block. This is not the same since we might kill the subtree during reevaluation!
+ propagateToSuccessors(instance, totalNumInstances);
+ }
+
+ private void hatGrow() {
+ if (changeDetector.getChange()) {
+ backgroundLearner = newNode(depth, new DoubleVector(), new LinkedList<>(usedNominalAttributes));
+ backgroundLearner.isBackgroundLearner = true;
+ }
+ else if (backgroundLearner != null) {
+ // adopted from HAT implementation
+ if (changeDetector.getWidth() > minSamplesReevaluate && backgroundLearner.changeDetector.getWidth() > minSamplesReevaluate) {
+ double oldErrorRate = changeDetector.getEstimation();
+ double oldWS = changeDetector.getWidth();
+ double altErrorRate = backgroundLearner.changeDetector.getEstimation();
+ double altWS = backgroundLearner.changeDetector.getWidth();
+
+ double fDelta = .05;
+ //if (gNumAlts>0) fDelta=fDelta/gNumAlts;
+// double fN = 1.0 / w2 + 1.0 / w1;
+ double fN = 1.0 / altWS + 1.0 / oldWS;
+ double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN);
+
+// double significance = switchSignificance(e1, e2, w1, w2);
+ if (Bound < oldErrorRate - altErrorRate) {
+ performedTreeRevision = true;
+ makeBackgroundLearnerMainLearner();
+ }
+ else if (Bound < altErrorRate - oldErrorRate) {
+ // Erase alternate tree
+ backgroundLearner = null;
+ }
+ }
+ }
+ }
+
+ @Override
+ protected EFHATNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) {
+ return new EFHATNode(
+ splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence,
+ leafPrediction, minSamplesReevaluate, depth, maxDepth,
+ tau, binaryOnly, noPrePrune, nominalObserverBlueprint,
+ classDistribution, new LinkedList<>(),
+ -1, // we don't block attributes in HT
+ (CustomADWINChangeDetector) changeDetector.copy(),
+ isBackgroundLearner
+ );
+ }
+
+ private void makeBackgroundLearnerMainLearner() {
+ splitAttribute = backgroundLearner.splitAttribute;
+ successors = backgroundLearner.successors;
+ observedClassDistribution = backgroundLearner.observedClassDistribution;
+ classDistributionAtTimeOfCreation = backgroundLearner.classDistributionAtTimeOfCreation;
+ attributeObservers = backgroundLearner.attributeObservers;
+ seenWeight = backgroundLearner.seenWeight;
+ nodeTime = backgroundLearner.nodeTime;
+ numSplitAttempts = backgroundLearner.numSplitAttempts;
+ setSplitTest(backgroundLearner.getSplitTest());
+ backgroundLearner = null;
+ }
+
+ private void updateChangeDetector(double label) {
+ if (predictions.size() == 0)
+ return;
+ Double pred = predictions.removeFirst();
+ if (pred == null)
+ return;
+ changeDetector.input(pred == label ? 0.0 : 1.0); // monitoring error rate, not accuracy.
+ }
+
+ private void resetIsBackgroundLearnerInSubtree() {
+ isBackgroundLearner = false;
+ if (isLeaf())
+ return;
+ for (CustomEFDTNode n: successors.getAllSuccessors()) {
+ ((EFHATNode) n).resetIsBackgroundLearnerInSubtree();
+ }
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java
new file mode 100644
index 000000000..7a678f7b0
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java
@@ -0,0 +1,409 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
+
+import java.util.*;
+
+class MappedTree implements Iterator {
+
+ private LinkedList branchQueue;
+ private final LinkedList finishedBranches = new LinkedList<>();
+
+ private final Attribute splitAttribute;
+ private final int splitAttributeIndex;
+ private final Double splitValue;
+ private final int maxBranchLength;
+
+ public MappedTree(PlasticNode root, Attribute splitAttribute, int splitAttributeIndex, Double splitValue, int maxBranchLength) {
+ branchQueue = disconnectRoot(root);
+
+ this.splitAttribute = splitAttribute;
+ this.splitAttributeIndex = splitAttributeIndex;
+ this.splitValue = splitValue;
+ this.maxBranchLength = maxBranchLength;
+ }
+
+
+ @Override
+ public boolean hasNext() {
+ return !(branchQueue.size() == 0 && finishedBranches.size() == 0);
+ }
+
+ @Override
+ public PlasticBranch next() {
+ if (finishedBranches.size() > 0) {
+ return finishedBranches.removeFirst();
+ }
+ else {
+ while (finishedBranches.size() == 0) {
+ mapBranches(branchQueue, splitAttribute, splitAttributeIndex, splitValue, maxBranchLength);
+ }
+ return finishedBranches.removeFirst();
+ }
+ }
+
+ private void mapBranches(
+ LinkedList branches,
+ Attribute swapAttribute,
+ int swapAttributeIndex,
+ Double splitValue,
+ int maxBranchLength
+ ) {
+ int numBranches = branches.size();
+ for (int i = 0; i < numBranches; i++) {
+ PlasticBranch branch = branches.removeFirst();
+ boolean branchIsFinished = getEndConditionForBranch(branch, swapAttribute, swapAttributeIndex, splitValue, maxBranchLength);
+ if (branchIsFinished) {
+ expandBranch(branch, swapAttribute, swapAttributeIndex, splitValue);
+ List decoupledBranches = decoupleLastNode(branch);
+ decoupledBranches.forEach(b -> modifyBranch(b, swapAttribute));
+ finishedBranches.addAll(decoupledBranches);
+ branchQueue = branches;
+ return;
+ }
+ PlasticTreeElement lastElement = branch.getLast();
+ PlasticBranch branchExtensions = disconnectSuccessors(lastElement, swapAttribute);
+ for (PlasticTreeElement extension: branchExtensions.getBranchRef()) {
+ PlasticBranch extendedBranch = new PlasticBranch((LinkedList) branch.getBranchRef().clone());
+ extendedBranch.getBranchRef().add(extension);
+ branches.add(extendedBranch);
+ }
+ }
+ branchQueue = branches;
+ }
+
+ private boolean getEndConditionForBranch(PlasticBranch branch, Attribute swapAttribute, int swapAttributeIndex, Double splitValue, int maxBranchLength) {
+ CustomEFDTNode lastNodeOfBranch = branch.getLast().getNode();
+ boolean isSwapAttribute = lastNodeOfBranch.getSplitAttribute() == swapAttribute;
+ boolean isLeaf = lastNodeOfBranch.isLeaf();
+ if (isSwapAttribute || isLeaf)
+ return true;
+ return branch.getBranchRef().size() >= maxBranchLength;
+ }
+
+ private PlasticBranch disconnectSuccessors(PlasticTreeElement ancestor, Attribute swapAttribute) {
+ PlasticBranch branchExtensions = new PlasticBranch();
+ PlasticNode ancestorNode = ancestor.getNode();
+ SuccessorIdentifier key = ancestor.getKey();
+ PlasticNode successor = (PlasticNode) ancestorNode.getSuccessors().getSuccessorNode(key);
+
+ boolean endCondition = successor.isLeaf() || successor.getSplitAttribute() == swapAttribute;
+
+ if (endCondition) {
+ PlasticTreeElement element = new PlasticTreeElement(successor, null);
+ branchExtensions.getBranchRef().add(element);
+ }
+ else {
+ for (SuccessorIdentifier successorSuccessorKey: successor.getSuccessors().getKeyset()) {
+ PlasticTreeElement element = new PlasticTreeElement(successor, successorSuccessorKey);
+ branchExtensions.getBranchRef().add(element);
+ }
+ }
+ return branchExtensions;
+ }
+
+ private LinkedList disconnectRoot(PlasticNode root) {
+ Successors successors = root.getSuccessors();
+ if (successors == null) {
+ return null;
+ }
+ if (successors.size() == 0) {
+ return null;
+ }
+ Set keys = successors.getKeyset();
+ LinkedList branches = new LinkedList<>();
+ for (SuccessorIdentifier key: keys) {
+ PlasticTreeElement newElement = new PlasticTreeElement(root, key);
+ PlasticBranch newBranch = new PlasticBranch();
+ newBranch.getBranchRef().add(newElement);
+ branches.add(newBranch);
+ }
+ return branches;
+ }
+
+ private void expandBranch(PlasticBranch branch, Attribute splitAttribute, int splitAttributeIndex, Double splitValue) {
+ if (splitAttribute.isNominal())
+ splitLastNodeIfRequiredCategorical(branch, splitAttribute, splitAttributeIndex, splitValue);
+ else
+ splitLastNodeIfRequiredNumeric(branch, splitAttribute, splitAttributeIndex, splitValue);
+ }
+
+ private LinkedList decoupleLastNode(PlasticBranch branch) {
+ PlasticNode lastNode = branch.getLast().getNode();
+ LinkedList decoupledBranches = new LinkedList<>();
+ Successors lastNodeSuccessors = lastNode.getSuccessors();
+
+ for (SuccessorIdentifier key: lastNodeSuccessors.getKeyset()) {
+ PlasticBranch branchCopy = branch.copy(); //TODO: Doublecheck if this is fine.
+ PlasticTreeElement replacedEndElement = new PlasticTreeElement(branchCopy.getLast().getNode(), key);
+ branchCopy.getBranchRef().removeLast();
+ branchCopy.getBranchRef().add(replacedEndElement);
+ PlasticTreeElement finalElement = new PlasticTreeElement((PlasticNode) lastNodeSuccessors.getSuccessorNode(key), null);
+ branchCopy.getBranchRef().add(finalElement);
+ decoupledBranches.add(branchCopy);
+ }
+
+ return decoupledBranches;
+ }
+
+ private void splitLastNodeIfRequiredCategorical(PlasticBranch branch,
+ Attribute swapAttribute,
+ int swapAttributeIndex,
+ Double splitValue) {
+ PlasticNode lastNode = branch.getLast().getNode();
+
+ boolean shouldBeBinary = splitValue != null;
+ boolean splitAttributesMatch = lastNode.splitAttribute == swapAttribute;
+
+ if (lastNode.isLeaf() || !splitAttributesMatch) { // Option 1: the split attributes don't match
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ shouldBeBinary
+ );
+ if (lastNode.isLeaf()) {
+ System.out.println("Error");
+ }
+ lastNode.successors.getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial());
+ return;
+ }
+
+ boolean isBinary = lastNode.successors.isBinary();
+
+ if (!isBinary && !shouldBeBinary) // Option 2: the split attributes match and the splits are multiway (this is really the best case possible)
+ // do nothing
+ return;
+
+ if (isBinary && shouldBeBinary) { // Option 3: the split attributes match and also both splits should be binary
+ if (splitValue.equals(lastNode.successors.getReferenceValue()))
+ // do nothing
+ return;
+ Successors lastNodeSuccessors = new Successors(lastNode.getSuccessors(), true);
+
+ // Corresponds to transformation 3 in the paper
+ Attribute lastNodeSplitAttribute = lastNode.getSplitAttribute();
+ int lastNodeSplitAttributeIndex = lastNode.getSplitAttributeIndex();
+ InstanceConditionalTest splitTest = lastNode.getSplitTest();
+ lastNode.successors = null;
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ shouldBeBinary
+ );
+ PlasticNode defaultSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(SuccessorIdentifier.DEFAULT_NOMINAL_VALUE);
+ defaultSuccessor.transferSplit(
+ lastNodeSuccessors, lastNodeSplitAttribute, lastNodeSplitAttributeIndex, splitTest
+ );
+ return;
+ }
+
+ if (!isBinary && shouldBeBinary) { // Option 4: The split attributes match and the current split is multiway while the old one was binary. In this case, we do something similar to the numeric splits
+ // Corresponds to transformation 2 in the paper
+ Successors lastNodeSuccessors = new Successors(lastNode.getSuccessors(), true);
+ Attribute lastNodeSplitAttribute = lastNode.getSplitAttribute();
+ int lastNodeSplitAttributeIndex = lastNode.getSplitAttributeIndex();
+ InstanceConditionalTest splitTest = lastNode.getSplitTest();
+ lastNode.successors = null;
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ shouldBeBinary
+ );
+ PlasticNode defaultSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(SuccessorIdentifier.DEFAULT_NOMINAL_VALUE);
+// if (lastNodeSuccessors.contains(splitValue)) {
+// SuccessorIdentifier foundKey = null;
+// for (SuccessorIdentifier k: lastNodeSuccessors.getKeyset()) {
+// if (k.getSelectorValue() == splitValue) {
+// foundKey = k;
+// break;
+// }
+// }
+// if (foundKey != null)
+// lastNodeSuccessors.removeSuccessor(foundKey);
+// }
+ defaultSuccessor.transferSplit(
+ lastNodeSuccessors, lastNodeSplitAttribute, lastNodeSplitAttributeIndex, splitTest
+ );
+ return;
+ }
+
+ if (isBinary && !shouldBeBinary) { // Option 5: The split is binary but should be multiway. In this case, we force the split and then use the old subtree of the left branch of the old subtree.
+ // Corresponds to transformation 1 in the paper
+ SuccessorIdentifier keyToPreviousSuccessor = new SuccessorIdentifier(false, splitValue, splitValue, false);
+ PlasticNode previousSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(keyToPreviousSuccessor);
+
+ lastNode.successors = null;
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ shouldBeBinary
+ );
+ lastNode.successors.removeSuccessor(keyToPreviousSuccessor);
+ lastNode.successors.addSuccessor(previousSuccessor, keyToPreviousSuccessor);
+ return;
+ }
+
+ else if (lastNode.isLeaf()) {
+ System.out.println("Do nothing");
+ }
+ }
+
+ private void splitLastNodeIfRequiredNumeric(PlasticBranch branch,
+ Attribute swapAttribute,
+ int swapAttributeIndex,
+ Double splitValue) {
+ assert splitValue != null;
+ PlasticNode lastNode = branch.getLast().getNode();
+ boolean forceSplit = lastNode.isLeaf() || lastNode.splitAttribute != swapAttribute;
+
+ if (forceSplit) {
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ true
+ );
+ lastNode.getSuccessors().getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial());
+ return;
+ }
+
+ Double oldThreshold = lastNode.getSuccessors().getReferenceValue();
+ if (splitValue.equals(oldThreshold))
+ return; // do nothing
+
+ if (lastNode.getSuccessors().size() > 1) {
+ updateThreshold(lastNode, swapAttributeIndex, splitValue);
+ return;
+ }
+
+ lastNode.forceSplit(
+ swapAttribute,
+ swapAttributeIndex,
+ splitValue,
+ true
+ );
+ lastNode.getSuccessors().getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial());
+ }
+
+ private void updateThreshold(PlasticNode node, int splitAttributeIndex, double splitValue) {
+ Double oldThreshold = node.getSuccessors().getReferenceValue();
+
+ SuccessorIdentifier leftKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, true);
+ SuccessorIdentifier rightKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, false);
+ PlasticNode succ1 = (PlasticNode) node.getSuccessors().getSuccessorNode(leftKey);
+ PlasticNode succ2 = (PlasticNode) node.getSuccessors().getSuccessorNode(rightKey);
+ Successors newSuccessors = new Successors(true, true, splitValue);
+ if (succ1 != null)
+ newSuccessors.addSuccessorNumeric(splitValue, succ1, true);
+ if (succ2 != null)
+ newSuccessors.addSuccessorNumeric(splitValue, succ2, false);
+ node.successors = newSuccessors;
+
+ if (node.isLeaf())
+ return;
+
+ for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) {
+ PlasticNode s = (PlasticNode) node.getSuccessors().getSuccessorNode(key);
+ removeUnreachableSubtree(s, splitAttributeIndex, splitValue, key.isLower());
+ }
+ }
+
+ private void removeUnreachableSubtree(PlasticNode node, int splitAttributeIndex, double threshold, boolean isLower) {
+ if (node.isLeaf())
+ return;
+
+ if (node.getSplitAttributeIndex() != splitAttributeIndex) {
+ for (CustomEFDTNode successor: node.getSuccessors().getAllSuccessors()) {
+ removeUnreachableSubtree((PlasticNode) successor, splitAttributeIndex, threshold, isLower);
+ }
+ return;
+ }
+
+ Set keysToRemove = new HashSet<>();
+ for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) {
+ assert key.isNumeric();
+ if (isLower) {
+ if (!key.isLower() && key.getSelectorValue() >= threshold) {
+ keysToRemove.add(key);
+ }
+ }
+ else {
+ if (key.isLower() && key.getSelectorValue() <= threshold) {
+ keysToRemove.add(key);
+ }
+ }
+ }
+ for (SuccessorIdentifier key: keysToRemove) {
+ node.getSuccessors().removeSuccessor(key);
+ }
+
+ if (!node.isLeaf()) {
+ node.getSuccessors().getAllSuccessors().forEach(s -> removeUnreachableSubtree((PlasticNode) s, splitAttributeIndex, threshold, isLower));
+ }
+ }
+
+ private void modifyBranch(PlasticBranch branch, Attribute splitAttribute) {
+ putLastElementToFront(branch, splitAttribute);
+ resetSuccessorsInBranch(branch);
+ setRestructuredFlagInBranch(branch);
+ setDepth(branch);
+ }
+
+ private void putLastElementToFront(PlasticBranch branch, Attribute splitAttribute) {
+ PlasticTreeElement oldFirstBranchElement = branch.getBranchRef().getFirst();
+ PlasticTreeElement newFirstBranchElement = branch.getBranchRef().remove(branch.getBranchRef().size() - 2);
+
+ branch.getBranchRef().addFirst(newFirstBranchElement);
+ if (splitAttribute != newFirstBranchElement.getNode().splitAttribute) {
+ System.out.println(branch.getDescription());
+ }
+
+ PlasticNode oldFirstNode = oldFirstBranchElement.getNode();
+ PlasticNode newFirstNode = newFirstBranchElement.getNode();
+
+ //TODO not sure this is actually required! I think it could be sufficient to just change the successors when building the tree in a later step.
+ newFirstNode.observedClassDistribution = oldFirstNode.observedClassDistribution;
+ newFirstNode.depth = oldFirstNode.depth;
+ newFirstNode.attributeObservers = oldFirstNode.attributeObservers;
+ }
+
+ private void resetSuccessorsInBranch(PlasticBranch branch) {
+ if (branch.getBranchRef().size() == 1)
+ return;
+ int i = 0;
+ for (PlasticTreeElement item: branch.getBranchRef()) {
+ if (i == branch.getBranchRef().size() - 1)
+ break;
+ item.getNode().successors = new Successors(item.getNode().getSuccessors(), false);
+ i++;
+ }
+ }
+
+ private void setRestructuredFlagInBranch(PlasticBranch branch) {
+ if (branch.getBranchRef().size() <= 2)
+ return;
+ int i = 0;
+ for (PlasticTreeElement item: branch.getBranchRef()) {
+ if (i == 0 || i == branch.getBranchRef().size() - 1) {
+ i++;
+ continue;
+ }
+ item.getNode().setRestructuredFlag();
+ i++;
+ }
+ }
+
+ private void setDepth(PlasticBranch branch) {
+ PlasticNode firstNode = branch.getBranchRef().getFirst().getNode();
+ int i = 0;
+ for (PlasticTreeElement item: branch.getBranchRef()) {
+ item.getNode().depth = firstNode.getDepth() + i;
+ i++;
+ }
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java
new file mode 100644
index 000000000..3873c5812
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java
@@ -0,0 +1,5 @@
+package moa.classifiers.trees.plastic_util;
+
+public interface MeasuresNumberOfLeaves {
+ public int getLeafNumber();
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java
new file mode 100644
index 000000000..6e87a7986
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java
@@ -0,0 +1,5 @@
+package moa.classifiers.trees.plastic_util;
+
+public interface PerformsTreeRevision {
+ boolean didPerformTreeRevision();
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java
new file mode 100644
index 000000000..00dab0ef4
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java
@@ -0,0 +1,63 @@
+package moa.classifiers.trees.plastic_util;
+
+import java.util.ArrayList;
+import java.util.LinkedList;
+
+class PlasticBranch implements Comparable {
+ private LinkedList branch = new LinkedList<>();
+
+ public PlasticBranch(){}
+
+ public PlasticBranch(LinkedList branch) {
+ this.branch = branch;
+ }
+
+ public LinkedList getBranchRef() {
+ return branch;
+ }
+
+ public String getDescription() {
+ if (branch == null) {
+ return "Branch is null";
+ }
+ StringBuilder s = new StringBuilder();
+ int i = 0;
+ for (PlasticTreeElement e: branch) {
+ i++;
+ s.append(e.getDescription()).append(i == branch.size() ? "" : " --> ");
+ }
+ return s.toString();
+ }
+
+ public PlasticTreeElement getLast() {
+ if (branch == null)
+ return null;
+ if (branch.size() == 0) {
+ return null;
+ }
+ return branch.getLast();
+ }
+
+ public PlasticBranch copy() {
+ PlasticBranch cpy = new PlasticBranch();
+ for (PlasticTreeElement item: branch) {
+ cpy.getBranchRef().add(item.copy());
+ }
+ return cpy;
+ }
+
+ public ArrayList branchArrayCpy() {
+ return new ArrayList<>(branch);
+ }
+
+ public int compareTo(PlasticBranch other)
+ {
+ int a = branch.getLast().getNode().observedClassDistribution.numValues();
+ int b = other.branch.getLast().getNode().observedClassDistribution.numValues();
+ if (a < b)
+ return -1;
+ if (a == b)
+ return 0;
+ return 1;
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java
new file mode 100644
index 000000000..91b056246
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java
@@ -0,0 +1,398 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.classifiers.core.AttributeSplitSuggestion;
+import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
+import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
+import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest;
+import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
+import moa.classifiers.core.splitcriteria.SplitCriterion;
+import moa.core.AutoExpandVector;
+import moa.core.DoubleVector;
+
+import java.util.*;
+
+public class PlasticNode extends CustomEFDTNode {
+
+ private Set childrenSplitAttributes;
+ private boolean nodeGotRestructured = false;
+ private boolean isArtificial = false;
+ private final Restructurer restructurer;
+ protected final int maxBranchLength;
+ protected final double acceptedNumericThresholdDeviation;
+ private boolean isDummy = false;
+
+ protected void setRestructuredFlag() {
+ nodeGotRestructured = true;
+ }
+
+ protected void resetRestructuredFlag() {
+ nodeGotRestructured = false;
+ }
+
+ protected boolean getRestructuredFlag() {
+ return nodeGotRestructured;
+ }
+
+ public PlasticNode(
+ SplitCriterion splitCriterion, int gracePeriod, Double confidence, Double adaptiveConfidence,
+ boolean useAdaptiveConfidence, String leafPrediction, Integer minSamplesReevaluate, Integer depth,
+ Integer maxDepth, Double tau, Double tauReevaluate, Double relMinDeltaG, boolean binaryOnly,
+ boolean noPrePrune, NominalAttributeClassObserver nominalObserverBlueprint,
+ DoubleVector observedClassDistribution, List usedNominalAttributes,
+ int maxBranchLength, double acceptedNumericThresholdDeviation, int blockedAttributeIndex
+ ) {
+ super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction,
+ minSamplesReevaluate, depth, maxDepth, tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune,
+ nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex);
+ this.maxBranchLength = maxBranchLength;
+ this.acceptedNumericThresholdDeviation = acceptedNumericThresholdDeviation;
+ restructurer = new Restructurer(maxBranchLength, acceptedNumericThresholdDeviation);
+ }
+
+ public PlasticNode(PlasticNode other) {
+ super((SplitCriterion) other.splitCriterion.copy(), other.gracePeriod, other.confidence,
+ other.adaptiveConfidence, other.useAdaptiveConfidence, other.leafPrediction,
+ other.minSamplesReevaluate, other.depth, other.maxDepth, other.tau, other.tauReevaluate,
+ other.relMinDeltaG, other.binaryOnly, other.noPrePrune, other.nominalObserverBlueprint,
+ (DoubleVector) other.observedClassDistribution.copy(), other.usedNominalAttributes,
+ other.blockedAttributeIndex);
+ this.acceptedNumericThresholdDeviation = other.acceptedNumericThresholdDeviation;
+ this.maxBranchLength = other.maxBranchLength;
+ if (other.successors != null)
+ this.successors = new Successors(other.successors, true);
+ if (other.getSplitTest() != null)
+ setSplitTest((InstanceConditionalTest) other.getSplitTest().copy());
+ this.infogainSum = new HashMap<>(infogainSum);
+ this.numSplitAttempts = other.numSplitAttempts;
+ this.classDistributionAtTimeOfCreation = other.classDistributionAtTimeOfCreation;
+ this.nodeTime = other.nodeTime;
+ this.splitAttribute = other.splitAttribute;
+ this.seenWeight = other.seenWeight;
+ this.isArtificial = other.isArtificial;
+ if (other.attributeObservers != null)
+ this.attributeObservers = (AutoExpandVector) other.attributeObservers.copy();
+ restructurer = other.restructurer;
+ blockedAttributeIndex = other.blockedAttributeIndex;
+ }
+
+ /**
+ * In some cases during restructuring, we create dummy nodes that we prune restructuring has finished
+ * @return true if the node is a dummy node
+ */
+ public boolean isDummy() {
+ return isDummy;
+ }
+
+ @Override
+ protected PlasticNode addSuccessor(Instance instance) {
+ List usedNomAttributes = new ArrayList<>(usedNominalAttributes); //deep copy
+ PlasticNode successor = newNode(depth + 1, new DoubleVector(), usedNomAttributes);
+ double value = instance.value(splitAttribute);
+ if (splitAttribute.isNominal()) {
+ if (!successors.isBinary()) {
+ boolean success = successors.addSuccessorNominalMultiway(value, successor);
+ return success ? successor : null;
+ } else {
+ boolean success = successors.addSuccessorNominalBinary(value, successor);
+ if (!success) // this is the case if the split is binary nominal but the "left" successor exists.
+ success = successors.addDefaultSuccessorNominalBinary(successor);
+ return success ? successor : null;
+ }
+ } else {
+ NumericAttributeBinaryTest test = (NumericAttributeBinaryTest) getSplitTest();
+ if (successors.lowerIsMissing()) {
+ boolean success = successors.addSuccessorNumeric(test.getValue(), successor, true);
+ return success ? successor : null;
+ } else if (successors.upperIsMissing()) {
+ boolean success = successors.addSuccessorNumeric(test.getValue(), successor, false);
+ return success ? successor : null;
+ }
+ }
+ return null;
+ }
+
+ @Override
+ protected PlasticNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) {
+ return new PlasticNode(
+ splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence,
+ leafPrediction, minSamplesReevaluate, depth, maxDepth,
+ tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune, nominalObserverBlueprint,
+ classDistribution, usedNominalAttributes, maxBranchLength, acceptedNumericThresholdDeviation, getSplitAttributeIndex()
+ );
+ }
+
+ /**
+ * Collect the split attributes of the children and this node.
+ * For performance reasons, this updates the `childrenSplitAttributes` property at every node.
+ * @return the set of attributes the children and this node split on.
+ */
+ public Set collectChildrenSplitAttributes() {
+ childrenSplitAttributes = new HashSet<>();
+ if (isLeaf()) {
+ // we have no split attribute
+ return childrenSplitAttributes;
+ }
+ // add the split attribute of this node to the set
+ childrenSplitAttributes.add(splitAttribute);
+ for (CustomEFDTNode successor : successors.getAllSuccessors()) {
+ // add the split attributes of the subtree
+ childrenSplitAttributes.addAll(((PlasticNode) successor).collectChildrenSplitAttributes());
+ }
+ return childrenSplitAttributes;
+ }
+
+ /**
+ * Simply returns `childrenSplitAttributes`. In doubt, call `collectChildrenSplitAttributes` before accessing the property.
+ * @return the set of attributes the children and this node split on.
+ */
+ public Set getChildrenSplitAttributes() {
+ return childrenSplitAttributes;
+ }
+
+ /**
+ * Transfers the split from another node to this node
+ * @param successors the sucessors of the other node
+ * @param splitAttribute the split attribute of the other node
+ * @param splitAttributeIndex the split attribute index of the other node
+ * @param splitTest the split test of the other node
+ */
+ public void transferSplit(Successors successors,
+ Attribute splitAttribute,
+ int splitAttributeIndex,
+ InstanceConditionalTest splitTest) {
+ this.successors = successors;
+ this.splitAttribute = splitAttribute;
+ setSplitTest(splitTest);
+ }
+
+ /**
+ * Increases the depth property in all subtree nodes by 1
+ */
+ protected void incrementDepthInSubtree() {
+ depth++;
+ if (isLeaf())
+ return;
+ for (SuccessorIdentifier key : successors.getKeyset()) {
+ PlasticNode successor = (PlasticNode) successors.getSuccessorNode(key);
+ successor.incrementDepthInSubtree();
+ }
+ }
+
+ protected void resetObservers() {
+ AutoExpandVector newObservers = new AutoExpandVector<>();
+ for (AttributeClassObserver observer : attributeObservers) {
+ if (observer.getClass() == nominalObserverBlueprint.getClass())
+ newObservers.add(newNominalClassObserver());
+ else
+ newObservers.add(newNumericClassObserver());
+ }
+ attributeObservers = newObservers;
+ }
+
+ /**
+ * If the node was artificially created during restructuring or if it originated from a 'normal' split
+ * @return true if the node was created artificially
+ */
+ public boolean isArtificial() {
+ return isArtificial;
+ }
+
+ public void setIsArtificial() {
+ isArtificial = true;
+ }
+
+ public void setIsArtificial(boolean val) {
+ isArtificial = val;
+ }
+
+ /**
+ * Forces a split of this node. This is required during restructuring to make sure the branch contains the desired split attribute.
+ * See step 3 of the algorithm
+ * @param splitAttribute the attribute to split on
+ * @param splitAttributeIndex the index of the attribute to split on
+ * @param splitValue the value of the split (e.g., for numerical splits or binary-nominal)
+ * @param isBinary flag if the split is binary or multiway
+ * @return true, if the split was successful
+ */
+ protected boolean forceSplit(Attribute splitAttribute, int splitAttributeIndex, Double splitValue, boolean isBinary) {
+ AttributeClassObserver observer = attributeObservers.get(splitAttributeIndex);
+ if (observer == null) {
+ observer = splitAttribute.isNominal() ? newNominalClassObserver() : newNumericClassObserver();
+ this.attributeObservers.set(splitAttributeIndex, observer);
+ }
+
+ boolean success;
+ if (splitAttribute.isNominal()) {
+ NominalAttributeClassObserver nominalObserver = (NominalAttributeClassObserver) observer;
+ AttributeSplitSuggestion suggestion = nominalObserver.forceSplit(
+ splitCriterion, observedClassDistribution.getArrayCopy(), splitAttributeIndex, isBinary, splitValue
+ );
+ if (suggestion != null) {
+ success = makeSplit(splitAttribute, suggestion);
+ } else
+ success = false;
+
+ if (!success) {
+ successors = new Successors(isBinary, splitAttribute.isNumeric(), splitValue);
+ this.splitAttribute = splitAttribute;
+ setSplitTest(suggestion == null ? null : suggestion.splitTest);
+
+ if (!isBinary) {
+ PlasticNode dummyNode = newNode(depth, new DoubleVector(), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex));
+ SuccessorIdentifier dummyKey = new SuccessorIdentifier(splitAttribute.isNumeric(), 0.0, 0.0, false);
+ success = successors.addSuccessor(dummyNode, dummyKey); // will be pruned later on.
+ dummyNode.isDummy = true;
+ return success;
+ } else {
+ PlasticNode a = newNode(depth + 1, new DoubleVector(), new LinkedList<>(usedNominalAttributes));
+ PlasticNode b = newNode(depth + 1, new DoubleVector(), new LinkedList<>(usedNominalAttributes));
+ SuccessorIdentifier keyA = new SuccessorIdentifier(splitAttribute.isNumeric(), splitValue, splitValue, false);
+ successors.addSuccessor(a, keyA);
+ successors.addDefaultSuccessorNominalBinary(b);
+ return true;
+ }
+ }
+ } else {
+ GaussianNumericAttributeClassObserver numericObserver = (GaussianNumericAttributeClassObserver) observer;
+ AttributeSplitSuggestion suggestion = numericObserver.forceSplit(
+ splitCriterion, observedClassDistribution.getArrayCopy(), splitAttributeIndex, splitValue
+ );
+ if (suggestion != null) {
+ success = makeSplit(splitAttribute, suggestion);
+ } else
+ success = false;
+
+ if (!success) {
+ successors = new Successors(isBinary, splitAttribute.isNumeric(), splitValue);
+ this.splitAttribute = splitAttribute;
+ setSplitTest(suggestion == null ? null : suggestion.splitTest);
+
+ for (int i = 0; i < 1; i++) {
+ PlasticNode dummyNode = newNode(depth, new DoubleVector(), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex));
+ SuccessorIdentifier dummyKey = new SuccessorIdentifier(splitAttribute.isNumeric(), splitValue, splitValue, i == 0);
+ success = successors.addSuccessor(dummyNode, dummyKey); // will be pruned later on.
+ dummyNode.isDummy = true;
+ if (!success)
+ break;
+ }
+ }
+ }
+
+ return success;
+ }
+
+ /**
+ * Reevaluate the split and restructure if needed
+ * @param instance the current instance
+ */
+ @Override
+ protected void reevaluateSplit(Instance instance) {
+ if (isPure())
+ return;
+
+ numSplitAttempts++;
+
+ AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSuggestions);
+ if (bestSuggestions.length == 0)
+ return;
+ updateInfogainSum(bestSuggestions);
+
+ AttributeSplitSuggestion[] bestSplitSuggestions = getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSplitSuggestions);
+ AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+
+ double bestSuggestionAverageMerit = bestSuggestion.splitTest == null ? 0.0 : bestSuggestion.merit;
+ double currentAverageMerit = getCurrentSuggestionAverageMerit(bestSuggestions);
+ double deltaG = bestSuggestionAverageMerit - currentAverageMerit;
+ double eps = computeHoeffdingBound();
+
+ if (deltaG > eps || (eps < tauReevaluate && deltaG > tauReevaluate * relMinDeltaG)) {
+
+ if (bestSuggestion.splitTest == null) {
+ System.out.println("preprune - null wins");
+ killSubtree();
+ resetSplitAttribute();
+ return;
+ }
+
+ Attribute newSplitAttribute = instance.attribute(bestSuggestion.splitTest.getAttsTestDependsOn()[0]);
+ boolean success = false;
+ performedTreeRevision = true;
+ if (maxBranchLength > 1) {
+ success = performReordering(bestSuggestion, newSplitAttribute);
+ if (success)
+ setSplitAttribute(bestSuggestion, newSplitAttribute);
+ }
+ if (!success) {
+ makeSplit(newSplitAttribute, bestSuggestion);
+ }
+ nodeTime = 0;
+ seenWeight = 0.0;
+ }
+ }
+
+ /**
+ * Perform the restructuring and replace the subtree with the restructured subtree
+ * @param xBest the suggestion for the best split
+ * @param splitAttribute the attribute of the best split
+ * @return true if restructuring was successful
+ */
+ private boolean performReordering(AttributeSplitSuggestion xBest, Attribute splitAttribute) {
+ Double splitValue = null;
+ InstanceConditionalTest test = xBest.splitTest;
+ if (test instanceof NominalAttributeBinaryTest)
+ splitValue = ((NominalAttributeBinaryTest) test).getValue();
+ else if (test instanceof NumericAttributeBinaryTest)
+ splitValue = ((NumericAttributeBinaryTest) test).getValue();
+
+ PlasticNode restructuredNode = restructurer.restructure(this, xBest, splitAttribute, splitValue);
+
+ if (restructuredNode != null)
+ successors = restructuredNode.getSuccessors();
+
+ return restructuredNode != null;
+ }
+
+ protected void updateUsedNominalAttributesInSubtree(Attribute splitAttribute, Integer splitAttributeIndex) {
+ if (isLeaf())
+ return;
+ for (CustomEFDTNode successor : successors.getAllSuccessors()) {
+ PlasticNode s = (PlasticNode) successor;
+ s.usedNominalAttributes = getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex);
+ s.updateUsedNominalAttributesInSubtree(splitAttribute, splitAttributeIndex);
+ }
+ }
+
+ protected Set getMajorityVotesOfLeaves() {
+ Set majorityVotes = new HashSet<>();
+ if (isLeaf()) {
+ if (observedClassDistribution.numValues() == 0)
+ return majorityVotes;
+ majorityVotes.add((double) argmax(observedClassDistribution.getArrayRef()));
+ return majorityVotes;
+ }
+ for (CustomEFDTNode s : getSuccessors().getAllSuccessors()) {
+ majorityVotes.addAll(((PlasticNode) s).getMajorityVotesOfLeaves());
+ }
+ return majorityVotes;
+ }
+
+ protected void setObservedClassDistribution(double[] newDistribution) {
+ observedClassDistribution = new DoubleVector(newDistribution);
+ classDistributionAtTimeOfCreation = new DoubleVector(newDistribution);
+ mcCorrectWeight = 0.0;
+ nbCorrectWeight = 0.0;
+ }
+
+ protected void resetObservedClassDistribution() {
+ observedClassDistribution = new DoubleVector();
+ classDistributionAtTimeOfCreation = new DoubleVector();
+ mcCorrectWeight = 0.0;
+ nbCorrectWeight = 0.0;
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java
new file mode 100644
index 000000000..16eccecef
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java
@@ -0,0 +1,40 @@
+package moa.classifiers.trees.plastic_util;
+
+class PlasticTreeElement {
+ private PlasticNode node;
+ private SuccessorIdentifier key;
+
+ public PlasticTreeElement(PlasticNode node, SuccessorIdentifier key) {
+ this.node = node;
+ this.key = key;
+ }
+
+ public PlasticTreeElement(PlasticTreeElement other) {
+ this.node = other.node;
+ this.key = other.key;
+ }
+
+ public PlasticNode getNode() {
+ return node;
+ }
+
+ public SuccessorIdentifier getKey() {
+ return key;
+ }
+
+ public String getDescription() {
+ String blueprint = "%s%s -- %s";
+ return String.format(blueprint,
+ node.splitAttribute != null ? node.splitAttribute.toString() : "L",
+ node.isArtificial() ? "*" : "",
+ key != null ? Double.toString(key.getReferencevalue()) : (node.isLeaf() ? "X" : "...")
+ );
+ }
+
+ public PlasticTreeElement copy() {
+ PlasticNode nodeCpy;
+ nodeCpy = new PlasticNode(node);
+ SuccessorIdentifier keyCpy = key != null ? new SuccessorIdentifier(key) : null;
+ return new PlasticTreeElement(nodeCpy, keyCpy);
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java
new file mode 100644
index 000000000..04f88771b
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java
@@ -0,0 +1,315 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import moa.AbstractMOAObject;
+import moa.classifiers.core.AttributeSplitSuggestion;
+import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest;
+import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest;
+import moa.core.DoubleVector;
+
+import java.util.*;
+
+class Restructurer extends AbstractMOAObject {
+ private final int maxBranchLength;
+ private final double acceptedThresholdDeviation;
+
+ public Restructurer(int maxBranchLength,
+ double acceptedNumericThresholdDeviation) {
+ this.maxBranchLength = maxBranchLength;
+ acceptedThresholdDeviation = acceptedNumericThresholdDeviation;
+ }
+
+ public PlasticNode restructure(PlasticNode node, AttributeSplitSuggestion suggestion, Attribute splitAttribute, Double splitValue) {
+ boolean isBinary = !(suggestion.splitTest instanceof NominalAttributeMultiwayTest);
+ int splitAttributeIndex = suggestion.splitTest.getAttsTestDependsOn()[0];
+
+ boolean checkSucceeds = checkPreconditions(node, splitAttribute, splitValue, isBinary);
+
+ if (!checkSucceeds)
+ return null;
+
+ if (splitAttribute == node.splitAttribute && isBinary) {
+ assert splitValue != null;
+ Double currentNominalBinarysplitValue = node.getSuccessors().getReferenceValue();
+ if (currentNominalBinarysplitValue.equals(splitValue))
+ return node;
+ }
+
+ if (node.splitAttribute.isNumeric() && splitAttribute.isNumeric()) {
+ assert splitValue != null;
+ Double currentSplitValue = node.getSuccessors().getReferenceValue();
+ if (node.splitAttribute == splitAttribute) {
+ if (!currentSplitValue.equals(splitValue))
+ updateThreshold(node, splitAttributeIndex, splitValue);
+ return node;
+ }
+ }
+
+// node.collectChildrenSplitAttributes();
+ MappedTree mappedTree = new MappedTree(node, splitAttribute, splitAttributeIndex, splitValue, maxBranchLength);
+ PlasticNode newRoot = reassembleTree(mappedTree);
+
+ newRoot.setSplitAttribute(suggestion, splitAttribute);
+ newRoot.updateUsedNominalAttributesInSubtree(splitAttribute, splitAttributeIndex);
+
+ // Reset counters in restructured nodes
+ newRoot.getSuccessors().getAllSuccessors().forEach(s -> cleanupSubtree((PlasticNode) s));
+
+ // Initialize the statistics of the root's direct successors
+ List sortedKeys = new LinkedList<>(newRoot.getSuccessors().getKeyset());
+ Collections.sort(sortedKeys);
+ for (SuccessorIdentifier key: sortedKeys) {
+ PlasticNode successor = (PlasticNode) newRoot.getSuccessors().getSuccessorNode(key);
+ if (splitAttribute.isNominal()) {
+ double selectorValue = key.getSelectorValue();
+ if (selectorValue == SuccessorIdentifier.DEFAULT_NOMINAL_VALUE) {
+ assert isBinary;
+ successor.setObservedClassDistribution(suggestion.resultingClassDistributions[1]);
+ }
+ else if (selectorValue < suggestion.numSplits()) {
+ successor.setObservedClassDistribution(suggestion.resultingClassDistributionFromSplit((int) selectorValue));
+ }
+ else {
+ successor.resetObservedClassDistribution();
+ }
+ }
+ else {
+ if (key.isLower()) {
+ successor.setObservedClassDistribution(suggestion.resultingClassDistributions[0]);
+ }
+ else {
+ successor.setObservedClassDistribution(suggestion.resultingClassDistributions[1]);
+ }
+ }
+ }
+
+ finalProcessing(node);
+ return newRoot;
+ }
+
+ private boolean checkPreconditions(PlasticNode node, Attribute splitAttribute, Double splitValue, boolean isBinary) {
+ if (node.isLeaf())
+ return false;
+ if (splitAttribute.isNominal()) {
+ if (node.getSplitTest() instanceof NominalAttributeBinaryTest && isBinary) {
+ if (
+ ((NominalAttributeBinaryTest) node.getSplitTest()).getValue() == splitValue
+ && splitAttribute == node.splitAttribute
+ ) {
+ System.err.println("This should never be triggered. A binary re-split with the same attribute and split value should never happen");
+ }
+ }
+ else if (node.getSplitTest() instanceof NominalAttributeMultiwayTest) {
+ if (splitAttribute == node.splitAttribute && !isBinary)
+ System.err.println("This should never be triggered. A multiway re-split on the same nominal attribute should never happen");
+ }
+ }
+ return true;
+ }
+
+ private void updateThreshold(PlasticNode node, int splitAttributeIndex, double splitValue) {
+ Double oldThreshold = node.getSuccessors().getReferenceValue();
+
+ SuccessorIdentifier leftKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, true);
+ SuccessorIdentifier rightKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, false);
+ PlasticNode succ1 = (PlasticNode) node.getSuccessors().getSuccessorNode(leftKey);
+ PlasticNode succ2 = (PlasticNode) node.getSuccessors().getSuccessorNode(rightKey);
+ Successors newSuccessors = new Successors(true, true, splitValue);
+ if (succ1 != null)
+ newSuccessors.addSuccessorNumeric(splitValue, succ1, true);
+ if (succ2 != null)
+ newSuccessors.addSuccessorNumeric(splitValue, succ2, false);
+ node.successors = newSuccessors;
+
+ if (node.isLeaf())
+ return;
+
+ for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) {
+ PlasticNode s = (PlasticNode) node.getSuccessors().getSuccessorNode(key);
+ removeUnreachableSubtree(s, splitAttributeIndex, splitValue, key.isLower());
+ }
+
+ if (Math.abs(splitValue - oldThreshold) > acceptedThresholdDeviation) {
+ setRestructuredFlagInSubtree(node);
+ }
+ }
+
+ private PlasticNode reassembleTree(LinkedList mappedTree) {
+ if (mappedTree.size() == 0) {
+ System.out.println("MappedTree is empty");
+ }
+
+ PlasticNode root = mappedTree.getFirst().getBranchRef().getFirst().getNode();
+ for (PlasticBranch branch: mappedTree) {
+ PlasticNode currentNode = root;
+
+ int depth = 0;
+ for (PlasticTreeElement thisElement: branch.getBranchRef()) {
+ if (depth == branch.getBranchRef().size() - 1)
+ break;
+
+ PlasticNode thisNode = thisElement.getNode();
+ SuccessorIdentifier thisKey = thisElement.getKey();
+ if (currentNode.getSplitAttribute() == thisNode.getSplitAttribute()) {
+ if (currentNode.getSuccessors().contains(thisKey)) {
+ currentNode = (PlasticNode) currentNode.getSuccessors().getSuccessorNode(thisKey);
+ }
+ else {
+ PlasticNode newSuccessor = branch.getBranchRef().get(depth + 1).getNode();
+ boolean success = currentNode.getSuccessors().addSuccessor(newSuccessor, thisKey);
+ assert success;
+ currentNode = newSuccessor;
+ }
+ }
+ depth++;
+ }
+ }
+ return root;
+ }
+
+ private PlasticNode reassembleTree(MappedTree mappedTree) {
+ if (!mappedTree.hasNext()) {
+ System.out.println("MappedTree is empty");
+ }
+
+ PlasticNode root = null;
+ while (mappedTree.hasNext()) {
+ PlasticBranch branch = mappedTree.next();
+ if (root == null)
+ root = branch.getBranchRef().getFirst().getNode();
+
+ PlasticNode currentNode = root;
+
+ int depth = 0;
+ for (PlasticTreeElement thisElement: branch.getBranchRef()) {
+ if (depth == branch.getBranchRef().size() - 1)
+ break;
+
+ PlasticNode thisNode = thisElement.getNode();
+ SuccessorIdentifier thisKey = thisElement.getKey();
+ if (currentNode.getSplitAttribute() == thisNode.getSplitAttribute()) {
+ if (currentNode.getSuccessors().contains(thisKey)) {
+ currentNode = (PlasticNode) currentNode.getSuccessors().getSuccessorNode(thisKey);
+ }
+ else {
+ PlasticNode newSuccessor = branch.getBranchRef().get(depth + 1).getNode();
+ boolean success = currentNode.getSuccessors().addSuccessor(newSuccessor, thisKey);
+ assert success;
+ currentNode = newSuccessor;
+ }
+ }
+ depth++;
+ }
+ }
+ return root;
+ }
+
+ private void cleanupSubtree(PlasticNode node) {
+ if (!node.getRestructuredFlag())
+ return;
+ if (!node.isLeaf()) {
+ node.resetObservedClassDistribution();
+ }
+ node.resetObservers();
+ node.seenWeight = 0.0;
+ node.nodeTime = 0;
+ node.numSplitAttempts = 0;
+ if (!node.isLeaf())
+ node.successors.getAllSuccessors().forEach(s -> cleanupSubtree((PlasticNode) s));
+ }
+
+ private void finalProcessing(PlasticNode node) {
+ node.setIsArtificial(false);
+ if (node.isLeaf()) {
+ return;
+ }
+
+ Set keys = new HashSet<>(node.getSuccessors().getKeyset());
+ boolean allSuccessorsArePure = true;
+ for (SuccessorIdentifier key : keys) {
+ PlasticNode thisNode = (PlasticNode) node.getSuccessors().getSuccessorNode(key);
+ if (thisNode.isDummy()) {
+ node.getSuccessors().removeSuccessor(key);
+ }
+ if (!thisNode.isPure())
+ allSuccessorsArePure = false;
+ }
+
+ if (node.isLeaf() || node.depth >= node.maxDepth) {
+ node.setObservedClassDistribution(collectStatsFromSuccessors(node).getArrayCopy());
+ node.killSubtree();
+ node.resetSplitAttribute();
+ return;
+ }
+
+ if ((allSuccessorsArePure && node.getMajorityVotesOfLeaves().size() <= 1)) {
+ node.setObservedClassDistribution(collectStatsFromSuccessors(node).getArrayCopy());
+ node.killSubtree();
+ node.resetSplitAttribute();
+ return;
+ }
+
+ for (SuccessorIdentifier key : node.getSuccessors().getKeyset()) {
+ PlasticNode successor = (PlasticNode) node.getSuccessors().getSuccessorNode(key);
+ finalProcessing(successor);
+ }
+ }
+
+ private DoubleVector collectStatsFromSuccessors(CustomEFDTNode node) {
+ if (node.isLeaf()) {
+ return node.observedClassDistribution;
+ }
+ else {
+ DoubleVector stats = new DoubleVector();
+ for (CustomEFDTNode successor : node.getSuccessors().getAllSuccessors()) {
+ DoubleVector fromSuccessor = successor.observedClassDistribution; //collectStatsFromSuccessors(successor);
+ stats.addValues(fromSuccessor);
+ }
+ return stats;
+ }
+ }
+
+ private void removeUnreachableSubtree(PlasticNode node, int splitAttributeIndex, double threshold, boolean isLower) {
+ if (node.isLeaf())
+ return;
+
+ if (node.getSplitAttributeIndex() != splitAttributeIndex) {
+ for (CustomEFDTNode successor: node.getSuccessors().getAllSuccessors()) {
+ removeUnreachableSubtree((PlasticNode) successor, splitAttributeIndex, threshold, isLower);
+ }
+ return;
+ }
+
+ Set keysToRemove = new HashSet<>();
+ for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) {
+ assert key.isNumeric();
+ if (isLower) {
+ if (!key.isLower() && key.getSelectorValue() >= threshold) {
+ keysToRemove.add(key);
+ }
+ }
+ else {
+ if (key.isLower() && key.getSelectorValue() <= threshold) {
+ keysToRemove.add(key);
+ }
+ }
+ }
+ for (SuccessorIdentifier key: keysToRemove) {
+ node.getSuccessors().removeSuccessor(key);
+ }
+
+ if (!node.isLeaf()) {
+ node.getSuccessors().getAllSuccessors().forEach(s -> removeUnreachableSubtree((PlasticNode) s, splitAttributeIndex, threshold, isLower));
+ }
+ }
+
+ private void setRestructuredFlagInSubtree(PlasticNode node) {
+ if (node.isLeaf())
+ return;
+ node.setRestructuredFlag();
+ node.getSuccessors().getAllSuccessors().forEach(s -> setRestructuredFlagInSubtree((PlasticNode) s));
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {}
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java
new file mode 100644
index 000000000..b4f847665
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java
@@ -0,0 +1,149 @@
+package moa.classifiers.trees.plastic_util;
+
+import moa.AbstractMOAObject;
+
+import java.util.Objects;
+
+class SuccessorIdentifier extends AbstractMOAObject implements Comparable {
+ public static final double DEFAULT_NOMINAL_VALUE = -1.0;
+ private final boolean isNumeric;
+ private final Double selectorValue;
+ private final Double referenceValue;
+ private final boolean isLower;
+ private int hashCode;
+
+ public SuccessorIdentifier(SuccessorIdentifier other) {
+ this.isLower = other.isLower;
+ this.isNumeric = other.isNumeric;
+ this.selectorValue = other.selectorValue;
+ this.referenceValue = other.referenceValue;
+ hashCode = toString().hashCode();
+ }
+
+ public SuccessorIdentifier(boolean isNumeric, Double referenceValue, Double selectorValue, boolean isLower) {
+ this.isNumeric = isNumeric;
+ this.isLower = isLower;
+ this.selectorValue = selectorValue;
+
+ if (referenceValue == null)
+ this.referenceValue = selectorValue;
+ else
+ this.referenceValue = referenceValue;
+
+ hashCode = toString().hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null)
+ return false;
+ if (getClass() == o.getClass()) {
+ SuccessorIdentifier that = (SuccessorIdentifier) o;
+ boolean equal = isNumeric() == that.isNumeric();
+ equal &= selectorValue.equals(that.getSelectorValue());
+ equal &= referenceValue.equals(that.getReferencevalue());
+ if (isNumeric) {
+ equal &= isLower() == that.isLower();
+ }
+ return equal;
+ }
+ Double that = (Double) o;
+ if (isNumeric)
+ return containsNumericAttribute(that);
+ else {
+ boolean result = Objects.equals(selectorValue, that);
+ if (!result) {
+ // use the default successor if the reference value is not `that` and the selector value is the selectorvalue of the default successor
+ result = Objects.equals(selectorValue, DEFAULT_NOMINAL_VALUE) && !Objects.equals(referenceValue, that);
+ }
+ return result;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return hashCode;
+ }
+
+ @Override
+ public int compareTo(SuccessorIdentifier other) {
+ if (isNumeric != other.isNumeric)
+ return 0;
+ if (isNumeric) {
+ if (!Objects.equals(referenceValue, other.referenceValue))
+ return 0;
+ if (isLower == other.isLower)
+ return 0;
+ return isLower ? -1 : 1;
+ }
+ else {
+ if (selectorValue == null) // this can only happen in the case of a dummy split (which will be pruned after reordering)
+ return 0;
+ if (selectorValue == DEFAULT_NOMINAL_VALUE || other.selectorValue == DEFAULT_NOMINAL_VALUE)
+ return referenceValue == DEFAULT_NOMINAL_VALUE ? 1 : -1;
+ if (selectorValue.equals(other.selectorValue))
+ return 0;
+ if (selectorValue < other.selectorValue)
+ return -1;
+ return 1;
+ }
+ }
+
+ public Double getSelectorValue() {
+ return selectorValue;
+ }
+
+ public Double getReferencevalue() {
+ return referenceValue;
+ }
+
+ private boolean matchesCategoricalValue(double attValue) {
+ if (isNumeric)
+ return false;
+ return selectorValue == attValue;
+ }
+
+ private boolean containsNumericAttribute(double attValue) {
+ if (!isNumeric)
+ return false;
+ return isLower ? attValue <= selectorValue : attValue > selectorValue;
+ }
+
+ public SuccessorIdentifier getOther() {
+ if (!isNumeric)
+ return null;
+ return new SuccessorIdentifier(isNumeric, selectorValue, selectorValue, !isLower);
+ }
+
+ public boolean isNumeric() {
+ return isNumeric;
+ }
+
+ public boolean isLower() {
+ if (!isNumeric)
+ return false;
+ return isLower;
+ }
+
+ public Double getValue() {
+ return selectorValue;
+ }
+
+ public String toString() {
+ if (isNumeric) {
+ String s = "%b%f%f%b";
+ return String.format(s, true, referenceValue, selectorValue, isLower);
+ }
+ else {
+ String s = "%b%f%f";
+ return String.format(s, false, referenceValue, selectorValue);
+ }
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java
new file mode 100644
index 000000000..1f378cda9
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java
@@ -0,0 +1,221 @@
+package moa.classifiers.trees.plastic_util;
+
+import moa.AbstractMOAObject;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Set;
+
+class Successors extends AbstractMOAObject {
+ private Double referenceValue;
+ private HashMap successors = new HashMap<>();
+
+ public Successors(Successors other, boolean transferNodes) {
+ isBinarySplit = other.isBinary();
+ isNumericSplit = !other.isNominal();
+ referenceValue = other.getReferenceValue();
+ if (transferNodes) {
+ successors = new HashMap<>(other.successors);
+ }
+ }
+
+ public Successors(boolean isBinarySplit, boolean isNumericSplit, Double splitValue) {
+ this.isBinarySplit = isBinarySplit;
+ this.isNumericSplit = isNumericSplit;
+ this.referenceValue = splitValue;
+ }
+
+ private final boolean isBinarySplit;
+ private final boolean isNumericSplit;
+
+
+ protected boolean addSuccessor(CustomEFDTNode node, SuccessorIdentifier key) {
+ if (node == null)
+ return false;
+ if (isNumericSplit != key.isNumeric())
+ return false;
+ if (successors.size() >= 2 && isBinary()) {
+ return false;
+ }
+ if (successors.containsKey(key)) {
+ return false;
+ }
+ successors.put(key, node);
+ return true;
+ }
+
+
+ public boolean addSuccessorNumeric(Double attValue, CustomEFDTNode n, boolean isLower) {
+ if (n == null)
+ return false;
+ if (!isNumericSplit)
+ return false;
+ if (successors.size() >= 2)
+ return false;
+ if (referenceValue != null && !referenceValue.equals(attValue))
+ return false;
+
+ SuccessorIdentifier id = new SuccessorIdentifier(true, attValue, attValue, isLower);
+ if (successors.containsKey(id))
+ return false;
+
+ referenceValue = attValue;
+ successors.put(id, n);
+ return true;
+ }
+
+
+ public boolean addSuccessorNominalBinary(Double attValue, CustomEFDTNode n) {
+ if (n == null)
+ return false;
+ if (isNumericSplit)
+ return false;
+ if (!isBinarySplit)
+ return false;
+ if (successors.size() >= 2)
+ return false;
+ if (successors.size() == 1) {
+ SuccessorIdentifier key = (SuccessorIdentifier) successors.keySet().toArray()[0]; // get key of existing successor
+ if (key.getValue() == SuccessorIdentifier.DEFAULT_NOMINAL_VALUE) { // check if the key is the default key for nominal values
+ if (!referenceValue.equals(attValue)) // if the key is the default key, only add the successor if the referenceValue of the split matches the provided value
+ return false;
+ }
+ }
+
+ SuccessorIdentifier id = new SuccessorIdentifier(false, attValue, attValue, false);
+ if (successors.containsKey(id))
+ return false;
+
+ referenceValue = attValue;
+ successors.put(id, n);
+ return true;
+ }
+
+ public boolean addDefaultSuccessorNominalBinary(CustomEFDTNode n) {
+ if (n == null)
+ return false;
+ if (isNumericSplit)
+ return false;
+ if (!isBinarySplit)
+ return false;
+ if (successors.size() >= 2)
+ return false;
+
+ SuccessorIdentifier id = new SuccessorIdentifier(false, referenceValue, SuccessorIdentifier.DEFAULT_NOMINAL_VALUE, false);
+ if (successors.containsKey(id))
+ return false;
+
+ successors.put(id, n);
+ return true;
+ }
+
+ public boolean addSuccessorNominalMultiway(Double attValue, CustomEFDTNode n) {
+ if (n == null)
+ return false;
+ if (isNumericSplit)
+ return false;
+ if (isBinarySplit)
+ return false;
+
+ SuccessorIdentifier id = new SuccessorIdentifier(false, attValue, attValue, false);
+ if (successors.containsKey(id))
+ return false;
+
+ successors.put(id, n);
+ return true;
+ }
+
+ public CustomEFDTNode getSuccessorNode(SuccessorIdentifier key) {
+ return successors.get(key);
+ }
+
+ public CustomEFDTNode getSuccessorNode(Double attributeValue) {
+ for (SuccessorIdentifier s : successors.keySet()) {
+ if (s.equals(attributeValue))
+ return successors.get(s);
+ }
+ return null;
+ }
+
+ public SuccessorIdentifier getSuccessorKey(Object key) {
+ //TODO: Looping over a set is probably not the best way to do this.
+ for (SuccessorIdentifier successorKey : successors.keySet()) {
+ if (successorKey.equals(key)) {
+ return successorKey;
+ }
+ }
+ return null;
+ }
+
+ public boolean isNominal() {
+ return !isNumericSplit;
+ }
+
+ public boolean isBinary() {
+ return isBinarySplit;
+ }
+
+ public boolean contains(Object key) {
+ return successors.containsKey((SuccessorIdentifier) key);
+ }
+
+ public Double getReferenceValue() {
+ return referenceValue;
+ }
+
+ public int size() {
+ return successors.size();
+ }
+
+ public SuccessorIdentifier getMissingKey() {
+ if (successors.size() > 1)
+ return null;
+ SuccessorIdentifier someKey = (SuccessorIdentifier) successors.keySet().toArray()[0];
+ return someKey.getOther();
+ }
+
+ public boolean lowerIsMissing() {
+ SuccessorIdentifier key = getMissingKey();
+ if (key == null)
+ return false;
+ return key.isLower();
+ }
+
+ public boolean upperIsMissing() {
+ SuccessorIdentifier key = getMissingKey();
+ if (key == null)
+ return false;
+ return !key.isLower();
+ }
+
+ public void adjustThreshold(double newThreshold) {
+ HashMap newSuccessors = new HashMap<>();
+ for (SuccessorIdentifier oldId: successors.keySet()) {
+ SuccessorIdentifier newId = new SuccessorIdentifier(true, newThreshold, newThreshold, oldId.isLower());
+ newSuccessors.put(newId, successors.get(oldId));
+ }
+ referenceValue = newThreshold;
+ successors = newSuccessors;
+ }
+
+ public Collection getAllSuccessors() {
+ return successors.values();
+ }
+
+ public Set getKeyset() {
+ return successors.keySet();
+ }
+
+ protected void forceSuccessorForKey(SuccessorIdentifier key, CustomEFDTNode node) {
+ successors.put(key, node);
+ }
+
+ protected CustomEFDTNode removeSuccessor(SuccessorIdentifier key) {
+ return successors.remove(key);
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+
+ }
+}
diff --git a/moa/src/main/java/moa/clusterers/clustream/Clustream.java b/moa/src/main/java/moa/clusterers/clustream/Clustream.java
index 58e0428bd..93e221c7a 100644
--- a/moa/src/main/java/moa/clusterers/clustream/Clustream.java
+++ b/moa/src/main/java/moa/clusterers/clustream/Clustream.java
@@ -14,8 +14,8 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
- *
- *
+ *
+ *
*/
package moa.clusterers.clustream;
@@ -98,7 +98,9 @@ public void trainOnInstanceImpl(Instance instance) {
// Clustering kmeans_clustering = kMeans(k, buffer);
for ( int i = 0; i < kmeans_clustering.size(); i++ ) {
- kernels[i] = new ClustreamKernel( new DenseInstance(1.0,centers[i].getCenter()), dim, timestamp, t, m );
+ Instance newInstance = new DenseInstance(1.0,centers[i].getCenter());
+ newInstance.setDataset(instance.dataset());
+ kernels[i] = new ClustreamKernel(newInstance, dim, timestamp, t, m );
}
buffer.clear();
@@ -111,7 +113,7 @@ public void trainOnInstanceImpl(Instance instance) {
double minDistance = Double.MAX_VALUE;
for ( int i = 0; i < kernels.length; i++ ) {
//System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation());
- double distance = distance(instance.toDoubleArray(), kernels[i].getCenter() );
+ double distance = distanceIgnoreNaN(instance.toDoubleArray(), kernels[i].getCenter() );
if ( distance < minDistance ) {
closestKernel = kernels[i];
minDistance = distance;
@@ -213,6 +215,26 @@ private static double distance(double[] pointA, double [] pointB){
return Math.sqrt(distance);
}
+ /***
+ * This function avoids the undesirable situation where the whole distance becomes NaN if one of the attributes
+ * is NaN.
+ * (SSL) This was observed when calculating the distance between an instance without the class label and a center
+ * which was updated using the class label.
+ * @param pointA
+ * @param pointB
+ * @return
+ */
+ public static double distanceIgnoreNaN(double[] pointA, double [] pointB){
+ double distance = 0.0;
+ for (int i = 0; i < pointA.length; i++) {
+ if(!(Double.isNaN(pointA[i]) || Double.isNaN(pointB[i]))) {
+ double d = pointA[i] - pointB[i];
+ distance += d * d;
+ }
+ }
+ return Math.sqrt(distance);
+ }
+
//wrapper... we need to rewrite kmeans to points, not clusters, doesnt make sense anymore
// public static Clustering kMeans( int k, ArrayList points, int dim ) {
// ArrayList cl = new ArrayList();
diff --git a/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java b/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java
index d4f901ba4..609ad8fdb 100644
--- a/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java
+++ b/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java
@@ -14,8 +14,8 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
- *
- *
+ *
+ *
*/
package moa.clusterers.clustream;
@@ -25,9 +25,9 @@
import com.yahoo.labs.samoa.instances.Instance;
public class ClustreamKernel extends CFCluster {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private final static double EPSILON = 0.00005;
+ private final static double EPSILON = 0.00005;
public static final double MIN_VARIANCE = 1e-50;
protected double LST;
@@ -36,66 +36,111 @@ public class ClustreamKernel extends CFCluster {
int m;
double t;
+ public double[] classObserver;
+
+ public static int ID_GENERATOR = 0;
- public ClustreamKernel( Instance instance, int dimensions, long timestamp , double t, int m) {
+ public ClustreamKernel(Instance instance, int dimensions, long timestamp , double t, int m) {
super(instance, dimensions);
+
+// Avoid situations where the instance header hasn't been defined and runtime errors.
+ if(instance.dataset() != null) {
+ this.classObserver = new double[instance.numClasses()];
+// instance.numAttributes() <= instance.classIndex() -> edge case where the class index is equal the
+// number of attributes (i.e. there is no class value in the attributes array).
+ if (instance.numAttributes() > instance.classIndex() &&
+ !instance.classIsMissing() &&
+ instance.classValue() >= 0 &&
+ instance.classValue() < instance.numClasses()) {
+ this.classObserver[(int) instance.classValue()]++;
+ }
+ }
+ this.setId(ID_GENERATOR++);
this.t = t;
this.m = m;
this.LST = timestamp;
- this.SST = timestamp*timestamp;
+ this.SST = timestamp*timestamp;
}
public ClustreamKernel( ClustreamKernel cluster, double t, int m ) {
super(cluster);
+ this.setId(ID_GENERATOR++);
this.t = t;
this.m = m;
this.LST = cluster.LST;
this.SST = cluster.SST;
+ this.classObserver = cluster.classObserver;
}
public void insert( Instance instance, long timestamp ) {
- N++;
- LST += timestamp;
- SST += timestamp*timestamp;
-
- for ( int i = 0; i < instance.numValues(); i++ ) {
- LS[i] += instance.value(i);
- SS[i] += instance.value(i)*instance.value(i);
- }
+ if(this.classObserver == null)
+ this.classObserver = new double[instance.numClasses()];
+ if(!instance.classIsMissing() &&
+ instance.classValue() >= 0 &&
+ instance.classValue() < instance.numClasses()) {
+ this.classObserver[(int)instance.classValue()]++;
+ }
+ N++;
+ LST += timestamp;
+ SST += timestamp*timestamp;
+
+ for ( int i = 0; i < instance.numValues(); i++ ) {
+ LS[i] += instance.value(i);
+ SS[i] += instance.value(i)*instance.value(i);
+ }
}
@Override
public void add( CFCluster other2 ) {
ClustreamKernel other = (ClustreamKernel) other2;
- assert( other.LS.length == this.LS.length );
- this.N += other.N;
- this.LST += other.LST;
- this.SST += other.SST;
-
- for ( int i = 0; i < LS.length; i++ ) {
- this.LS[i] += other.LS[i];
- this.SS[i] += other.SS[i];
- }
+ assert( other.LS.length == this.LS.length );
+ this.N += other.N;
+ this.LST += other.LST;
+ this.SST += other.SST;
+ this.classObserver = sumClassObservers(other.classObserver, this.classObserver);
+
+ for ( int i = 0; i < LS.length; i++ ) {
+ this.LS[i] += other.LS[i];
+ this.SS[i] += other.SS[i];
+ }
}
+ private double[] sumClassObservers(double[] A, double[] B) {
+ double[] result = null;
+ if (A != null && B != null) {
+ result = new double[A.length];
+ if(A.length == B.length)
+ for(int i = 0 ; i < A.length ; ++i)
+ result[i] += A[i] + B[i];
+ }
+ return result;
+ }
+
+// @Override
+// public void add( CFCluster other2, long timestamp) {
+// this.add(other2);
+// // accumulate the count
+// this.accumulateWeight(other2, timestamp);
+// }
+
public double getRelevanceStamp() {
- if ( N < 2*m )
- return getMuTime();
-
- return getMuTime() + getSigmaTime() * getQuantile( ((double)m)/(2*N) );
+ if ( N < 2*m )
+ return getMuTime();
+
+ return getMuTime() + getSigmaTime() * getQuantile( ((double)m)/(2*N) );
}
private double getMuTime() {
- return LST / N;
+ return LST / N;
}
private double getSigmaTime() {
- return Math.sqrt(SST/N - (LST/N)*(LST/N));
+ return Math.sqrt(SST/N - (LST/N)*(LST/N));
}
private double getQuantile( double z ) {
- assert( z >= 0 && z <= 1 );
- return Math.sqrt( 2 ) * inverseError( 2*z - 1 );
+ assert( z >= 0 && z <= 1 );
+ return Math.sqrt( 2 ) * inverseError( 2*z - 1 );
}
@Override
@@ -187,7 +232,7 @@ private double[] getVarianceVector() {
}
}
else{
-
+
}
}
return res;
@@ -223,7 +268,7 @@ private double calcNormalizedDistance(double[] point) {
return Math.sqrt(res);
}
- /**
+ /**
* Approximates the inverse error function. Clustream needs this.
* @param x
*/
@@ -266,7 +311,7 @@ protected void getClusterSpecificInfo(ArrayList infoTitle, ArrayList windowedResults;
public double[] cumulativeResults;
- public ArrayList targets;
- public ArrayList predictions;
-
+ public ArrayList targets;
+ public ArrayList predictions;
public HashMap otherMeasurements;
- public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults) {
- this.windowedResults = windowedResults;
- this.cumulativeResults = cumulativeResults;
- this.targets = null;
- this.predictions = null;
- }
-
- public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults,
- ArrayList targets, ArrayList predictions) {
+ public PrequentialResult(
+ ArrayList windowedResults,
+ double[] cumulativeResults,
+ ArrayList targets,
+ ArrayList predictions,
+ HashMap otherMeasurements
+ ) {
this.windowedResults = windowedResults;
this.cumulativeResults = cumulativeResults;
this.targets = targets;
this.predictions = predictions;
+ this.otherMeasurements = otherMeasurements;
}
- /***
- * This constructor is useful to store metrics beyond the evaluation metrics available through the evaluators.
- * @param windowedResults
- * @param cumulativeResults
- * @param otherMeasurements
- */
- public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults,
- HashMap otherMeasurements) {
- this(windowedResults, cumulativeResults);
- this.otherMeasurements = otherMeasurements;
+ public PrequentialResult(
+ ArrayList windowedResults,
+ double[] cumulativeResults,
+ ArrayList targets,
+ ArrayList predictions
+ ) {
+ this(windowedResults, cumulativeResults, targets, predictions, null);
+ }
+
+ public PrequentialResult(
+ ArrayList windowedResults,
+ double[] cumulativeResults,
+ HashMap otherMeasurements
+ ) {
+ this(windowedResults, cumulativeResults, null, null, otherMeasurements);
}
}
@@ -93,11 +74,13 @@ public PrequentialResult(ArrayList windowedResults, double[] cumulativ
* @param windowedEvaluator
* @param maxInstances
* @param windowSize
- * @return PrequentialResult is a custom class that holds the respective results from the execution
+ * @param storeY
+ * @param storePredictions
+ * @return the return has to be an ArrayList because we don't know ahead of time how many windows will be produced
*/
public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Learner learner,
- LearningPerformanceEvaluator basicEvaluator,
- LearningPerformanceEvaluator windowedEvaluator,
+ LearningPerformanceEvaluator> basicEvaluator,
+ LearningPerformanceEvaluator> windowedEvaluator,
long maxInstances, long windowSize,
boolean storeY, boolean storePredictions) {
int instancesProcessed = 0;
@@ -106,27 +89,184 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear
stream.restart();
ArrayList windowed_results = new ArrayList<>();
- ArrayList targetValues = new ArrayList<>();
- ArrayList predictions = new ArrayList<>();
+ ArrayList targetValues = new ArrayList<>();
+ ArrayList predictions = new ArrayList<>();
while (stream.hasMoreInstances() &&
(maxInstances == -1 || instancesProcessed < maxInstances)) {
Example instance = stream.nextInstance();
- if (storeY)
- targetValues.add(instance.getData().classValue());
double[] prediction = learner.getVotesForInstance(instance);
+
+ // Update evaluators and store predictions if requested
if (basicEvaluator != null)
basicEvaluator.addResult(instance, prediction);
if (windowedEvaluator != null)
windowedEvaluator.addResult(instance, prediction);
-
if (storePredictions)
- predictions.add(prediction.length == 0? 0 : prediction[0]);
+ predictions.add(Utils.maxIndex(prediction));
+ if (storeY)
+ targetValues.add((int)Math.round(instance.getData().classValue()));
learner.trainOnInstance(instance);
+ instancesProcessed++;
+
+ // Store windowed results if requested
+ if (windowedEvaluator != null)
+ if (instancesProcessed % windowSize == 0) {
+ Measurement[] measurements = windowedEvaluator.getPerformanceMeasurements();
+ double[] values = new double[measurements.length];
+ for (int i = 0; i < values.length; ++i)
+ values[i] = measurements[i].getValue();
+ windowed_results.add(values);
+ }
+ }
+ if (windowedEvaluator != null)
+ if (instancesProcessed % windowSize != 0) {
+ Measurement[] measurements = windowedEvaluator.getPerformanceMeasurements();
+ double[] values = new double[measurements.length];
+ for (int i = 0; i < values.length; ++i)
+ values[i] = measurements[i].getValue();
+ windowed_results.add(values);
+ }
+
+ double[] cumulative_results = null;
+
+ if (basicEvaluator != null) {
+ Measurement[] measurements = basicEvaluator.getPerformanceMeasurements();
+ cumulative_results = new double[measurements.length];
+ for (int i = 0; i < cumulative_results.length; ++i)
+ cumulative_results[i] = measurements[i].getValue();
+ }
+
+ return new PrequentialResult(
+ windowed_results,
+ cumulative_results,
+ targetValues,
+ predictions
+ );
+ }
+
+ public static PrequentialResult PrequentialSSLEvaluation(
+ ExampleStream> stream,
+ Learner learner,
+ LearningPerformanceEvaluator basicEvaluator,
+ LearningPerformanceEvaluator windowedEvaluator,
+ long maxInstances,
+ long windowSize,
+ long initialWindowSize,
+ long delayLength,
+ double labelProbability,
+ int randomSeed,
+ boolean debugPseudoLabels,
+ boolean storeY,
+ boolean storePredictions
+ ) {
+// int delayLength = this.delayLengthOption.getValue();
+// double labelProbability = this.labelProbabilityOption.getValue();
+
+ RandomGenerator taskRandom = new MersenneTwister(randomSeed);
+// ExampleStream stream = (ExampleStream) getPreparedClassOption(this.streamOption);
+// Learner learner = getLearner(stream);
+
+ int instancesProcessed = 0;
+ int numCorrectPseudoLabeled = 0;
+ int numUnlabeledData = 0;
+ int numInstancesTested = 0;
+
+ if (!stream.hasMoreInstances())
+ stream.restart();
+
+ ArrayList windowed_results = new ArrayList<>();
+
+ ArrayList targetValues = new ArrayList<>();
+ ArrayList predictions = new ArrayList<>();
+ HashMap other_measures = new HashMap<>();
+
+ // The buffer is a list of tuples. The first element is the index when
+ // it should be emitted. The second element is the instance itself.
+ List>> delayBuffer = new ArrayList>>();
+
+ while (stream.hasMoreInstances() &&
+ (maxInstances == -1 || instancesProcessed < maxInstances)) {
+
+ // TRAIN on delayed instances
+ while (delayBuffer.size() > 0
+ && delayBuffer.get(0).getKey() == instancesProcessed) {
+ Example delayedExample = delayBuffer.remove(0).getValue();
+// System.out.println("[TRAIN][DELAY] "+delayedExample.getData().toString());
+ learner.trainOnInstance(delayedExample);
+ }
+
+ Example instance = stream.nextInstance();
+ Example unlabeledExample = instance.copy();
+ int trueClass = (int) ((Instance) instance.getData()).classValue();
+
+ // In case it is set, then the label is not removed. We want to pass the
+ // labelled data to the learner even in trainOnUnlabeled data to generate statistics such as number
+ // of correctly pseudo-labeled instances.
+ if (!debugPseudoLabels) {
+ // Remove the label of the unlabeledExample indirectly through
+ // unlabeledInstanceData.
+ Instance __instance = (Instance) unlabeledExample.getData();
+ __instance.setMissing(__instance.classIndex());
+ }
+
+ // WARMUP
+ // Train on the initial instances. These are not used for testing!
+ if (instancesProcessed < initialWindowSize) {
+// if (learner instanceof SemiSupervisedLearner)
+// ((SemiSupervisedLearner) learner).addInitialWarmupTrainingInstances();
+// System.out.println("[TRAIN][INITIAL_WINDOW] "+instance.getData().toString());
+ learner.trainOnInstance(instance);
+ instancesProcessed++;
+ continue;
+ }
+
+ Boolean is_labeled = labelProbability > taskRandom.nextDouble();
+ if (!is_labeled) {
+ numUnlabeledData++;
+ }
+
+ // TEST
+ // Obtain the prediction for the testInst (i.e. no label)
+// System.out.println("[TEST] " + unlabeledExample.getData().toString());
+ double[] prediction = learner.getVotesForInstance(unlabeledExample);
+ numInstancesTested++;
+
+ if (basicEvaluator != null)
+ basicEvaluator.addResult(instance, prediction);
+ if (windowedEvaluator != null)
+ windowedEvaluator.addResult(instance, prediction);
+ if (storeY)
+ targetValues.add((int)Math.round(instance.getData().classValue()));
+ if (storePredictions)
+ predictions.add(Utils.maxIndex(prediction));
+
+ int pseudoLabel = -1;
+ // TRAIN
+ if (is_labeled && delayLength >= 0) {
+ // The instance will be labeled but has been delayed
+ if (learner instanceof SemiSupervisedLearner) {
+// System.out.println("[TRAIN_UNLABELED][DELAYED] " + unlabeledExample.getData().toString());
+ pseudoLabel = ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData());
+ }
+ delayBuffer.add(new MutablePair<>(1 + instancesProcessed + delayLength, instance));
+ } else if (is_labeled) {
+// System.out.println("[TRAIN] " + instance.getData().toString());
+ // The instance will be labeled and is not delayed e.g delayLength = -1
+ learner.trainOnInstance(instance);
+ } else {
+ // The instance will never be labeled
+ if (learner instanceof SemiSupervisedLearner) {
+// System.out.println("[TRAIN_UNLABELED][IMMEDIATE] " + unlabeledExample.getData().toString());
+ pseudoLabel = ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData());
+ }
+ }
+ if(trueClass == pseudoLabel)
+ numCorrectPseudoLabeled++;
instancesProcessed++;
@@ -156,62 +296,153 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear
for (int i = 0; i < cumulative_results.length; ++i)
cumulative_results[i] = measurements[i].getValue();
}
- if (!storePredictions && !storeY)
- return new PrequentialResult(windowed_results, cumulative_results);
- else
- return new PrequentialResult(windowed_results, cumulative_results, targetValues, predictions);
+
+ // TODO: Add this measures in a windowed way.
+ other_measures.put("num_unlabeled_instances", (double) numUnlabeledData);
+ other_measures.put("num_correct_pseudo_labeled", (double) numCorrectPseudoLabeled);
+ other_measures.put("num_instances_tested", (double) numInstancesTested);
+ other_measures.put("pseudo_label_accuracy", (double) numCorrectPseudoLabeled/numInstancesTested);
+
+
+ return new PrequentialResult(
+ windowed_results,
+ cumulative_results,
+ targetValues,
+ predictions,
+ other_measures
+ );
}
+ /******************************************************************************************************************/
+ /******************************************************************************************************************/
+ /***************************************** TESTS ******************************************************************/
+ /******************************************************************************************************************/
+ /******************************************************************************************************************/
+
+ private static void testPrequentialSSL(String file_path, Learner learner,
+ long maxInstances,
+ long windowSize,
+ long initialWindowSize,
+ long delayLength,
+ double labelProbability) {
+ System.out.println(
+ "maxInstances: " + maxInstances + ", " +
+ "windowSize: " + windowSize + ", " +
+ "initialWindowSize: " + initialWindowSize + ", " +
+ "delayLength: " + delayLength + ", " +
+ "labelProbability: " + labelProbability
+ );
- /***
- * The following code can be used to provide examples of how to use the class.
- * In the future, some of these examples can be turned into tests.
- * @param args
- */
- public static void main(String[] args) {
- examplePrequentialEvaluation_edge_cases1();
- examplePrequentialEvaluation_edge_cases2();
- examplePrequentialEvaluation_edge_cases3();
- examplePrequentialEvaluation_edge_cases4();
- examplePrequentialEvaluation_SampleFrequency_TestThenTrain();
- examplePrequentialRegressionEvaluation();
- examplePrequentialEvaluation();
- exampleTestThenTrainEvaluation();
- exampleWindowedEvaluation();
-
- // Run time efficiency evaluation examples
- StreamingRandomPatches srp10 = new StreamingRandomPatches();
- srp10.getOptions().setViaCLIString("-s 10"); // 10 learners
- srp10.setRandomSeed(5);
- srp10.prepareForUse();
-
- StreamingRandomPatches srp100 = new StreamingRandomPatches();
- srp100.getOptions().setViaCLIString("-s 100"); // 100 learners
- srp100.setRandomSeed(5);
- srp100.prepareForUse();
-
- int maxInstances = 100000;
- examplePrequentialEfficiency(srp10, maxInstances);
- examplePrequentialEfficiency(srp100, maxInstances);
+ // Record the start time
+ long startTime = System.currentTimeMillis();
+
+ ArffFileStream stream = new ArffFileStream(file_path, -1);
+ stream.prepareForUse();
+
+ BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
+ basic_evaluator.recallPerClassOption.setValue(true);
+ basic_evaluator.prepareForUse();
+
+ WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator();
+ windowed_evaluator.widthOption.setValue((int) windowSize);
+ windowed_evaluator.prepareForUse();
+
+ PrequentialResult result = PrequentialSSLEvaluation(stream, learner,
+ basic_evaluator,
+ windowed_evaluator,
+ maxInstances,
+ windowSize,
+ initialWindowSize,
+ delayLength,
+ labelProbability,
+ 1,
+ true,
+ false,
+ false
+ );
+
+ // Record the end time
+ long endTime = System.currentTimeMillis();
+
+ // Calculate the elapsed time in milliseconds
+ long elapsedTime = endTime - startTime;
+
+ // Print the elapsed time
+ System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds");
+ System.out.println("Number of unlabeled instances: " + result.otherMeasurements.get("num_unlabeled_instances"));
+
+ System.out.println("\tBasic performance");
+ for (int i = 0; i < result.cumulativeResults.length; ++i)
+ System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + result.cumulativeResults[i]);
+
+ System.out.println("\tWindowed performance");
+ for (int j = 0; j < result.windowedResults.size(); ++j) {
+ System.out.print("Window: " + j + ", ");
+ for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i)
+ System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + result.windowedResults.get(j)[i]);
+ }
}
+ public static void main(String[] args) {
+ String hyper_arff = "/Users/gomeshe/Desktop/data/Hyper100k.arff";
+ String debug_arff = "/Users/gomeshe/Desktop/data/debug_prequential_SSL.arff";
+ String ELEC_arff = "/Users/gomeshe/Dropbox/ciencia_computacao/lecturer/research/ssl_disagreement/datasets/ELEC/elecNormNew.arff";
+
+ NaiveBayes learner = new NaiveBayes();
+ learner.prepareForUse();
+
+// testPrequentialSSL(debug_arff, learner, 100, 10, 0, 0, 1.0); // OK
+// testPrequentialSSL(debug_arff, learner, 100, 10, 1, 0, 1.0); //OK
+// testPrequentialSSL(debug_arff, learner, 10, 10, 5, 0, 1.0); // OK
+// testPrequentialSSL(debug_arff, learner, 10, 10, -1, 1, 1.0); // OK
+// testPrequentialSSL(debug_arff, learner, 20, 10, -1, 10, 1.0); // OK
+// testPrequentialSSL(debug_arff, learner, 20, 10, -1, 2, 0.5); // OK
+// testPrequentialSSL(debug_arff, learner, 100, 10, 50, 2, 0.0); // OK
+// testPrequentialSSL(debug_arff, learner, 100, 10, 0, 90, 1.0); // OK
+// testPrequentialSSL(debug_arff, learner, 100, 10, 0, -1, 0.5); // OK
+
+// testPrequentialSSL(hyper_arff, learner, -1, 1000, -1, -1, 1.0);
+// testPrequentialSSL(hyper_arff, learner, -1, 1000, -1, -1, 0.5); // OK
+
+// testPrequentialSSL(hyper_arff, learner, -1, 1000, 1000, -1, 0.5);
+
+ ClusterAndLabelClassifier ssl_learner = new ClusterAndLabelClassifier();
+ ssl_learner.prepareForUse();
+
+ testPrequentialSSL(ELEC_arff, ssl_learner, 10000, 1000, -1, -1, 0.01);
+
+// testWindowedEvaluation();
+// testTestThenTrainEvaluation();
+// testPrequentialEvaluation();
+//
+// StreamingRandomPatches learner = new StreamingRandomPatches();
+// learner.getOptions().setViaCLIString("-s 100"); // 10 learners
+//// learner.setRandomSeed(5);
+// learner.prepareForUse();
+// testPrequentialEfficiency1(learner);
+
+// testPrequentialEvaluation_edge_cases1();
+// testPrequentialEvaluation_edge_cases2();
+// testPrequentialEvaluation_edge_cases3();
+// testPrequentialEvaluation_edge_cases4();
+// testPrequentialEvaluation_SampleFrequency_TestThenTrain();
+
+// testPrequentialRegressionEvaluation();
+ }
- private static void examplePrequentialEfficiency(Learner learner, int maxInstances) {
- System.out.println("Assessing efficiency for " + learner.getCLICreationString(learner.getClass()) +
- " maxInstances: " + maxInstances);
+ private static void testPrequentialEfficiency1(Learner learner) {
// Record the start time
long startTime = System.currentTimeMillis();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
basic_evaluator.recallPerClassOption.setValue(true);
basic_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null,
- maxInstances, 1, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null, 100000, 1, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -227,18 +458,25 @@ private static void examplePrequentialEfficiency(Learner learner, int maxInstanc
System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]);
}
- private static void examplePrequentialEvaluation_edge_cases1() {
+ private static void testPrequentialEvaluation_edge_cases1() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
learner.prepareForUse();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, null, null,
- 100000, 1000, false, false);
+// BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
+// basic_evaluator.recallPerClassOption.setValue(true);
+// basic_evaluator.prepareForUse();
+//
+// WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator();
+// windowed_evaluator.widthOption.setValue(1000);
+// windowed_evaluator.prepareForUse();
+
+ PrequentialResult results = PrequentialEvaluation(stream, learner, null, null, 100000, 1000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -248,16 +486,28 @@ private static void examplePrequentialEvaluation_edge_cases1() {
// Print the elapsed time
System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds");
+
+// System.out.println("\tBasic performance");
+// for (int i = 0; i < results.basicResults.length; ++i)
+// System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.basicResults[i]);
+
+// System.out.println("\tWindowed performance");
+// for (int j = 0; j < results.windowedResults.size(); ++j) {
+// System.out.println("\t" + j);
+// for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i)
+// System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]);
+// }
}
- private static void examplePrequentialEvaluation_edge_cases2() {
+
+ private static void testPrequentialEvaluation_edge_cases2() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
learner.prepareForUse();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
@@ -268,8 +518,7 @@ private static void examplePrequentialEvaluation_edge_cases2() {
windowed_evaluator.widthOption.setValue(1000);
windowed_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator,
- 1000, 10000, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 1000, 10000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -292,14 +541,14 @@ private static void examplePrequentialEvaluation_edge_cases2() {
}
}
- private static void examplePrequentialEvaluation_edge_cases3() {
+ private static void testPrequentialEvaluation_edge_cases3() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
learner.prepareForUse();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
@@ -310,8 +559,7 @@ private static void examplePrequentialEvaluation_edge_cases3() {
windowed_evaluator.widthOption.setValue(1000);
windowed_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator,
- 10, 1, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 10, 1, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -324,26 +572,24 @@ private static void examplePrequentialEvaluation_edge_cases3() {
System.out.println("\tBasic performance");
for (int i = 0; i < results.cumulativeResults.length; ++i)
- System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " +
- results.cumulativeResults[i]);
+ System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]);
System.out.println("\tWindowed performance");
for (int j = 0; j < results.windowedResults.size(); ++j) {
System.out.println("\t" + j);
for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i)
- System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " +
- results.windowedResults.get(j)[i]);
+ System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]);
}
}
- private static void examplePrequentialEvaluation_edge_cases4() {
+ private static void testPrequentialEvaluation_edge_cases4() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
learner.prepareForUse();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
@@ -354,8 +600,7 @@ private static void examplePrequentialEvaluation_edge_cases4() {
windowed_evaluator.widthOption.setValue(10000);
windowed_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator,
- 100000, 10000, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, -1, 10000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -378,22 +623,26 @@ private static void examplePrequentialEvaluation_edge_cases4() {
}
}
- private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain() {
+
+ private static void testPrequentialEvaluation_SampleFrequency_TestThenTrain() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
learner.prepareForUse();
- AgrawalGenerator stream = new AgrawalGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator();
basic_evaluator.recallPerClassOption.setValue(true);
basic_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator,
- 100000, 10000, false, false);
+// WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator();
+// windowed_evaluator.widthOption.setValue(10000);
+// windowed_evaluator.prepareForUse();
+
+ PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator, -1, 10000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -404,6 +653,10 @@ private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain()
// Print the elapsed time
System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds");
+// System.out.println("\tBasic performance");
+// for (int i = 0; i < results.basicResults.length; ++i)
+// System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.basicResults[i]);
+
System.out.println("\tWindowed performance");
for (int j = 0; j < results.windowedResults.size(); ++j) {
System.out.println("\t" + j);
@@ -412,23 +665,26 @@ private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain()
}
}
- private static void examplePrequentialRegressionEvaluation() {
+
+ private static void testPrequentialRegressionEvaluation() {
// Record the start time
long startTime = System.currentTimeMillis();
FIMTDD learner = new FIMTDD();
+// learner.getOptions().setViaCLIString("-s 10"); // 10 learners
+// learner.setRandomSeed(5);
learner.prepareForUse();
- HyperplaneGenerator stream = new HyperplaneGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/metrotraffic_with_nominals.arff", -1);
stream.prepareForUse();
BasicRegressionPerformanceEvaluator basic_evaluator = new BasicRegressionPerformanceEvaluator();
WindowRegressionPerformanceEvaluator windowed_evaluator = new WindowRegressionPerformanceEvaluator();
windowed_evaluator.widthOption.setValue(1000);
+// windowed_evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator,
- 10000, 1000, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -451,7 +707,7 @@ private static void examplePrequentialRegressionEvaluation() {
}
}
- private static void examplePrequentialEvaluation() {
+ private static void testPrequentialEvaluation() {
// Record the start time
long startTime = System.currentTimeMillis();
@@ -494,22 +750,23 @@ private static void examplePrequentialEvaluation() {
}
}
- private static void exampleTestThenTrainEvaluation() {
+ private static void testTestThenTrainEvaluation() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
+// learner.getOptions().setViaCLIString("-s 10"); // 10 learners
+// learner.setRandomSeed(5);
learner.prepareForUse();
- HyperplaneGenerator stream = new HyperplaneGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
BasicClassificationPerformanceEvaluator evaluator = new BasicClassificationPerformanceEvaluator();
evaluator.recallPerClassOption.setValue(true);
evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null,
- 100000, 100000, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null, 100000, 100000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -521,18 +778,19 @@ private static void exampleTestThenTrainEvaluation() {
System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds");
for (int i = 0; i < results.cumulativeResults.length; ++i)
- System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " +
- results.cumulativeResults[i]);
+ System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]);
}
- private static void exampleWindowedEvaluation() {
+ private static void testWindowedEvaluation() {
// Record the start time
long startTime = System.currentTimeMillis();
NaiveBayes learner = new NaiveBayes();
+// learner.getOptions().setViaCLIString("-s 10"); // 10 learners
+// learner.setRandomSeed(5);
learner.prepareForUse();
- HyperplaneGenerator stream = new HyperplaneGenerator();
+ ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1);
stream.prepareForUse();
WindowClassificationPerformanceEvaluator evaluator = new WindowClassificationPerformanceEvaluator();
@@ -540,8 +798,7 @@ private static void exampleWindowedEvaluation() {
evaluator.recallPerClassOption.setValue(true);
evaluator.prepareForUse();
- PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator,
- 100000, 10000, false, false);
+ PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator, 100000, 10000, false, false);
// Record the end time
long endTime = System.currentTimeMillis();
@@ -555,9 +812,7 @@ private static void exampleWindowedEvaluation() {
for (int j = 0; j < results.windowedResults.size(); ++j) {
System.out.println("\t" + j);
for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i)
- System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " +
- results.windowedResults.get(j)[i]);
+ System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]);
}
}
-
}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java
index a7c655be8..911ac4c50 100644
--- a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java
+++ b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java
@@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.evaluation;
@@ -35,35 +35,37 @@
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
-public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler {
+public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler, AutoCloseable {
- /**
- * Resets this evaluator. It must be similar to
- * starting a new evaluator from scratch.
- *
- */
+ /**
+ * Resets this evaluator. It must be similar to
+ * starting a new evaluator from scratch.
+ *
+ */
public void reset();
- /**
- * Adds a learning result to this evaluator.
- *
- * @param example the example to be classified
- * @param classVotes an array containing the estimated membership
- * probabilities of the test instance in each class
- */
- public void addResult(E example, double[] classVotes);
- public void addResult(E testInst, Prediction prediction);
+ /**
+ * Adds a learning result to this evaluator.
+ *
+ * @param example the example to be classified
+ * @param classVotes an array containing the estimated membership
+ * probabilities of the test instance in each class
+ */
+ public void addResult(E example, double[] classVotes);
+ public void addResult(E testInst, Prediction prediction);
- /**
- * Gets the current measurements monitored by this evaluator.
- *
- * @return an array of measurements monitored by this evaluator
- */
+ /**
+ * Gets the current measurements monitored by this evaluator.
+ *
+ * @return an array of measurements monitored by this evaluator
+ */
public Measurement[] getPerformanceMeasurements();
@Override
default ImmutableCapabilities defineImmutableCapabilities() {
- return new ImmutableCapabilities(Capability.VIEW_STANDARD);
+ return new ImmutableCapabilities(Capability.VIEW_STANDARD);
}
+ default void close() throws Exception {
+ }
}
diff --git a/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java
index eb9745a08..23e977598 100644
--- a/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java
+++ b/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java
@@ -175,6 +175,7 @@ public double getCoefficientOfDetermination() {
return 0.0;
}
+
public double getAdjustedCoefficientOfDetermination() {
return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) /
(getTotalWeightObserved() - numAttributes - 1);
@@ -197,6 +198,7 @@ private double getRelativeSquareError() {
}
public double getTotalWeightObserved() {
+// return this.weightObserved.total();
return this.TotalweightObserved;
}
diff --git a/moa/src/main/java/moa/gui/RegressionTabPanel.java b/moa/src/main/java/moa/gui/RegressionTabPanel.java
index d7a5f5ad9..a6ebd6473 100644
--- a/moa/src/main/java/moa/gui/RegressionTabPanel.java
+++ b/moa/src/main/java/moa/gui/RegressionTabPanel.java
@@ -40,8 +40,7 @@ public class RegressionTabPanel extends AbstractTabPanel {
public RegressionTabPanel() {
this.taskManagerPanel = new RegressionTaskManagerPanel();
- if (Objects.equals(this.taskManagerPanel.currentTask.getTaskName(),
- "EvaluatePrequentialPredictionIntervals")){
+ if (Objects.equals(this.taskManagerPanel.currentTask.getTaskName(), "EvaluatePrequentialPredictionIntervals")) {
this.previewPanel = new PreviewPanel(TypePanel.PREDICTIONINTERVAL);
}
else{
diff --git a/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java b/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java
new file mode 100644
index 000000000..f9dc784e1
--- /dev/null
+++ b/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java
@@ -0,0 +1,29 @@
+package moa.gui;
+
+import java.awt.*;
+
+public class SemiSupervisedTabPanel extends AbstractTabPanel {
+
+ protected SemiSupervisedTaskManagerPanel taskManagerPanel;
+
+ protected PreviewPanel previewPanel;
+
+ public SemiSupervisedTabPanel() {
+ this.taskManagerPanel = new SemiSupervisedTaskManagerPanel();
+ this.previewPanel = new PreviewPanel();
+ this.taskManagerPanel.setPreviewPanel(this.previewPanel);
+ setLayout(new BorderLayout());
+ add(this.taskManagerPanel, BorderLayout.NORTH);
+ add(this.previewPanel, BorderLayout.CENTER);
+ }
+
+ @Override
+ public String getTabTitle() {
+ return "Semi-Supervised Learning";
+ }
+
+ @Override
+ public String getDescription() {
+ return "MOA Semi-Supervised Learning";
+ }
+}
diff --git a/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java b/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java
new file mode 100644
index 000000000..0e2f251c6
--- /dev/null
+++ b/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java
@@ -0,0 +1,468 @@
+package moa.gui;
+
+import moa.core.StringUtils;
+import moa.options.ClassOption;
+import moa.options.OptionHandler;
+import moa.tasks.EvaluateInterleavedTestThenTrainSSLDelayed;
+import moa.tasks.SemiSupervisedMainTask;
+import moa.tasks.Task;
+import moa.tasks.TaskThread;
+import nz.ac.waikato.cms.gui.core.BaseFileChooser;
+
+import javax.swing.*;
+import javax.swing.event.ListSelectionEvent;
+import javax.swing.event.ListSelectionListener;
+import javax.swing.table.AbstractTableModel;
+import javax.swing.table.DefaultTableCellRenderer;
+import javax.swing.table.TableCellRenderer;
+import java.awt.*;
+import java.awt.datatransfer.Clipboard;
+import java.awt.datatransfer.StringSelection;
+import java.awt.event.ActionEvent;
+import java.awt.event.ActionListener;
+import java.awt.event.MouseAdapter;
+import java.awt.event.MouseEvent;
+import java.io.*;
+import java.util.ArrayList;
+import java.util.prefs.Preferences;
+
+public class SemiSupervisedTaskManagerPanel extends JPanel {
+
+ private static final long serialVersionUID = 1L;
+
+ public static final int MILLISECS_BETWEEN_REFRESH = 600;
+
+ public static String exportFileExtension = "log";
+
+ public class ProgressCellRenderer extends JProgressBar implements
+ TableCellRenderer {
+
+ private static final long serialVersionUID = 1L;
+
+ public ProgressCellRenderer() {
+ super(SwingConstants.HORIZONTAL, 0, 10000);
+ setBorderPainted(false);
+ setStringPainted(true);
+ }
+
+ @Override
+ public Component getTableCellRendererComponent(JTable table,
+ Object value, boolean isSelected, boolean hasFocus, int row,
+ int column) {
+ double frac = -1.0;
+ if (value instanceof Double) {
+ frac = ((Double) value).doubleValue();
+ }
+ if (frac >= 0.0) {
+ setIndeterminate(false);
+ setValue((int) (frac * 10000.0));
+ setString(StringUtils.doubleToString(frac * 100.0, 2, 2));
+ } else {
+ setValue(0);
+ }
+ return this;
+ }
+
+ @Override
+ public void validate() { }
+
+ @Override
+ public void revalidate() { }
+
+ @Override
+ protected void firePropertyChange(String propertyName, Object oldValue,
+ Object newValue) { }
+
+ @Override
+ public void firePropertyChange(String propertyName, boolean oldValue,
+ boolean newValue) { }
+ }
+
+ protected class TaskTableModel extends AbstractTableModel {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public String getColumnName(int col) {
+ switch (col) {
+ case 0:
+ return "command";
+ case 1:
+ return "status";
+ case 2:
+ return "time elapsed";
+ case 3:
+ return "current activity";
+ case 4:
+ return "% complete";
+ }
+ return null;
+ }
+
+ @Override
+ public int getColumnCount() {
+ return 5;
+ }
+
+ @Override
+ public int getRowCount() {
+ return SemiSupervisedTaskManagerPanel.this.taskList.size();
+ }
+
+ @Override
+ public Object getValueAt(int row, int col) {
+ TaskThread thread = SemiSupervisedTaskManagerPanel.this.taskList.get(row);
+ switch (col) {
+ case 0:
+ return ((OptionHandler) thread.getTask()).getCLICreationString(SemiSupervisedMainTask.class);
+ case 1:
+ return thread.getCurrentStatusString();
+ case 2:
+ return StringUtils.secondsToDHMSString(thread.getCPUSecondsElapsed());
+ case 3:
+ return thread.getCurrentActivityString();
+ case 4:
+ return Double.valueOf(thread.getCurrentActivityFracComplete());
+ }
+ return null;
+ }
+
+ @Override
+ public boolean isCellEditable(int row, int col) {
+ return false;
+ }
+ }
+
+ protected SemiSupervisedMainTask currentTask;
+
+ protected java.util.List taskList = new ArrayList<>();
+
+ protected JButton configureTaskButton = new JButton("Configure");
+
+ protected JTextField taskDescField = new JTextField();
+
+ protected JButton runTaskButton = new JButton("Run");
+
+ protected TaskTableModel taskTableModel;
+
+ protected JTable taskTable;
+
+ protected JButton pauseTaskButton = new JButton("Pause");
+
+ protected JButton resumeTaskButton = new JButton("Resume");
+
+ protected JButton cancelTaskButton = new JButton("Cancel");
+
+ protected JButton deleteTaskButton = new JButton("Delete");
+
+ protected PreviewPanel previewPanel;
+
+ private Preferences prefs;
+
+ private final String PREF_NAME = "currentTask";
+
+ public SemiSupervisedTaskManagerPanel() {
+ // Read current task preference
+ prefs = Preferences.userRoot().node(this.getClass().getName());
+ currentTask = new EvaluateInterleavedTestThenTrainSSLDelayed();
+ String taskText = this.currentTask.getCLICreationString(SemiSupervisedMainTask.class);
+ String propertyValue = prefs.get(PREF_NAME, taskText);
+ //this.taskDescField.setText(propertyValue);
+ setTaskString(propertyValue, false); //Not store preference
+ this.taskDescField.setEditable(false);
+
+ final Component comp = this.taskDescField;
+ this.taskDescField.addMouseListener(new MouseAdapter() {
+
+ @Override
+ public void mouseClicked(MouseEvent evt) {
+ if (evt.getClickCount() == 1) {
+ if ((evt.getButton() == MouseEvent.BUTTON3)
+ || ((evt.getButton() == MouseEvent.BUTTON1) && evt.isAltDown() && evt.isShiftDown())) {
+ JPopupMenu menu = new JPopupMenu();
+ JMenuItem item;
+
+ item = new JMenuItem("Copy configuration to clipboard");
+ item.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent e) {
+ copyClipBoardConfiguration();
+ }
+ });
+ menu.add(item);
+
+ item = new JMenuItem("Save selected tasks to file");
+ item.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ saveLogSelectedTasks();
+ }
+ });
+ menu.add(item);
+
+
+ item = new JMenuItem("Enter configuration...");
+ item.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ String newTaskString = JOptionPane.showInputDialog("Insert command line");
+ if (newTaskString != null) {
+ setTaskString(newTaskString);
+ }
+ }
+ });
+ menu.add(item);
+
+ menu.show(comp, evt.getX(), evt.getY());
+ }
+ }
+ }
+ });
+
+ JPanel configPanel = new JPanel();
+ configPanel.setLayout(new BorderLayout());
+ configPanel.add(this.configureTaskButton, BorderLayout.WEST);
+ configPanel.add(this.taskDescField, BorderLayout.CENTER);
+ configPanel.add(this.runTaskButton, BorderLayout.EAST);
+ this.taskTableModel = new TaskTableModel();
+ this.taskTable = new JTable(this.taskTableModel);
+ DefaultTableCellRenderer centerRenderer = new DefaultTableCellRenderer();
+ centerRenderer.setHorizontalAlignment(SwingConstants.CENTER);
+ this.taskTable.getColumnModel().getColumn(1).setCellRenderer(
+ centerRenderer);
+ this.taskTable.getColumnModel().getColumn(2).setCellRenderer(
+ centerRenderer);
+ this.taskTable.getColumnModel().getColumn(4).setCellRenderer(
+ new ProgressCellRenderer());
+ JPanel controlPanel = new JPanel();
+ controlPanel.add(this.pauseTaskButton);
+ controlPanel.add(this.resumeTaskButton);
+ controlPanel.add(this.cancelTaskButton);
+ controlPanel.add(this.deleteTaskButton);
+ setLayout(new BorderLayout());
+ add(configPanel, BorderLayout.NORTH);
+ add(new JScrollPane(this.taskTable), BorderLayout.CENTER);
+ add(controlPanel, BorderLayout.SOUTH);
+ this.taskTable.getSelectionModel().addListSelectionListener(
+ new ListSelectionListener() {
+
+ @Override
+ public void valueChanged(ListSelectionEvent arg0) {
+ taskSelectionChanged();
+ }
+ });
+ this.configureTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ String newTaskString = ClassOptionSelectionPanel.showSelectClassDialog(
+ SemiSupervisedTaskManagerPanel.this,
+ "Configure task", SemiSupervisedMainTask.class,
+ SemiSupervisedTaskManagerPanel.this.currentTask.getCLICreationString(SemiSupervisedMainTask.class),
+ null);
+ setTaskString(newTaskString);
+ }
+ });
+ this.runTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ runTask((Task) SemiSupervisedTaskManagerPanel.this.currentTask.copy());
+ }
+ });
+ this.pauseTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ pauseSelectedTasks();
+ }
+ });
+ this.resumeTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ resumeSelectedTasks();
+ }
+ });
+ this.cancelTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ cancelSelectedTasks();
+ }
+ });
+ this.deleteTaskButton.addActionListener(new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent arg0) {
+ deleteSelectedTasks();
+ }
+ });
+
+ Timer updateListTimer = new Timer(
+ MILLISECS_BETWEEN_REFRESH, new ActionListener() {
+
+ @Override
+ public void actionPerformed(ActionEvent e) {
+ SemiSupervisedTaskManagerPanel.this.taskTable.repaint();
+ }
+ });
+ updateListTimer.start();
+ setPreferredSize(new Dimension(0, 200));
+ }
+
+ public void setPreviewPanel(PreviewPanel previewPanel) {
+ this.previewPanel = previewPanel;
+ }
+
+ public void setTaskString(String cliString) {
+ setTaskString(cliString, true);
+ }
+
+ public void setTaskString(String cliString, boolean storePreference) {
+ try {
+ this.currentTask = (SemiSupervisedMainTask) ClassOption.cliStringToObject(
+ cliString, SemiSupervisedMainTask.class, null);
+ String taskText = this.currentTask.getCLICreationString(SemiSupervisedMainTask.class);
+ this.taskDescField.setText(taskText);
+ if (storePreference) {
+ //Save task text as a preference
+ prefs.put(PREF_NAME, taskText);
+ }
+ } catch (Exception ex) {
+ GUIUtils.showExceptionDialog(this, "Problem with task", ex);
+ }
+ }
+
+ public void runTask(Task task) {
+ TaskThread thread = new TaskThread(task);
+ this.taskList.add(0, thread);
+ this.taskTableModel.fireTableDataChanged();
+ this.taskTable.setRowSelectionInterval(0, 0);
+ thread.start();
+ }
+
+ public void taskSelectionChanged() {
+ TaskThread[] selectedTasks = getSelectedTasks();
+ if (selectedTasks.length == 1) {
+ setTaskString(((OptionHandler) selectedTasks[0].getTask()).getCLICreationString(SemiSupervisedMainTask.class));
+ if (this.previewPanel != null) {
+ this.previewPanel.setTaskThreadToPreview(selectedTasks[0]);
+ }
+ } else {
+ this.previewPanel.setTaskThreadToPreview(null);
+ }
+ }
+
+ public TaskThread[] getSelectedTasks() {
+ int[] selectedRows = this.taskTable.getSelectedRows();
+ TaskThread[] selectedTasks = new TaskThread[selectedRows.length];
+ for (int i = 0; i < selectedRows.length; i++) {
+ selectedTasks[i] = this.taskList.get(selectedRows[i]);
+ }
+ return selectedTasks;
+ }
+
+ public void pauseSelectedTasks() {
+ TaskThread[] selectedTasks = getSelectedTasks();
+ for (TaskThread thread : selectedTasks) {
+ thread.pauseTask();
+ }
+ }
+
+ public void resumeSelectedTasks() {
+ TaskThread[] selectedTasks = getSelectedTasks();
+ for (TaskThread thread : selectedTasks) {
+ thread.resumeTask();
+ }
+ }
+
+ public void cancelSelectedTasks() {
+ TaskThread[] selectedTasks = getSelectedTasks();
+ for (TaskThread thread : selectedTasks) {
+ thread.cancelTask();
+ }
+ }
+
+ public void deleteSelectedTasks() {
+ TaskThread[] selectedTasks = getSelectedTasks();
+ for (TaskThread thread : selectedTasks) {
+ thread.cancelTask();
+ this.taskList.remove(thread);
+ }
+ this.taskTableModel.fireTableDataChanged();
+ }
+
+ public void copyClipBoardConfiguration() {
+
+ StringSelection selection = new StringSelection(this.taskDescField.getText().trim());
+ Clipboard clipboard = Toolkit.getDefaultToolkit().getSystemClipboard();
+ clipboard.setContents(selection, selection);
+
+ }
+
+ public void saveLogSelectedTasks() {
+ String tasksLog = "";
+ TaskThread[] selectedTasks = getSelectedTasks();
+ for (TaskThread thread : selectedTasks) {
+ tasksLog += ((OptionHandler) thread.getTask()).getCLICreationString(SemiSupervisedMainTask.class) + "\n";
+ }
+
+ BaseFileChooser fileChooser = new BaseFileChooser();
+ fileChooser.setAcceptAllFileFilterUsed(true);
+ fileChooser.addChoosableFileFilter(new FileExtensionFilter(
+ exportFileExtension));
+ if (fileChooser.showSaveDialog(this) == BaseFileChooser.APPROVE_OPTION) {
+ File chosenFile = fileChooser.getSelectedFile();
+ String fileName = chosenFile.getPath();
+ if (!chosenFile.exists()
+ && !fileName.endsWith(exportFileExtension)) {
+ fileName = fileName + "." + exportFileExtension;
+ }
+ try {
+ PrintWriter out = new PrintWriter(new BufferedWriter(
+ new FileWriter(fileName)));
+ out.write(tasksLog);
+ out.close();
+ } catch (IOException ioe) {
+ GUIUtils.showExceptionDialog(
+ this,
+ "Problem saving file " + fileName, ioe);
+ }
+ }
+ }
+
+ private static void createAndShowGUI() {
+
+ // Create and set up the labeledInstancesBuffer.
+ JFrame frame = new JFrame("Test");
+ frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
+
+ // Create and set up the content pane.
+ JPanel panel = new SemiSupervisedTabPanel();
+ panel.setOpaque(true); // content panes must be opaque
+ frame.setContentPane(panel);
+
+ // Display the labeledInstancesBuffer.
+ frame.pack();
+ // frame.setSize(400, 400);
+ frame.setVisible(true);
+ }
+
+ public static void main(String[] args) {
+ try {
+ UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
+ SwingUtilities.invokeLater(new Runnable() {
+ @Override
+ public void run() {
+ createAndShowGUI();
+ }
+ });
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+}
diff --git a/moa/src/main/java/moa/learners/Learner.java b/moa/src/main/java/moa/learners/Learner.java
index be959a8d0..a806ad587 100644
--- a/moa/src/main/java/moa/learners/Learner.java
+++ b/moa/src/main/java/moa/learners/Learner.java
@@ -19,14 +19,10 @@
*/
package moa.learners;
+import com.yahoo.labs.samoa.instances.*;
import moa.MOAObject;
import moa.core.Example;
-import com.yahoo.labs.samoa.instances.InstanceData;
-import com.yahoo.labs.samoa.instances.InstancesHeader;
-import com.yahoo.labs.samoa.instances.MultiLabelInstance;
-import com.yahoo.labs.samoa.instances.Prediction;
-
import moa.core.Measurement;
import moa.gui.AWTRenderable;
import moa.options.OptionHandler;
@@ -95,6 +91,14 @@ public interface Learner extends MOAObject, OptionHandler, AW
*/
public double[] getVotesForInstance(E example);
+ /**
+ *
+ * @param example the instance whose confidence we are observing
+ * @param label
+ * @return
+ */
+ public double getConfidenceForPrediction(E example, double label);
+
/**
* Gets the current measurements of this learner.
*
diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
index e0c69a5bd..e16d43791 100644
--- a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
+++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
@@ -25,6 +25,7 @@
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
+import moa.classifiers.AbstractClassifierMiniBatch;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
@@ -217,6 +218,9 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
+ if (learner instanceof AbstractClassifierMiniBatch) {
+ ((AbstractClassifierMiniBatch) learner).trainingHasEnded();
+ }
return learningCurve;
}
diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java
new file mode 100644
index 000000000..b5b02904a
--- /dev/null
+++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java
@@ -0,0 +1,351 @@
+package moa.tasks;
+
+import com.github.javacliparser.FileOption;
+import com.github.javacliparser.FlagOption;
+import com.github.javacliparser.FloatOption;
+import com.github.javacliparser.IntOption;
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.classifiers.MultiClassClassifier;
+import moa.classifiers.SemiSupervisedLearner;
+import moa.core.*;
+import moa.evaluation.LearningEvaluation;
+import moa.evaluation.LearningPerformanceEvaluator;
+import moa.evaluation.preview.LearningCurve;
+import moa.learners.Learner;
+import moa.options.ClassOption;
+import moa.streams.ExampleStream;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.MutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.math3.random.MersenneTwister;
+import org.apache.commons.math3.random.RandomGenerator;
+
+/**
+ * An evaluation task that relies on the mechanism of Interleaved Test Then
+ * Train,
+ * applied on semi-supervised data streams
+ */
+public class EvaluateInterleavedTestThenTrainSSLDelayed extends SemiSupervisedMainTask {
+
+ @Override
+ public String getPurposeString() {
+ return "Evaluates a classifier on a semi-supervised stream by testing only the labeled data, " +
+ "then training with each example in sequence.";
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public IntOption randomSeedOption = new IntOption(
+ "instanceRandomSeed", 'r',
+ "Seed for random generation of instances.", 1);
+
+ public FlagOption onlyLabeledDataOption = new FlagOption("labeledDataOnly", 'a',
+ "Learner only trained on labeled data");
+
+ public ClassOption standardLearnerOption = new ClassOption("standardLearner", 'b',
+ "A standard learner to train. This will be ignored if labeledDataOnly flag is not set.",
+ MultiClassClassifier.class, "moa.classifiers.trees.HoeffdingTree");
+
+ public ClassOption sslLearnerOption = new ClassOption("sslLearner", 'l',
+ "A semi-supervised learner to train.", SemiSupervisedLearner.class,
+ "moa.classifiers.semisupervised.ClusterAndLabelClassifier");
+
+ public ClassOption streamOption = new ClassOption("stream", 's',
+ "Stream to learn from.", ExampleStream.class,
+ "moa.streams.ArffFileStream");
+
+ public ClassOption evaluatorOption = new ClassOption("evaluator", 'e',
+ "Classification performance evaluation method.",
+ LearningPerformanceEvaluator.class,
+ "BasicClassificationPerformanceEvaluator");
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /** Option: Probability of instance being unlabeled */
+ public FloatOption labelProbabilityOption = new FloatOption("labelProbability", 'j',
+ "The ratio of labeled data",
+ 0.01);
+
+ public IntOption delayLengthOption = new IntOption("delay", 'k',
+ "Number of instances before test instance is used for training. -1 = no delayed labeling.",
+ -1, -1, Integer.MAX_VALUE);
+
+ public IntOption initialWindowSizeOption = new IntOption("initialTrainingWindow", 'p',
+ "Number of instances used for training in the beginning of the stream (-1 = no initialWindow).",
+ -1, -1, Integer.MAX_VALUE);
+
+ public FlagOption debugPseudoLabelsOption = new FlagOption("debugPseudoLabels", 'w',
+ "Learner also receives the labeled data, but it is not used for training (just for statistics)");
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i',
+ "Maximum number of instances to test/train on (-1 = no limit).",
+ 100000000, -1, Integer.MAX_VALUE);
+
+ public IntOption timeLimitOption = new IntOption("timeLimit", 't',
+ "Maximum number of seconds to test/train for (-1 = no limit).", -1,
+ -1, Integer.MAX_VALUE);
+
+ public IntOption sampleFrequencyOption = new IntOption("sampleFrequency",
+ 'f',
+ "How many instances between samples of the learning performance.",
+ 100000, 0, Integer.MAX_VALUE);
+
+ public IntOption memCheckFrequencyOption = new IntOption(
+ "memCheckFrequency", 'q',
+ "How many instances between memory bound checks.", 100000, 0,
+ Integer.MAX_VALUE);
+
+ public FileOption dumpFileOption = new FileOption("dumpFile", 'd',
+ "File to append intermediate csv results to.", null, "csv", true);
+
+ public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o',
+ "File to append output predictions to.", null, "pred", true);
+
+ public FileOption debugOutputUnlabeledClassInformation = new FileOption("debugOutputUnlabeledClassInformation", 'h',
+ "Single column containing the class label or -999 indicating missing labels.", null, "csv", true);
+
+ private int numUnlabeledData = 0;
+
+ private Learner getLearner(ExampleStream stream) {
+ Learner learner;
+ if (this.onlyLabeledDataOption.isSet()) {
+ learner = (Learner) getPreparedClassOption(this.standardLearnerOption);
+ } else {
+ learner = (SemiSupervisedLearner) getPreparedClassOption(this.sslLearnerOption);
+ }
+
+ learner.setModelContext(stream.getHeader());
+ if (learner.isRandomizable()) {
+ learner.setRandomSeed(this.randomSeedOption.getValue());
+ learner.resetLearning();
+ }
+ return learner;
+ }
+
+ private String getLearnerString() {
+ if (this.onlyLabeledDataOption.isSet()) {
+ return this.standardLearnerOption.getValueAsCLIString();
+ } else {
+ return this.sslLearnerOption.getValueAsCLIString();
+ }
+ }
+
+ private PrintStream newPrintStream(File f, String err_msg) {
+ if (f == null)
+ return null;
+ try {
+ return new PrintStream(new FileOutputStream(f, f.exists()), true);
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(err_msg, e);
+ }
+ }
+
+ private Object internalDoMainTask(TaskMonitor monitor, ObjectRepository repository, LearningPerformanceEvaluator evaluator)
+ {
+ int maxInstances = this.instanceLimitOption.getValue();
+ int maxSeconds = this.timeLimitOption.getValue();
+ int delayLength = this.delayLengthOption.getValue();
+ double labelProbability = this.labelProbabilityOption.getValue();
+ String streamString = this.streamOption.getValueAsCLIString();
+ RandomGenerator taskRandom = new MersenneTwister(this.randomSeedOption.getValue());
+ ExampleStream stream = (ExampleStream) getPreparedClassOption(this.streamOption);
+ Learner learner = getLearner(stream);
+ String learnerString = getLearnerString();
+
+ // A number of output files used for debugging and manual evaluation
+ PrintStream dumpStream = newPrintStream(this.dumpFileOption.getFile(), "Failed to create dump file");
+ PrintStream predStream = newPrintStream(this.outputPredictionFileOption.getFile(),
+ "Failed to create prediction file");
+ PrintStream labelStream = newPrintStream(this.debugOutputUnlabeledClassInformation.getFile(),
+ "Failed to create unlabeled class information file");
+ if (labelStream != null)
+ labelStream.println("class");
+
+ // Setup evaluation
+ monitor.setCurrentActivity("Evaluating learner...", -1.0);
+ LearningCurve learningCurve = new LearningCurve("learning evaluation instances");
+
+ boolean firstDump = true;
+ boolean preciseCPUTiming = TimingUtils.enablePreciseTiming();
+ long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
+ long lastEvaluateStartTime = evaluateStartTime;
+ long instancesProcessed = 0;
+ int secondsElapsed = 0;
+ double RAMHours = 0.0;
+
+ // The buffer is a list of tuples. The first element is the index when
+ // it should be emitted. The second element is the instance itself.
+ List> delayBuffer = new ArrayList>();
+
+ while (stream.hasMoreInstances()
+ && ((maxInstances < 0) || (instancesProcessed < maxInstances))
+ && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) {
+ instancesProcessed++;
+
+ // TRAIN on delayed instances
+ while (delayBuffer.size() > 0
+ && delayBuffer.get(0).getKey() == instancesProcessed) {
+ Example delayedExample = delayBuffer.remove(0).getValue();
+ learner.trainOnInstance(delayedExample);
+ }
+
+ // Obtain the next Example from the stream.
+ // The instance is expected to be labeled.
+ Example originalExample = stream.nextInstance();
+ Example unlabeledExample = originalExample.copy();
+ int trueClass = (int) ((Instance) originalExample.getData()).classValue();
+
+ // In case it is set, then the label is not removed. We want to pass the
+ // labelled data to the learner even in trainOnUnlabeled data to generate statistics such as number
+ // of correctly pseudo-labeled instances.
+ if (!debugPseudoLabelsOption.isSet()) {
+ // Remove the label of the unlabeledExample indirectly through
+ // unlabeledInstanceData.
+ Instance instance = (Instance) unlabeledExample.getData();
+ instance.setMissing(instance.classIndex());
+ }
+
+ // WARMUP
+ // Train on the initial instances. These are not used for testing!
+ if (instancesProcessed <= this.initialWindowSizeOption.getValue()) {
+ if (learner instanceof SemiSupervisedLearner)
+ ((SemiSupervisedLearner) learner).addInitialWarmupTrainingInstances();
+ learner.trainOnInstance(originalExample);
+ continue;
+ }
+
+ Boolean is_labeled = labelProbability > taskRandom.nextDouble();
+ if (!is_labeled) {
+ this.numUnlabeledData++;
+ if (labelStream != null)
+ labelStream.println(-999);
+ } else {
+ if (labelStream != null)
+ labelStream.println((int) trueClass);
+ }
+
+ // TEST
+ // Obtain the prediction for the testInst (i.e. no label)
+ double[] prediction = learner.getVotesForInstance(unlabeledExample);
+
+ // Output prediction
+ if (predStream != null) {
+ // Assuming that the class label is not missing for the originalInstanceData
+ predStream.println(Utils.maxIndex(prediction) + "," + trueClass);
+ }
+ evaluator.addResult(originalExample, prediction);
+
+ // TRAIN
+ if (is_labeled && delayLength >= 0) {
+ // The instance will be labeled but has been delayed
+ if (learner instanceof SemiSupervisedLearner)
+ {
+ ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData());
+ }
+ delayBuffer.add(
+ new MutablePair(1 + instancesProcessed + delayLength, originalExample));
+ } else if (is_labeled) {
+ // The instance will be labeled and is not delayed e.g delayLength = -1
+ learner.trainOnInstance(originalExample);
+ } else {
+ // The instance will never be labeled
+ if (learner instanceof SemiSupervisedLearner)
+ ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData());
+ }
+
+ if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 || !stream.hasMoreInstances()) {
+ long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
+ double time = TimingUtils.nanoTimeToSeconds(evaluateTime - evaluateStartTime);
+ double timeIncrement = TimingUtils.nanoTimeToSeconds(evaluateTime - lastEvaluateStartTime);
+ double RAMHoursIncrement = learner.measureByteSize() / (1024.0 * 1024.0 * 1024.0); // GBs
+ RAMHoursIncrement *= (timeIncrement / 3600.0); // Hours
+ RAMHours += RAMHoursIncrement;
+ lastEvaluateStartTime = evaluateTime;
+ learningCurve.insertEntry(new LearningEvaluation(
+ new Measurement[] {
+ new Measurement(
+ "learning evaluation instances",
+ instancesProcessed),
+ new Measurement(
+ "evaluation time ("
+ + (preciseCPUTiming ? "cpu "
+ : "")
+ + "seconds)",
+ time),
+ new Measurement(
+ "model cost (RAM-Hours)",
+ RAMHours),
+ new Measurement(
+ "Unlabeled instances",
+ this.numUnlabeledData)
+ },
+ evaluator, learner));
+ if (dumpStream != null) {
+ if (firstDump) {
+ dumpStream.print("Learner,stream,randomSeed,");
+ dumpStream.println(learningCurve.headerToString());
+ firstDump = false;
+ }
+ dumpStream.print(learnerString + "," + streamString + ","
+ + this.randomSeedOption.getValueAsCLIString() + ",");
+ dumpStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
+ dumpStream.flush();
+ }
+ }
+ if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
+ if (monitor.taskShouldAbort()) {
+ return null;
+ }
+ long estimatedRemainingInstances = stream.estimatedRemainingInstances();
+ if (maxInstances > 0) {
+ long maxRemaining = maxInstances - instancesProcessed;
+ if ((estimatedRemainingInstances < 0)
+ || (maxRemaining < estimatedRemainingInstances)) {
+ estimatedRemainingInstances = maxRemaining;
+ }
+ }
+ monitor.setCurrentActivityFractionComplete(estimatedRemainingInstances < 0 ? -1.0
+ : (double) instancesProcessed / (double) (instancesProcessed + estimatedRemainingInstances));
+ if (monitor.resultPreviewRequested()) {
+ monitor.setLatestResultPreview(learningCurve.copy());
+ }
+ secondsElapsed = (int) TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread()
+ - evaluateStartTime);
+ }
+ }
+ if (dumpStream != null) {
+ dumpStream.close();
+ }
+ if (predStream != null) {
+ predStream.close();
+ }
+ return learningCurve;
+ }
+
+ @Override
+ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
+ // Some resource must be closed at the end of the task
+ try (
+ LearningPerformanceEvaluator evaluator = (LearningPerformanceEvaluator) getPreparedClassOption(this.evaluatorOption)
+ ) {
+ return internalDoMainTask(monitor, repository, evaluator);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ @Override
+ public Class> getTaskResultType() {
+ return LearningCurve.class;
+ }
+}
diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java
index c8e559d12..b577acb9c 100644
--- a/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java
+++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java
@@ -16,7 +16,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.tasks;
@@ -224,17 +224,17 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
lastEvaluateStartTime = evaluateTime;
learningCurve.insertEntry(new LearningEvaluation(
new Measurement[]{
- new Measurement(
- "learning evaluation instances",
- instancesProcessed),
- new Measurement(
- "evaluation time ("
- + (preciseCPUTiming ? "cpu "
- : "") + "seconds)",
- time),
- new Measurement(
- "model cost (RAM-Hours)",
- RAMHours)
+ new Measurement(
+ "learning evaluation instances",
+ instancesProcessed),
+ new Measurement(
+ "evaluation time ("
+ + (preciseCPUTiming ? "cpu "
+ : "") + "seconds)",
+ time),
+ new Measurement(
+ "model cost (RAM-Hours)",
+ RAMHours)
},
evaluator, learner));
@@ -277,4 +277,4 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
}
return learningCurve;
}
-}
\ No newline at end of file
+}
diff --git a/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java b/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java
new file mode 100644
index 000000000..fecf7feae
--- /dev/null
+++ b/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java
@@ -0,0 +1,24 @@
+package moa.tasks;
+
+import moa.streams.clustering.ClusterEvent;
+
+import java.util.ArrayList;
+
+/**
+ *
+ */
+public abstract class SemiSupervisedMainTask extends MainTask {
+
+ private static final long serialVersionUID = 1L;
+
+ protected ArrayList events;
+
+ protected void setEventsList(ArrayList events) {
+ this.events = events;
+ }
+
+ public ArrayList getEventsList() {
+ return this.events;
+ }
+
+}
diff --git a/moa/src/main/resources/moa/gui/GUI.props b/moa/src/main/resources/moa/gui/GUI.props
index d1deabd85..3bb990469 100644
--- a/moa/src/main/resources/moa/gui/GUI.props
+++ b/moa/src/main/resources/moa/gui/GUI.props
@@ -8,6 +8,7 @@
Tabs=\
moa.gui.ClassificationTabPanel,\
moa.gui.RegressionTabPanel,\
+ moa.gui.SemiSupervisedTabPanel,\
moa.gui.MultiLabelTabPanel,\
moa.gui.MultiTargetTabPanel,\
moa.gui.clustertab.ClusteringTabPanel,\
diff --git a/moa/src/test/java/moa/classifiers/meta/PEARLTest.java b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java
new file mode 100644
index 000000000..cc53afc78
--- /dev/null
+++ b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java
@@ -0,0 +1,81 @@
+/*
+ * PEARLTest.java
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+/**
+ *
+ */
+package moa.classifiers.meta;
+
+import junit.framework.Test;
+import junit.framework.TestSuite;
+import moa.classifiers.AbstractMultipleClassifierTestCase;
+import moa.classifiers.Classifier;
+
+/**
+ * Tests the PEARL classifier.
+ *
+ * @author Nuwan Gunasekara (ng98 at students dot waikato dot ac dot nz)
+ * @version $Revision$
+ */
+public class PEARLTest
+ extends AbstractMultipleClassifierTestCase {
+
+ /**
+ * Constructs the test case. Called by subclasses.
+ *
+ * @param name the name of the test
+ */
+ public PEARLTest(String name) {
+ super(name);
+ this.setNumberTests(1);
+ }
+
+ /**
+ * Returns the classifier setups to use in the regression test.
+ *
+ * @return the setups
+ */
+ @Override
+ protected Classifier[] getRegressionClassifierSetups() {
+ PEARL PEARLTest = new PEARL();
+ PEARLTest.ensembleSizeOption.setValue(5);
+ PEARLTest.mFeaturesModeOption.setChosenIndex(0);
+ PEARLTest.mFeaturesPerTreeSizeOption.setValue(2);
+
+ return new Classifier[]{
+ PEARLTest,
+ };
+ }
+
+ /**
+ * Returns a test suite.
+ *
+ * @return the test suite
+ */
+ public static Test suite() {
+ return new TestSuite(PEARLTest.class);
+ }
+
+ /**
+ * Runs the test from commandline.
+ *
+ * @param args ignored
+ */
+ public static void main(String[] args) {
+ runTest(suite());
+ }
+}
diff --git a/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java b/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java
index 7c1510a4d..b76795378 100644
--- a/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java
+++ b/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java
@@ -55,7 +55,7 @@ protected Classifier[] getRegressionClassifierSetups() {
SOKNLTest.ensembleSizeOption.setValue(5);
SOKNLTest.mFeaturesModeOption.setChosenIndex(0);
SOKNLTest.mFeaturesPerTreeSizeOption.setValue(2);
- SOKNLTest.DisableSelfOptimisingOption.set();
+ SOKNLTest.disableSelfOptimisingOption.set();
SOKNLTest.kOption.setValue(5);
diff --git a/moa/src/test/resources/moa/classifiers/meta/PEARL.ref b/moa/src/test/resources/moa/classifiers/meta/PEARL.ref
new file mode 100644
index 000000000..202a255c5
--- /dev/null
+++ b/moa/src/test/resources/moa/classifiers/meta/PEARL.ref
@@ -0,0 +1,145 @@
+--> classification-out0.arff
+moa.classifiers.meta.PEARL -s 5 -o (Specified m (integer value)) -m 2 -x (ADWINChangeDetector -a 0.001) -p (ADWINChangeDetector -a 0.01)
+
+Index
+ 10000
+Votes
+ 0: 201.05038779
+ 1: 71.46686394
+Measurements
+ classified instances: 9999
+ classifications correct (percent): 76.07760776
+ Kappa Statistic (percent): 49.28392856
+ Kappa Temporal Statistic (percent): 49.57841484
+ Kappa M Statistic (percent): 41.58730159
+Model measurements
+ model training instances: 9999
+
+Index
+ 20000
+Votes
+ 0: 38.6256937
+ 1: 312.13184418
+Measurements
+ classified instances: 19999
+ classifications correct (percent): 78.18890945
+ Kappa Statistic (percent): 54.27721317
+ Kappa Temporal Statistic (percent): 54.46288757
+ Kappa M Statistic (percent): 47.7103812
+Model measurements
+ model training instances: 19999
+
+Index
+ 30000
+Votes
+ 0: 2.11253342
+ 1: 354.36268242
+Measurements
+ classified instances: 29999
+ classifications correct (percent): 79.79932664
+ Kappa Statistic (percent): 57.69378087
+ Kappa Temporal Statistic (percent): 58.14049872
+ Kappa M Statistic (percent): 51.47341448
+Model measurements
+ model training instances: 29999
+
+Index
+ 40000
+Votes
+ 0: 5.27561538
+ 1: 281.57905599
+Measurements
+ classified instances: 39999
+ classifications correct (percent): 80.73701843
+ Kappa Statistic (percent): 59.76778227
+ Kappa Temporal Statistic (percent): 60.18293628
+ Kappa M Statistic (percent): 53.91746411
+Model measurements
+ model training instances: 39999
+
+Index
+ 50000
+Votes
+ 0: 203.06408337
+ 1: 86.63171055
+Measurements
+ classified instances: 49999
+ classifications correct (percent): 81.39562791
+ Kappa Statistic (percent): 61.17254121
+ Kappa Temporal Statistic (percent): 61.49515688
+ Kappa M Statistic (percent): 55.48430322
+Model measurements
+ model training instances: 49999
+
+Index
+ 60000
+Votes
+ 0: 5.48964024
+ 1: 360.37979092
+Measurements
+ classified instances: 59999
+ classifications correct (percent): 81.91136519
+ Kappa Statistic (percent): 62.34778415
+ Kappa Temporal Statistic (percent): 62.59520937
+ Kappa M Statistic (percent): 56.93595746
+Model measurements
+ model training instances: 59999
+
+Index
+ 70000
+Votes
+ 0: 3.2953865
+ 1: 364.58558323
+Measurements
+ classified instances: 69999
+ classifications correct (percent): 82.35689081
+ Kappa Statistic (percent): 63.32418526
+ Kappa Temporal Statistic (percent): 63.6326158
+ Kappa M Statistic (percent): 58.14126898
+Model measurements
+ model training instances: 69999
+
+Index
+ 80000
+Votes
+ 0: 171.67957388
+ 1: 197.86379541
+Measurements
+ classified instances: 79999
+ classifications correct (percent): 82.71728397
+ Kappa Statistic (percent): 64.09329532
+ Kappa Temporal Statistic (percent): 64.41092435
+ Kappa M Statistic (percent): 58.99276308
+Model measurements
+ model training instances: 79999
+
+Index
+ 90000
+Votes
+ 0: 2.74400602
+ 1: 368.37233971
+Measurements
+ classified instances: 89999
+ classifications correct (percent): 83.05203391
+ Kappa Statistic (percent): 64.80130245
+ Kappa Temporal Statistic (percent): 65.08412499
+ Kappa M Statistic (percent): 59.77584388
+Model measurements
+ model training instances: 89999
+
+Index
+ 100000
+Votes
+ 0: 357.30819817
+ 1: 15.06752559
+Measurements
+ classified instances: 99999
+ classifications correct (percent): 83.31583316
+ Kappa Statistic (percent): 65.37306693
+ Kappa Temporal Statistic (percent): 65.63117996
+ Kappa M Statistic (percent): 60.44758428
+Model measurements
+ model training instances: 99999
+
+
+