From 5ca06f6eb32a6eb2af25dcf39ddd813c3ad88d03 Mon Sep 17 00:00:00 2001 From: Heitor Date: Wed, 17 Apr 2024 14:46:51 +1200 Subject: [PATCH 01/31] Update adding SSL tab and fast evaluation (capymoa support) --- .../moa/classifiers/AbstractClassifier.java | 20 + .../main/java/moa/classifiers/Classifier.java | 10 +- .../classifiers/SemiSupervisedLearner.java | 16 +- .../ClusterAndLabelClassifier.java | 330 ++++++++++++ .../semisupervised/SSLTaskTester.java | 81 +++ .../SelfTrainingClassifier.java | 349 +++++++++++++ .../SelfTrainingIncrementalClassifier.java | 222 ++++++++ .../SelfTrainingWeightingClassifier.java | 115 +++++ .../AttributeSimilarityCalculator.java | 235 +++++++++ ...EuclideanDistanceSimilarityCalculator.java | 23 + .../GoodAll3SimilarityCalculator.java | 23 + .../IgnoreSimilarityCalculator.java | 20 + ...currenceFrequencySimilarityCalculator.java | 24 + .../LinSimilarityCalculator.java | 32 ++ ...currenceFrequencySimilarityCalculator.java | 26 + .../moa/clusterers/clustream/Clustream.java | 30 +- .../clusterers/clustream/ClustreamKernel.java | 113 ++-- .../evaluation/EfficientEvaluationLoops.java | 484 ++++++++++++------ .../LearningPerformanceEvaluator.java | 46 +- .../java/moa/gui/SemiSupervisedTabPanel.java | 29 ++ .../gui/SemiSupervisedTaskManagerPanel.java | 468 +++++++++++++++++ moa/src/main/java/moa/learners/Learner.java | 14 +- ...ateInterleavedTestThenTrainSSLDelayed.java | 351 +++++++++++++ .../moa/tasks/SemiSupervisedMainTask.java | 24 + moa/src/main/resources/moa/gui/GUI.props | 1 + 25 files changed, 2865 insertions(+), 221 deletions(-) create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/ClusterAndLabelClassifier.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/SSLTaskTester.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingClassifier.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingIncrementalClassifier.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingWeightingClassifier.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/AttributeSimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/EuclideanDistanceSimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/GoodAll3SimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/IgnoreSimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/InverseOccurrenceFrequencySimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/LinSimilarityCalculator.java create mode 100644 moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/OccurrenceFrequencySimilarityCalculator.java create mode 100644 moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java create mode 100644 moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java create mode 100644 moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java create mode 100644 moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java diff --git a/moa/src/main/java/moa/classifiers/AbstractClassifier.java b/moa/src/main/java/moa/classifiers/AbstractClassifier.java index f60467d0d..29350a48e 100644 --- a/moa/src/main/java/moa/classifiers/AbstractClassifier.java +++ b/moa/src/main/java/moa/classifiers/AbstractClassifier.java @@ -105,6 +105,26 @@ public double[] getVotesForInstance(Example example){ @Override public abstract double[] getVotesForInstance(Instance inst); + @Override + public double getConfidenceForPrediction(Instance inst, double prediction) { + double[] votes = this.getVotesForInstance(inst); + double predictionValue = votes[(int) prediction]; + + double sum = 0.0; + for (double vote : votes) + sum += vote; + + // Check if the sum is zero + if (sum == 0.0) + return 0.0; // Return 0 if sum is zero to avoid division by zero + return predictionValue / sum; + } + + @Override + public double getConfidenceForPrediction(Example example, double prediction) { + return getConfidenceForPrediction(example.getData(), prediction); + } + @Override public Prediction getPredictionForInstance(Example example){ return getPredictionForInstance(example.getData()); diff --git a/moa/src/main/java/moa/classifiers/Classifier.java b/moa/src/main/java/moa/classifiers/Classifier.java index 101d7fe3d..7a5acaa4e 100644 --- a/moa/src/main/java/moa/classifiers/Classifier.java +++ b/moa/src/main/java/moa/classifiers/Classifier.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.classifiers; @@ -76,7 +76,7 @@ public interface Classifier extends Learner> { * test instance in each class */ public double[] getVotesForInstance(Instance inst); - + /** * Sets the reference to the header of the data stream. The header of the * data stream is extended from WEKA @@ -86,7 +86,7 @@ public interface Classifier extends Learner> { * @param ih the reference to the data stream header */ //public void setModelContext(InstancesHeader ih); - + /** * Gets the reference to the header of the data stream. The header of the * data stream is extended from WEKA @@ -96,6 +96,8 @@ public interface Classifier extends Learner> { * @return the reference to the data stream header */ //public InstancesHeader getModelContext(); - + public Prediction getPredictionForInstance(Instance inst); + + public double getConfidenceForPrediction(Instance inst, double prediction); } diff --git a/moa/src/main/java/moa/classifiers/SemiSupervisedLearner.java b/moa/src/main/java/moa/classifiers/SemiSupervisedLearner.java index 05caec5a0..ba6954de0 100644 --- a/moa/src/main/java/moa/classifiers/SemiSupervisedLearner.java +++ b/moa/src/main/java/moa/classifiers/SemiSupervisedLearner.java @@ -19,12 +19,16 @@ */ package moa.classifiers; +import com.yahoo.labs.samoa.instances.Instance; +import moa.core.Example; +import moa.learners.Learner; + /** - * Learner interface for incremental semi supervised models. It is used only in the GUI Regression Tab. - * - * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) - * @version $Revision: 7 $ + * Updated learner interface for semi-supervised methods. */ -public interface SemiSupervisedLearner { - +public interface SemiSupervisedLearner extends Learner> { + // Returns the pseudo-label used. If no pseudo-label was used, then return -1. + int trainOnUnlabeledInstance(Instance instance); + + void addInitialWarmupTrainingInstances(); } diff --git a/moa/src/main/java/moa/classifiers/semisupervised/ClusterAndLabelClassifier.java b/moa/src/main/java/moa/classifiers/semisupervised/ClusterAndLabelClassifier.java new file mode 100644 index 000000000..95ad2bf54 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/ClusterAndLabelClassifier.java @@ -0,0 +1,330 @@ +package moa.classifiers.semisupervised; + +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.cluster.Cluster; +import moa.cluster.Clustering; +import moa.clusterers.clustream.Clustream; +import moa.clusterers.clustream.ClustreamKernel; +import moa.core.*; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +import java.util.*; + +/** + * A simple semi-supervised classifier that serves as a baseline. + * The idea is to group the incoming data into micro-clusters, each of which + * is assigned a label. The micro-clusters will then be used for classification of unlabeled data. + */ +public class ClusterAndLabelClassifier extends AbstractClassifier + implements SemiSupervisedLearner { + + private static final long serialVersionUID = 1L; + + public ClassOption clustererOption = new ClassOption("clustream", 'c', + "Used to configure clustream", + Clustream.class, "Clustream"); + + /** Lets user decide if they want to use pseudo-labels */ + public FlagOption usePseudoLabelOption = new FlagOption("pseudoLabel", 'p', + "Using pseudo-label while training"); + + public FlagOption debugModeOption = new FlagOption("debugMode", 'e', + "Print information about the clusters on stdout"); + + /** Decides the labels based on k-nearest cluster, k defaults to 1 */ + public IntOption kNearestClusterOption = new IntOption("kNearestCluster", 'k', + "Issue predictions based on the majority vote from k-nearest cluster", 1); + + /** Number of nearest clusters used to issue prediction */ + private int k; + + private Clustream clustream; + + /** To train using pseudo-label or not */ + private boolean usePseudoLabel; + + /** Number of nearest clusters used to issue prediction */ +// private int k; + + // Statistics + protected long instancesSeen; + protected long instancesPseudoLabeled; + protected long instancesCorrectPseudoLabeled; + + @Override + public String getPurposeString() { + return "A basic semi-supervised learner"; + } + + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + this.clustream = (Clustream) getPreparedClassOption(this.clustererOption); + this.clustream.prepareForUse(); + this.usePseudoLabel = usePseudoLabelOption.isSet(); + this.k = kNearestClusterOption.getValue(); + super.prepareForUseImpl(monitor, repository); + } + + @Override + public void resetLearningImpl() { + this.clustream.resetLearning(); + this.instancesSeen = 0; + this.instancesCorrectPseudoLabeled = 0; + this.instancesPseudoLabeled = 0; + } + + + @Override + public void trainOnInstanceImpl(Instance instance) { + ++this.instancesSeen; + Objects.requireNonNull(this.clustream, "Cluster must not be null!"); + if(this.clustream.getModelContext() == null) + this.clustream.setModelContext(this.getModelContext()); + this.clustream.trainOnInstance(instance); + } + + @Override + public int trainOnUnlabeledInstance(Instance instance) { + // Creates a copy of the instance to be pseudoLabeled + Instance unlabeledInstance = instance.copy(); + // In case the label is available for debugging purposes (i.e. checking the pseudoLabel accuracy), + // we want to save it, but then immediately remove the label to avoid it being used + int groundTruthClassLabel = -999; + if(! unlabeledInstance.classIsMissing()) { + groundTruthClassLabel = (int) unlabeledInstance.classValue(); + unlabeledInstance.setMissing(unlabeledInstance.classIndex()); + } + + int pseudoLabel = -1; + if (this.usePseudoLabel) { + ClustreamKernel closestCluster = getNearestClustreamKernel(this.clustream, unlabeledInstance, false); + pseudoLabel = (closestCluster != null ? Utils.maxIndex(closestCluster.classObserver) : -1); + + unlabeledInstance.setClassValue(pseudoLabel); + this.clustream.trainOnInstance(unlabeledInstance); + + if (pseudoLabel == groundTruthClassLabel) { + ++this.instancesCorrectPseudoLabeled; + } + ++this.instancesPseudoLabeled; + } + else { // Update the cluster without using the pseudoLabel + this.clustream.trainOnInstance(unlabeledInstance); + } + return pseudoLabel; + } + + @Override + public void addInitialWarmupTrainingInstances() { + } + + @Override + public double[] getVotesForInstance(Instance instance) { + Objects.requireNonNull(this.clustream, "Cluster must not be null!"); + // Creates a copy of the instance to be used in here (avoid changing the instance passed to this method) + Instance unlabeledInstance = instance.copy(); + + if(! unlabeledInstance.classIsMissing()) + unlabeledInstance.setMissing(unlabeledInstance.classIndex()); + + Clustering clustering = clustream.getMicroClusteringResult(); + + double[] votes = new double[unlabeledInstance.numClasses()]; + + if(clustering != null) { + if (k == 1) { + ClustreamKernel closestKernel = getNearestClustreamKernel(clustream, unlabeledInstance, false); + if (closestKernel != null) + votes = closestKernel.classObserver; + } + else { + votes = getVotesFromKClusters(this.findKNearestClusters(unlabeledInstance, this.k)); + } + } + return votes; + } + + /** + * Gets the predictions from K nearest clusters + * @param kClusters array of k nearest clusters + * @return the final predictions + */ + private double[] getVotesFromKClusters(ClustreamKernel[] kClusters) { + DoubleVector result = new DoubleVector(); + + for(ClustreamKernel microCluster : kClusters) { + if(microCluster == null) + continue; + + int maxIndex = Utils.maxIndex(microCluster.classObserver); + result.setValue(maxIndex, 1.0); + } + if(result.numValues() > 0) { + result.normalize(); + } + return result.getArrayRef(); + } + + /** + * Finds K nearest cluster from an instance + * @param instance the instance X + * @param k K closest clusters + * @return set of K closest clusters + */ + private ClustreamKernel[] findKNearestClusters(Instance instance, int k) { + Set sortedClusters = new TreeSet<>(new DistanceKernelComparator(instance)); + Clustering clustering = clustream.getMicroClusteringResult(); + + if (clustering == null || clustering.size() == 0) + return new ClustreamKernel[0]; + + // There should be a better way of doing this instead of creating a separate array list + ArrayList clusteringArray = new ArrayList<>(); + for(int i = 0 ; i < clustering.getClustering().size() ; ++i) + clusteringArray.add((ClustreamKernel) clustering.getClustering().get(i)); + + // Sort the clusters according to their distance to instance + sortedClusters.addAll(clusteringArray); + ClustreamKernel[] topK = new ClustreamKernel[k]; + // Keep only the topK clusters, i.e. the closest clusters to instance + Iterator it = sortedClusters.iterator(); + int i = 0; + while (it.hasNext() && i < k) + topK[i++] = it.next(); + + ////////////////////////////////// + if(this.debugModeOption.isSet()) + debugVotingScheme(clustering, instance, topK, true); + ////////////////////////////////// + + return topK; + } + + class DistanceKernelComparator implements Comparator { + + private Instance instance; + + public DistanceKernelComparator(Instance instance) { + this.instance = instance; + } + + @Override + public int compare(ClustreamKernel C1, ClustreamKernel C2) { + double distanceC1 = Clustream.distanceIgnoreNaN(C1.getCenter(), instance.toDoubleArray()); + double distanceC2 = Clustream.distanceIgnoreNaN(C2.getCenter(), instance.toDoubleArray()); + return Double.compare(distanceC1, distanceC2); + } + } + + private ClustreamKernel getNearestClustreamKernel(Clustream clustream, Instance instance, boolean includeClass) { + double minDistance = Double.MAX_VALUE; + ClustreamKernel closestCluster = null; + + List excluded = new ArrayList<>(); + if (!includeClass) + excluded.add(instance.classIndex()); + + Clustering clustering = clustream.getMicroClusteringResult(); + AutoExpandVector kernels = clustering.getClustering(); + + double[] arrayInstance = instance.toDoubleArray(); + + + for(int i = 0 ; i < kernels.size() ; ++i) { + double[] clusterCenter = kernels.get(i).getCenter(); + double distance = Clustream.distanceIgnoreNaN(arrayInstance, clusterCenter); + ////////////////////////////// + if(this.debugModeOption.isSet()) + debugClustreamMicroCluster((ClustreamKernel) kernels.get(i), clusterCenter, distance, true); + ////////////////////////////// + if(distance < minDistance) { + minDistance = distance; + closestCluster = (ClustreamKernel) kernels.get(i); + } + } + /////////////////////////// + if(this.debugModeOption.isSet()) + debugShowInstance(instance); + /////////////////////////// + + return closestCluster; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // instances seen * the number of ensemble members + return new Measurement[]{ + new Measurement("#pseudo-labeled", this.instancesPseudoLabeled), + new Measurement("#correct pseudo-labeled", this.instancesCorrectPseudoLabeled), + new Measurement("accuracy pseudo-labeled", this.instancesCorrectPseudoLabeled / (double) this.instancesPseudoLabeled * 100) + }; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + throw new UnsupportedOperationException("Not supported yet."); + } + + @Override + public boolean isRandomizable() { + return false; + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + /////////////////////////////////// DEBUG METHODS ////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////// + + private void debugShowInstance(Instance instance) { + System.out.print("Instance: ["); + for(int i = 0 ; i < instance.numAttributes() ; ++i) { + System.out.print(instance.value(i) + " "); + } + System.out.println("]"); + } + + private void debugClustreamMicroCluster(ClustreamKernel cluster, double[] clusterCenter, double distance, boolean showMicroClusterValues) { + System.out.print(" MicroCluster: " + cluster.getId()); + if(showMicroClusterValues) { + System.out.print(" ["); + for (int j = 0; j < clusterCenter.length; ++j) { + System.out.print(String.format("%.4f ", clusterCenter[j]) + " "); + } + System.out.print("]"); + } + System.out.print(" distance to instance: " + String.format("%.4f ",distance) + " classObserver: [ "); + + for(int g = 0 ; g < cluster.classObserver.length ; ++g) { + System.out.print(cluster.classObserver[g] + " "); + } + System.out.print("] maxIndex (vote): " + Utils.maxIndex(cluster.classObserver)); + System.out.println(); + } + + private void debugVotingScheme(Clustering clustering, Instance instance, ClustreamKernel[] topK, boolean showAllClusters) { + System.out.println("[DEBUG] Voting Scheme: "); + AutoExpandVector kernels = clustering.getClustering(); + + double[] arrayInstance = instance.toDoubleArray(); + + System.out.println(" TopK: "); + for(int z = 0 ; z < topK.length ; ++z) { + double[] clusterCenter = topK[z].getCenter(); + double distance = Clustream.distanceIgnoreNaN(arrayInstance, clusterCenter); + debugClustreamMicroCluster(topK[z], clusterCenter, distance, true); + } + + if(showAllClusters) { + System.out.println(" All microclusters: "); + for (int x = 0; x < kernels.size(); ++x) { + double[] clusterCenter = kernels.get(x).getCenter(); + double distance = Clustream.distanceIgnoreNaN(arrayInstance, clusterCenter); + debugClustreamMicroCluster((ClustreamKernel) kernels.get(x), clusterCenter, distance, true); + } + } + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/SSLTaskTester.java b/moa/src/main/java/moa/classifiers/semisupervised/SSLTaskTester.java new file mode 100644 index 000000000..b878f31f2 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/SSLTaskTester.java @@ -0,0 +1,81 @@ +package moa.classifiers.semisupervised; + +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.tasks.TaskMonitor; + +/*** + * This class shall be removed later. Just used to verify the EvaluateInterleavedTestThenTrainSSLDelayed + * works as expected. + */ +public class SSLTaskTester extends AbstractClassifier implements SemiSupervisedLearner{ + + protected long instancesWarmupCounter; + protected long instancesLabeledCounter; + protected long instancesUnlabeledCounter; + protected long instancesTestCounter; + + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + super.prepareForUseImpl(monitor, repository); + + this.instancesTestCounter = 0; + this.instancesUnlabeledCounter = 0; + this.instancesLabeledCounter = 0; + } + + @Override + public boolean isRandomizable() { + return false; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + // TODO Auto-generated method stub + ++this.instancesTestCounter; + double[] dummy = new double[inst.numClasses()]; + return dummy; + } + + @Override + public void resetLearningImpl() { + // TODO Auto-generated method stub + } + + @Override + public void addInitialWarmupTrainingInstances() { + ++this.instancesWarmupCounter; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + // TODO Auto-generated method stub + ++this.instancesLabeledCounter; + } + + @Override + public int trainOnUnlabeledInstance(Instance instance) { + ++this.instancesUnlabeledCounter; + return -1; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return new Measurement[]{ + new Measurement("#labeled", this.instancesLabeledCounter), + new Measurement("#unlabeled", this.instancesUnlabeledCounter), + new Measurement("#warmup", this.instancesWarmupCounter) +// new Measurement("accuracy supervised learner", this.evaluatorSupervisedDebug.getPerformanceMeasurements()[1].getValue()) + }; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + // TODO Auto-generated method stub + + } + +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingClassifier.java b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingClassifier.java new file mode 100644 index 000000000..851aace57 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingClassifier.java @@ -0,0 +1,349 @@ +package moa.classifiers.semisupervised; + +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.Classifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; + +/** + * Self-training classifier: it is trained with a limited number of labeled data at first, + * then it predicts the labels of unlabeled data, the most confident predictions are used + * for training in the next iteration. + */ +public class SelfTrainingClassifier extends AbstractClassifier implements SemiSupervisedLearner { + + private static final long serialVersionUID = 1L; + + /* ------------------- + * GUI options + * -------------------*/ + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Any learner to be self-trained", AbstractClassifier.class, + "moa.classifiers.trees.HoeffdingTree"); + + public IntOption batchSizeOption = new IntOption("batchSize", 'b', + "Size of one batch to self-train", + 1000, 1, Integer.MAX_VALUE); + + public MultiChoiceOption thresholdChoiceOption = new MultiChoiceOption("thresholdValue", 't', + "Ways to define the confidence threshold", + new String[] { "Fixed", "AdaptiveWindowing", "AdaptiveVariance" }, + new String[] { + "The threshold is input once and remains unchanged", + "The threshold is updated every h-interval of time", + "The threshold is updated if the confidence score drifts off from the average" + }, 0); + + public FloatOption thresholdOption = new FloatOption("confidenceThreshold", 'c', + "Threshold to evaluate the confidence of a prediction", + 0.7, 0.0, Double.MAX_VALUE); + + public IntOption horizonOption = new IntOption("horizon", 'h', + "The interval of time to update the threshold", 1000); + + public FloatOption ratioThresholdOption = new FloatOption("ratioThreshold", 'r', + "How large should the threshold be wrt to the average confidence score", + 0.8, 0.0, Double.MAX_VALUE); + + public MultiChoiceOption confidenceOption = new MultiChoiceOption("confidenceComputation", + 's', "Choose the method to estimate the prediction uncertainty", + new String[]{ "DistanceMeasure", "FromLearner" }, + new String[]{ "Confidence score from pair-wise distance with the ground truth", + "Confidence score estimated by the learner itself" }, 1); + + /* ------------------- + * Attributes + * -------------------*/ + + /** A learner to be self-trained */ + private Classifier learner; + + /** The size of one batch */ + private int batchSize; + + /** The confidence threshold to decide which predictions to include in the next training batch */ + private double threshold; + + /** Contains the unlabeled instances */ + private List U; + + /** Contains the labeled instances */ + private List L; + + /** Contains the predictions of one batch's training */ +// private List Uhat; + + /** Contains the most confident prediction */ +// private List mostConfident; + + private int horizon; + private int t; + private double ratio; + private double LS; + private double SS; + private double N; + private double lastConfidenceScore; + + // Statistics + protected long instancesSeen; + protected long instancesPseudoLabeled; + protected long instancesCorrectPseudoLabeled; + + + @Override + public String getPurposeString() { return "A self-training classifier"; } + + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + this.learner = (Classifier) getPreparedClassOption(learnerOption); + this.batchSize = batchSizeOption.getValue(); + this.threshold = thresholdOption.getValue(); + this.ratio = ratioThresholdOption.getValue(); + this.horizon = horizonOption.getValue(); + LS = SS = N = t = 0; + allocateBatch(); + super.prepareForUseImpl(monitor, repository); + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return learner.getVotesForInstance(inst); + } + + @Override + public void resetLearningImpl() { + this.learner.resetLearning(); + lastConfidenceScore = LS = SS = N = t = 0; + allocateBatch(); + + this.instancesSeen = 0; + this.instancesCorrectPseudoLabeled = 0; + this.instancesPseudoLabeled = 0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + this.instancesSeen++; + updateThreshold(); + t++; + + L.add(inst); + learner.trainOnInstance(inst); + + + /* if batch B is full, launch the self-training process */ + if (isBatchFull()) { + trainOnUnlabeledBatch(); + } + } + + private void trainOnUnlabeledBatch() { + List> Uhat = predictOnBatch(U); + List> mostConfident = null; + + // chose the method to estimate prediction uncertainty +// if (confidenceOption.getChosenIndex() == 0) +// mostConfident = getMostConfidentDistanceBased(Uhat); +// else + mostConfident = getMostConfidentFromLearner(Uhat); + // train from the most confident examples + for(AbstractMap.SimpleEntry x : mostConfident) { + learner.trainOnInstance(x.getKey()); + if(x.getKey().classValue() == x.getValue()) + ++this.instancesCorrectPseudoLabeled; + ++this.instancesPseudoLabeled; + } + cleanBatch(); + } + + @Override + public void addInitialWarmupTrainingInstances() { + // TODO: add counter, but this may not be necessary for this class + } + + // TODO: Verify if we need to do something else. + @Override + public int trainOnUnlabeledInstance(Instance instance) { + this.instancesSeen++; + U.add(instance); + + if (isBatchFull()) { + trainOnUnlabeledBatch(); + } +// this.trainOnInstanceImpl(instance); + return -1; + } + + private void updateThreshold() { + if (thresholdChoiceOption.getChosenIndex() == 1) updateThresholdWindowing(); + if (thresholdChoiceOption.getChosenIndex() == 2) updateThresholdVariance(); + } + + /** + * Dynamically updates the confidence threshold at the end of each labeledInstancesBuffer horizon + */ + private void updateThresholdWindowing() { + if (t % horizon == 0) { + if (N == 0 || LS == 0 || SS == 0) return; + threshold = (LS / N) * ratio; + t = 0; + } + } + + /** + * Dynamically updates the confidence threshold: + * adapt the threshold if the last confidence score falls out of z-index = 1 zone + */ + private void updateThresholdVariance() { + // TODO update right when it detects a drift, or to wait until H drifts have happened? + if (N == 0 || LS == 0 || SS == 0) return; + double variance = (SS - LS * LS / N) / (N - 1); + double mean = LS / N; + double zscore = (lastConfidenceScore - mean) / variance; + if (Math.abs(zscore) > 1.0) { + threshold = mean * ratio; + } + } + + /** + * Gives prediction for each instance in a given batch. + * @param batch the batch containing unlabeled instances + * @return result the result to save the prediction in + */ + private List> predictOnBatch(List batch) { + List> batchWithPredictions = new ArrayList<>(); + + + for (Instance instance : batch) { + Instance copy = instance.copy(); // use copy because we do not want to modify the original data + double classValue = -1.0; + if(!instance.classIsMissing()) // if it is not missing, assume this is a debug execution and store it for checking pseudo-labelling accuracy. + classValue = instance.classValue(); + + copy.setClassValue(Utils.maxIndex(learner.getVotesForInstance(copy))); + batchWithPredictions.add(new AbstractMap.SimpleEntry (copy, classValue)); + } + + return batchWithPredictions; + } + + /** + * Gets the most confident predictions + * @param batch batch of instances to give prediction to + * @return mostConfident instances that are more confidence than a threshold + */ + private List> getMostConfidentFromLearner(List> batch) { + List> mostConfident = new ArrayList<>(); + for (AbstractMap.SimpleEntry x : batch) { + double[] votes = learner.getVotesForInstance(x.getKey()); + if (votes[Utils.maxIndex(votes)] >= threshold) { + mostConfident.add(x); + } + } + return mostConfident; + } + + /** + * Gets the most confident predictions that exceed the indicated threshold + * @param batch the batch containing the predictions + * @return mostConfident the result containing the most confident prediction from the given batch + */ +// private List> getMostConfidentDistanceBased(List> batch) { +// /* +// * Use distance measure to estimate the confidence of a prediction +// * +// * for each instance X in the batch: +// * for each instance XL in the labeled data: (ground-truth) +// * if X.label == XL.label: (only consider instances sharing the same label) +// * confidence[X] += distance(X, XL) +// * confidence[X] = confidence[X] / |L| (taking the average) +// */ +// List> mostConfident = new ArrayList<>(); +// +// double[] confidences = new double[batch.size()]; +// double conf; +// int i = 0; +// for (AbstractMap.SimpleEntry X : batch) { +// conf = 0; +// for (Instance XL : this.L) { +// if (XL.classValue() == X.getKey().classValue()) { +// conf += Clusterer.distance(XL.toDoubleArray(), X.getKey().toDoubleArray()) / this.L.size(); +// } +// } +// conf = (1.0 / conf > 1.0 ? 1.0 : 1 / conf); // reverse so the distance becomes the confidence +// confidences[i++] = conf; +// // accumulate the statistics +// LS += conf; +// SS += conf * conf; +// N++; +// } +// +// for (double confidence : confidences) lastConfidenceScore += confidence / confidences.length; +// +// /* The confidences are computed using the distance measures, +// * so naturally, the lower the score, the more certain the prediction is. +// * Here we simply retrieve the instances whose confidence score are below a threshold */ +// for (int j = 0; j < confidences.length; j++) { +// if (confidences[j] >= threshold) { +// mostConfident.add(batch.get(j)); +// } +// } +// +// return mostConfident; +// } + + /** + * Checks whether the batch is full + * @return true if the batch is full, false otherwise + */ + private boolean isBatchFull() { + return U.size() + L.size() >= batchSize; + } + + /** Cleans the batch (and its associated variables) */ + private void cleanBatch() { + L.clear(); + U.clear(); +// mostConfident.clear(); + } + + /** Allocates memory to the batch */ + private void allocateBatch() { + this.U = new ArrayList<>(); + this.L = new ArrayList<>(); +// this.mostConfident = new ArrayList<>(); + } + + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // instances seen * the number of ensemble members + return new Measurement[]{ + new Measurement("#pseudo-labeled", this.instancesPseudoLabeled), + new Measurement("#correct pseudo-labeled", this.instancesCorrectPseudoLabeled), + new Measurement("accuracy pseudo-labeled", this.instancesCorrectPseudoLabeled / (double) this.instancesPseudoLabeled * 100) + }; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + + } + + @Override + public boolean isRandomizable() { + return false; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingIncrementalClassifier.java b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingIncrementalClassifier.java new file mode 100644 index 000000000..d9f60a977 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingIncrementalClassifier.java @@ -0,0 +1,222 @@ +package moa.classifiers.semisupervised; + +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.Classifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +/** + * Self-training classifier: Incremental version. + * Instead of using a batch, the model will be update with every instance that arrives. + */ +public class SelfTrainingIncrementalClassifier extends AbstractClassifier implements SemiSupervisedLearner { + + private static final long serialVersionUID = 1L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Any learner to be self-trained", AbstractClassifier.class, + "moa.classifiers.trees.HoeffdingTree"); + + public MultiChoiceOption thresholdChoiceOption = new MultiChoiceOption("thresholdValue", 't', + "Ways to define the confidence threshold", + new String[] { "Fixed", "AdaptiveWindowing", "AdaptiveVariance" }, + new String[] { + "The threshold is input once and remains unchanged", + "The threshold is updated every h-interval of time", + "The threshold is updated if the confidence score drifts off from the average" + }, 0); + + public FloatOption thresholdOption = new FloatOption("confidenceThreshold", 'c', + "Threshold to evaluate the confidence of a prediction", 0.9, 0.0, 1.0); + + public IntOption horizonOption = new IntOption("horizon", 'h', + "The interval of time to update the threshold", 1000); + + public FloatOption ratioThresholdOption = new FloatOption("ratioThreshold", 'r', + "How large should the threshold be wrt to the average confidence score", + 0.95, 0.0, Double.MAX_VALUE); + + /* ------------------- + * Attributes + * -------------------*/ + /** A learner to be self-trained */ + private Classifier learner; + + /** The confidence threshold to decide which predictions to include in the next training batch */ + private double threshold; + + /** Whether the threshold is to be adaptive or fixed*/ + private boolean adaptiveThreshold; + + /** Interval of time to update the threshold */ + private int horizon; + + /** Keep track of time */ + private int t; + + /** Ratio of the threshold wrt the average confidence score*/ + private double ratio; + + // statistics needed to update the confidence threshold + private double LS; + private double SS; + private double N; + private double lastConfidenceScore; + + // Statistics + protected long instancesSeen; + protected long instancesPseudoLabeled; + protected long instancesCorrectPseudoLabeled; + + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + this.learner = (Classifier) getPreparedClassOption(learnerOption); + this.threshold = thresholdOption.getValue(); + this.horizon = horizonOption.getValue(); + this.ratio = ratioThresholdOption.getValue(); + super.prepareForUseImpl(monitor, repository); + } + + @Override + public String getPurposeString() { + return "A self-training classifier that trains at every instance (not using a batch)"; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return learner.getVotesForInstance(inst); + } + + @Override + public void resetLearningImpl() { + LS = SS = N = lastConfidenceScore = 0; + this.instancesSeen = 0; + this.instancesCorrectPseudoLabeled = 0; + this.instancesPseudoLabeled = 0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + /* + * update the threshold + * + * if X is labeled: + * L.train(X) + * else: + * X_hat <- L.predict_probab(X) + * if X_hat.highest_proba > threshold: + * L.train(X_hat) + */ + + updateThreshold(); + + ++this.instancesSeen; + + if (/*!inst.classIsMasked() &&*/ !inst.classIsMissing()) { + learner.trainOnInstance(inst); + } else { + double pseudoLabel = getPrediction(inst); + double confidenceScore = learner.getConfidenceForPrediction(inst, pseudoLabel); + if (confidenceScore >= threshold) { + Instance instCopy = inst.copy(); + instCopy.setClassValue(pseudoLabel); + learner.trainOnInstance(instCopy); + } + // accumulate the statistics to update the adaptive threshold + LS += confidenceScore; + SS += confidenceScore * confidenceScore; + N++; + lastConfidenceScore = confidenceScore; + +// if(pseudoLabel == inst.maskedClassValue()) { +// ++this.instancesCorrectPseudoLabeled; +// } + ++this.instancesPseudoLabeled; + } + + t++; + } + + private void updateThreshold() { + if (thresholdChoiceOption.getChosenIndex() == 1) updateThresholdWindowing(); + if (thresholdChoiceOption.getChosenIndex() == 2) updateThresholdVariance(); + } + + @Override + public void addInitialWarmupTrainingInstances() { + // TODO: add counter, but this may not be necessary for this class + } + + // TODO: Verify if we need to do something else. + @Override + public int trainOnUnlabeledInstance(Instance instance) { + this.trainOnInstanceImpl(instance); + return -1; + } + + /** + * Updates the threshold after each labeledInstancesBuffer horizon + */ + private void updateThresholdWindowing() { + if (t % horizon == 0) { + if (N == 0 || LS == 0 || SS == 0) return; + threshold = (LS / N) * ratio; + threshold = (Math.min(threshold, 1.0)); + // N = LS = SS = 0; // to reset or not? + t = 0; + } + } + + /** + * Update the thresholds based on the variance: + * if the z-score of the last confidence score wrt the mean is more than 1.0, + * update the confidence threshold + */ + private void updateThresholdVariance() { + if (N == 0 || LS == 0 || SS == 0) return; + double variance = (SS - LS * LS / N) / (N - 1); + double mean = LS / N; + double zscore = (lastConfidenceScore - mean) / variance; + if (Math.abs(zscore) > 1.0) { + threshold = mean * ratio; + threshold = (Math.min(threshold, 1.0)); + } + } + + /** + * Gets the prediction from an instance (a shortcut to pass getVotesForInstance) + * @param inst the instance + * @return the most likely prediction (the label with the highest probability in getVotesForInstance) + */ + private double getPrediction(Instance inst) { + return Utils.maxIndex(this.getVotesForInstance(inst)); + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // instances seen * the number of ensemble members + return new Measurement[]{ + new Measurement("#pseudo-labeled", -1), // this.instancesPseudoLabeled), + new Measurement("#correct pseudo-labeled", -1), //this.instancesCorrectPseudoLabeled), + new Measurement("accuracy pseudo-labeled", -1) //this.instancesCorrectPseudoLabeled / (double) this.instancesPseudoLabeled * 100) + }; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + + } + + @Override + public boolean isRandomizable() { + return false; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingWeightingClassifier.java b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingWeightingClassifier.java new file mode 100644 index 000000000..39520ea03 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/SelfTrainingWeightingClassifier.java @@ -0,0 +1,115 @@ +package moa.classifiers.semisupervised; + +import com.github.javacliparser.FlagOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.Classifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +/** + * Variance of Self-training: all instances are used to self-train the learner, but each has a weight, depending + * on the confidence of their prediction + */ +public class SelfTrainingWeightingClassifier extends AbstractClassifier implements SemiSupervisedLearner { + + + @Override + public String getPurposeString() { + return "Self-training classifier that weights instances by confidence score (threshold not used)"; + } + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Any learner to be self-trained", AbstractClassifier.class, + "moa.classifiers.trees.HoeffdingTree"); + + public FlagOption equalWeightOption = new FlagOption("equalWeight", 'w', + "Assigns to all instances a weight equal to 1"); + + /** If set to True, all instances have weight 1; otherwise, the weights are based on the confidence score */ + private boolean equalWeight; + + /** The learner to be self-trained */ + private Classifier learner; + + // Statistics + protected long instancesSeen; + protected long instancesPseudoLabeled; + protected long instancesCorrectPseudoLabeled; + + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + this.learner = (Classifier) getPreparedClassOption(learnerOption); + this.equalWeight = equalWeightOption.isSet(); + super.prepareForUseImpl(monitor, repository); + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return learner.getVotesForInstance(inst); + } + + @Override + public void resetLearningImpl() { + this.learner.resetLearning(); + this.instancesSeen = 0; + this.instancesCorrectPseudoLabeled = 0; + this.instancesPseudoLabeled = 0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + ++this.instancesSeen; + + if (/*!inst.classIsMasked() &&*/ !inst.classIsMissing()) { + learner.trainOnInstance(inst); + } else { + Instance instCopy = inst.copy(); + int pseudoLabel = Utils.maxIndex(learner.getVotesForInstance(instCopy)); + instCopy.setClassValue(pseudoLabel); + if (!equalWeight) instCopy.setWeight(learner.getConfidenceForPrediction(instCopy, pseudoLabel)); + learner.trainOnInstance(instCopy); + +// if(pseudoLabel == inst.maskedClassValue()) { +// ++this.instancesCorrectPseudoLabeled; +// } + ++this.instancesPseudoLabeled; + } + } + + @Override + public void addInitialWarmupTrainingInstances() { + // TODO: add counter, but this may not be necessary for this class + } + + // TODO: Verify if we need to do something else. + @Override + public int trainOnUnlabeledInstance(Instance instance) { + this.trainOnInstanceImpl(instance); + return -1; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // instances seen * the number of ensemble members + return new Measurement[]{ + new Measurement("#pseudo-labeled", -1), // this.instancesPseudoLabeled), + new Measurement("#correct pseudo-labeled", -1), //this.instancesCorrectPseudoLabeled), + new Measurement("accuracy pseudo-labeled", -1) //this.instancesCorrectPseudoLabeled / (double) this.instancesPseudoLabeled * 100) + }; + } + + + + @Override + public void getModelDescription(StringBuilder out, int indent) {} + + @Override + public boolean isRandomizable() { + return false; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/AttributeSimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/AttributeSimilarityCalculator.java new file mode 100644 index 000000000..43be4f75f --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/AttributeSimilarityCalculator.java @@ -0,0 +1,235 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; +import moa.core.DoubleVector; + +import java.util.HashMap; +import java.util.Map; + +/** + * An observer that collects statistics for similarity computation of categorical attributes. + * This observer observes the categorical attributes of one dataset. + */ +public abstract class AttributeSimilarityCalculator { + + /** + *

Collection of statistics of one attribute, including:

+ *
    + *
  • ID: index of the attribute
  • + *
  • f_k: frequency of a value of an attribute
  • + *
+ */ + class AttributeStatistics extends Attribute { + /** ID of the attribute */ + private int id; + + /** Frequency of the values of the attribute */ + private DoubleVector fk; + + /** The decorated attribute */ + private Attribute attribute; + + /** + * Creates a new collection of statistics of an attribute + * @param id ID of the attribute + */ + AttributeStatistics(int id) { + this.id = id; + this.fk = new DoubleVector(); + } + + AttributeStatistics(Attribute attr, int id) { + this.attribute = attr; + this.id = id; + this.fk = new DoubleVector(); + } + + /** Gets the ID of the attribute */ + int getId() { return this.id; } + + /** Gets the decorated attribute */ + Attribute getAttribute() { return this.attribute; } + + /** + * Gets f_k(x) i.e. the number of times x is a value of the attribute of ID k + * @param value the attribute value + * @return the number of times x is a value of the attribute of ID k + */ + int getFrequencyOfValue(int value) { + return (int) this.fk.getValue(value); + } + + /** + * Updates the frequency of a value + * @param value the value X_k + * @param frequency the frequency + */ + void updateFrequencyOfValue(int value, int frequency) { + this.fk.addToValue((int)value, frequency); + } + } + + /** Size of the dataset (number of instances) */ + protected int N; + + /** Dimension of the dataset (number of attributes) */ + protected int d; + + /** Storing the statistics of each attribute */ + //private AttributeStatistics[] attrStats; + protected Map attributeStats; + + /** A small value to avoid division by 0 */ + protected static double SMALL_VALUE = 1e-5; + + /** Creates a new observer */ + public AttributeSimilarityCalculator() { + this.N = this.d = 0; + this.attributeStats = new HashMap<>(); + } + + /** + * Creates a new observer with a predefined number of attributes + * @param d number of attributes + */ + public AttributeSimilarityCalculator(int d) { + this.d = d; + this.attributeStats = new HashMap<>(); + } + + /** + * Returns the size of the dataset + * @return the size of the dataset (number of instances) + */ + public int getSize() { return this.N; } + + /** + * Increases the number of instances seen so far + * @param amount the amount to increase + */ + public void increaseSize(int amount) { this.N += amount; } + + /** + * Returns the dimension size + * @return the dimension size (number of attributes) + */ + public int getDimension() { return this.d; } + + /** + * Specifies the dimension of the dataset + * @param d the dimension + */ + public void setDimension(int d) { this.d = d; } + + /** + * Returns the number of values taken by A_k collected online i.e. n_k + * @param attr the attribute A_k + * @return number of values taken by A_k (n_k) + */ + public int getNumberOfAttributes(Attribute attr) { + if (attributeStats.containsKey(attr)) return attributeStats.get(attr).numValues(); + return 0; + } + + /** + * Gets the frequency of value x of attribute A_k i.e. f_k(x) + * @param attr the attribute + * @param value the value + * @return the number of times x occurs as value of attribute A_k; 0 if attribute k has not been observed so far + */ + public double getFrequencyOfValueByAttribute(Attribute attr, int value) { + if (attributeStats.containsKey(attr)) return attributeStats.get(attr).getFrequencyOfValue(value); + return 0; + } + + /** + * Gets the sample probability of attribute A_k to take the value x in the dataset + * i.e. p_k(x) = f_k(x) / N + * @param attr the attribute A_k + * @param value the value x + * @return the sample probability p_k(x) + */ + public double getSampleProbabilityOfAttributeByValue(Attribute attr, int value) { + return this.getFrequencyOfValueByAttribute(attr, value) / this.N; + } + + /** + * Gets another probability estimate of attribute A_k to take the value x in the dataset + * i.e. p_k^2 = f_k(x) * [ f_k(x) - 1 ] / [ N * (N - 1) ] + * @param attr the attribute A_k + * @param value the value x + * @return the sample probability p_k^2(x) + */ + public double getProbabilityEstimateOfAttributeByValue(Attribute attr, int value) { + double fX = getFrequencyOfValueByAttribute(attr, value); + if (N == 1) return 0; + return (fX * (fX - 1)) / (N * (N - 1)); + } + + /** + * Updates the statistics of an attribute A_k, e.g. frequency of the value (f_k) + * @param id ID of the attribute A_k + * @param attr the attribute A_k + * @param value the value of A_k + */ + public void updateAttributeStatistics(int id, Attribute attr, int value) { + if (!attributeStats.containsKey(attr)) { + AttributeStatistics stat = new AttributeStatistics(attr, id); + stat.updateFrequencyOfValue(value, 1); + attributeStats.put(attr, stat); + } else { +// System.out.println("attributeStats.get(attr).updateFrequencyOfValue(value, 1);" + attr + " " + value); + if(value >= 0) + attributeStats.get(attr).updateFrequencyOfValue(value, 1); + else + System.out.println("if(value < 0)"); + } + } + + /** + * Computes the similarity of categorical attributes of two instances X and Y, denoted S(X, Y). + * S(X, Y) = Sum of [w_k * S_k(X_k, Y_k)] for k from 1 to d, + * X_k and Y_k are from A_k (attribute k of the dataset). + * + * Note that X and Y must come from the same dataset, contain the same set of attributes, + * and numeric attributes will not be taken into account. + * @param X instance X + * @param Y instance Y + * @return the similarity of categorical attributes of X and Y + */ + public double computeSimilarityOfInstance(Instance X, Instance Y) { + // for k from 1 to d + double S = 0; + for (int i = 0; i < X.numAttributes(); i++) { + // sanity check + if (!X.attribute(i).equals(Y.attribute(i))) continue; // if X and Y's attributes are not aligned + Attribute Ak = X.attribute(i); + if (Ak.isNumeric() || !attributeStats.containsKey(Ak) || i == X.classIndex()) continue; + // computation + double wk = computeWeightOfAttribute(Ak, X, Y); + double Sk = computePerAttributeSimilarity(Ak, (int)X.value(Ak), (int)Y.value(Ak)); + S += (wk * Sk); + } + return S; + } + + /** + * Computes the per-attribute similarity S_k(X_k, Y_k) between two value X_k and Y_k + * of the attribute A_k. X_k and Y_k must be from A_k. + * + * To be overriden by subclasses. + * @param attr the attribute A_k + * @param X_k the value of X_k + * @param Y_k the value of Y_k + * @return the per-attribute similarity S_k(X_k, Y_k) + */ + public abstract double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k); + + /** + * Computes the weight w_k of an attribute A_k. To be overriden by subclasses. + * @param attr the attribute A_k + * @return the weight w_k of A_k + */ + public abstract double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y); +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/EuclideanDistanceSimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/EuclideanDistanceSimilarityCalculator.java new file mode 100644 index 000000000..3d08cf4d6 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/EuclideanDistanceSimilarityCalculator.java @@ -0,0 +1,23 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Computes the per-attribute similarity of categorical attributes with Euclidean distance, + * i.e. to consider them as numeric attributes + */ +public class EuclideanDistanceSimilarityCalculator extends AttributeSimilarityCalculator { + + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + // TODO NOT CORRECT !!! To fix!!! + return Math.sqrt((X_k - Y_k) * (X_k - Y_k)); + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + return 1; + } + +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/GoodAll3SimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/GoodAll3SimilarityCalculator.java new file mode 100644 index 000000000..0c0119b9a --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/GoodAll3SimilarityCalculator.java @@ -0,0 +1,23 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Computes the similarity of categorical attributes using GoodAll3: + * if X_k == Y_k: 1 - p_k^2(x) + * else: 0 + */ +public class GoodAll3SimilarityCalculator extends AttributeSimilarityCalculator { + + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + if (X_k == Y_k) return 1 - getProbabilityEstimateOfAttributeByValue(attr, (int)X_k); + return 0; + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + return 1.0 / (float) d; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/IgnoreSimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/IgnoreSimilarityCalculator.java new file mode 100644 index 000000000..2c92e6037 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/IgnoreSimilarityCalculator.java @@ -0,0 +1,20 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Does nothing, just ignores the categorical attributes + */ +public class IgnoreSimilarityCalculator extends AttributeSimilarityCalculator { + + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + return 0; + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + return 0; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/InverseOccurrenceFrequencySimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/InverseOccurrenceFrequencySimilarityCalculator.java new file mode 100644 index 000000000..a45a5c4d2 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/InverseOccurrenceFrequencySimilarityCalculator.java @@ -0,0 +1,24 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Computes the similarity between categorical attributes using Inverse Occurrence Frequency (IOF) + */ +public class InverseOccurrenceFrequencySimilarityCalculator extends AttributeSimilarityCalculator { + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + if (X_k == Y_k) return 1.0; + double fX = Math.max(attributeStats.get(attr).getFrequencyOfValue((int)X_k), SMALL_VALUE); + double fY = Math.max(attributeStats.get(attr).getFrequencyOfValue((int)Y_k), SMALL_VALUE); + double logX = fX > 0 ? Math.log(fX) : 0.0; + double logY = fY > 0 ? Math.log(fY) : 0.0; + return 1 / (1 + logX * logY); + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + return 0; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/LinSimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/LinSimilarityCalculator.java new file mode 100644 index 000000000..35bd68b66 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/LinSimilarityCalculator.java @@ -0,0 +1,32 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Computes the similarity of categorical attributes using Lin formula: + * if X_k == Y_k: S_k = 2 * log[p_k(X_k)] + * else: S_k = 2 * log[p_k(X_k) + p_k(Y_k)] + */ +public class LinSimilarityCalculator extends AttributeSimilarityCalculator { + + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + double pX = getSampleProbabilityOfAttributeByValue(attr, (int)X_k); + double pY = getSampleProbabilityOfAttributeByValue(attr, (int)Y_k); + if (X_k == Y_k) return 2.0 * Math.log(pX); + return 2.0 * Math.log(pX + pY); + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + double deno = 0; + for (int i = 0; i < d; i++) { + double pX = getSampleProbabilityOfAttributeByValue(attr, (int)X.value(i)); + double pY = getSampleProbabilityOfAttributeByValue(attr, (int)Y.value(i)); + deno += Math.log(pX) + Math.log(pY); + } + if (deno == 0) return 1.0; + return 1.0 / deno; + } +} diff --git a/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/OccurrenceFrequencySimilarityCalculator.java b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/OccurrenceFrequencySimilarityCalculator.java new file mode 100644 index 000000000..b23b53859 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/semisupervised/attributeSimilarity/OccurrenceFrequencySimilarityCalculator.java @@ -0,0 +1,26 @@ +package moa.classifiers.semisupervised.attributeSimilarity; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Computes the attribute similarity using Occurrence Frequency (OF): + * if X_k == Y_k: S_k(X_k, Y_k) = 1 + * else: S_k(X_k, Y_k) = 1 / (1 + log(N / f_k(X_k)) * log(N / f_k(Y_k))) + */ +public class OccurrenceFrequencySimilarityCalculator extends AttributeSimilarityCalculator { + + @Override + public double computePerAttributeSimilarity(Attribute attr, double X_k, double Y_k) { + if (X_k == Y_k) return 1; + if (attributeStats.get(attr) == null) return SMALL_VALUE; + double fX = Math.max(attributeStats.get(attr).getFrequencyOfValue((int)X_k), SMALL_VALUE); + double fY = Math.max(attributeStats.get(attr).getFrequencyOfValue((int)Y_k), SMALL_VALUE); + return 1.0 / (1.0 + Math.log(N / fX) * Math.log(N / fY)); + } + + @Override + public double computeWeightOfAttribute(Attribute attr, Instance X, Instance Y) { + return 1.0 / (double) d; + } +} diff --git a/moa/src/main/java/moa/clusterers/clustream/Clustream.java b/moa/src/main/java/moa/clusterers/clustream/Clustream.java index 58e0428bd..93e221c7a 100644 --- a/moa/src/main/java/moa/clusterers/clustream/Clustream.java +++ b/moa/src/main/java/moa/clusterers/clustream/Clustream.java @@ -14,8 +14,8 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * + * + * */ package moa.clusterers.clustream; @@ -98,7 +98,9 @@ public void trainOnInstanceImpl(Instance instance) { // Clustering kmeans_clustering = kMeans(k, buffer); for ( int i = 0; i < kmeans_clustering.size(); i++ ) { - kernels[i] = new ClustreamKernel( new DenseInstance(1.0,centers[i].getCenter()), dim, timestamp, t, m ); + Instance newInstance = new DenseInstance(1.0,centers[i].getCenter()); + newInstance.setDataset(instance.dataset()); + kernels[i] = new ClustreamKernel(newInstance, dim, timestamp, t, m ); } buffer.clear(); @@ -111,7 +113,7 @@ public void trainOnInstanceImpl(Instance instance) { double minDistance = Double.MAX_VALUE; for ( int i = 0; i < kernels.length; i++ ) { //System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation()); - double distance = distance(instance.toDoubleArray(), kernels[i].getCenter() ); + double distance = distanceIgnoreNaN(instance.toDoubleArray(), kernels[i].getCenter() ); if ( distance < minDistance ) { closestKernel = kernels[i]; minDistance = distance; @@ -213,6 +215,26 @@ private static double distance(double[] pointA, double [] pointB){ return Math.sqrt(distance); } + /*** + * This function avoids the undesirable situation where the whole distance becomes NaN if one of the attributes + * is NaN. + * (SSL) This was observed when calculating the distance between an instance without the class label and a center + * which was updated using the class label. + * @param pointA + * @param pointB + * @return + */ + public static double distanceIgnoreNaN(double[] pointA, double [] pointB){ + double distance = 0.0; + for (int i = 0; i < pointA.length; i++) { + if(!(Double.isNaN(pointA[i]) || Double.isNaN(pointB[i]))) { + double d = pointA[i] - pointB[i]; + distance += d * d; + } + } + return Math.sqrt(distance); + } + //wrapper... we need to rewrite kmeans to points, not clusters, doesnt make sense anymore // public static Clustering kMeans( int k, ArrayList points, int dim ) { // ArrayList cl = new ArrayList(); diff --git a/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java b/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java index d4f901ba4..609ad8fdb 100644 --- a/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java +++ b/moa/src/main/java/moa/clusterers/clustream/ClustreamKernel.java @@ -14,8 +14,8 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * + * + * */ package moa.clusterers.clustream; @@ -25,9 +25,9 @@ import com.yahoo.labs.samoa.instances.Instance; public class ClustreamKernel extends CFCluster { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - private final static double EPSILON = 0.00005; + private final static double EPSILON = 0.00005; public static final double MIN_VARIANCE = 1e-50; protected double LST; @@ -36,66 +36,111 @@ public class ClustreamKernel extends CFCluster { int m; double t; + public double[] classObserver; + + public static int ID_GENERATOR = 0; - public ClustreamKernel( Instance instance, int dimensions, long timestamp , double t, int m) { + public ClustreamKernel(Instance instance, int dimensions, long timestamp , double t, int m) { super(instance, dimensions); + +// Avoid situations where the instance header hasn't been defined and runtime errors. + if(instance.dataset() != null) { + this.classObserver = new double[instance.numClasses()]; +// instance.numAttributes() <= instance.classIndex() -> edge case where the class index is equal the +// number of attributes (i.e. there is no class value in the attributes array). + if (instance.numAttributes() > instance.classIndex() && + !instance.classIsMissing() && + instance.classValue() >= 0 && + instance.classValue() < instance.numClasses()) { + this.classObserver[(int) instance.classValue()]++; + } + } + this.setId(ID_GENERATOR++); this.t = t; this.m = m; this.LST = timestamp; - this.SST = timestamp*timestamp; + this.SST = timestamp*timestamp; } public ClustreamKernel( ClustreamKernel cluster, double t, int m ) { super(cluster); + this.setId(ID_GENERATOR++); this.t = t; this.m = m; this.LST = cluster.LST; this.SST = cluster.SST; + this.classObserver = cluster.classObserver; } public void insert( Instance instance, long timestamp ) { - N++; - LST += timestamp; - SST += timestamp*timestamp; - - for ( int i = 0; i < instance.numValues(); i++ ) { - LS[i] += instance.value(i); - SS[i] += instance.value(i)*instance.value(i); - } + if(this.classObserver == null) + this.classObserver = new double[instance.numClasses()]; + if(!instance.classIsMissing() && + instance.classValue() >= 0 && + instance.classValue() < instance.numClasses()) { + this.classObserver[(int)instance.classValue()]++; + } + N++; + LST += timestamp; + SST += timestamp*timestamp; + + for ( int i = 0; i < instance.numValues(); i++ ) { + LS[i] += instance.value(i); + SS[i] += instance.value(i)*instance.value(i); + } } @Override public void add( CFCluster other2 ) { ClustreamKernel other = (ClustreamKernel) other2; - assert( other.LS.length == this.LS.length ); - this.N += other.N; - this.LST += other.LST; - this.SST += other.SST; - - for ( int i = 0; i < LS.length; i++ ) { - this.LS[i] += other.LS[i]; - this.SS[i] += other.SS[i]; - } + assert( other.LS.length == this.LS.length ); + this.N += other.N; + this.LST += other.LST; + this.SST += other.SST; + this.classObserver = sumClassObservers(other.classObserver, this.classObserver); + + for ( int i = 0; i < LS.length; i++ ) { + this.LS[i] += other.LS[i]; + this.SS[i] += other.SS[i]; + } } + private double[] sumClassObservers(double[] A, double[] B) { + double[] result = null; + if (A != null && B != null) { + result = new double[A.length]; + if(A.length == B.length) + for(int i = 0 ; i < A.length ; ++i) + result[i] += A[i] + B[i]; + } + return result; + } + +// @Override +// public void add( CFCluster other2, long timestamp) { +// this.add(other2); +// // accumulate the count +// this.accumulateWeight(other2, timestamp); +// } + public double getRelevanceStamp() { - if ( N < 2*m ) - return getMuTime(); - - return getMuTime() + getSigmaTime() * getQuantile( ((double)m)/(2*N) ); + if ( N < 2*m ) + return getMuTime(); + + return getMuTime() + getSigmaTime() * getQuantile( ((double)m)/(2*N) ); } private double getMuTime() { - return LST / N; + return LST / N; } private double getSigmaTime() { - return Math.sqrt(SST/N - (LST/N)*(LST/N)); + return Math.sqrt(SST/N - (LST/N)*(LST/N)); } private double getQuantile( double z ) { - assert( z >= 0 && z <= 1 ); - return Math.sqrt( 2 ) * inverseError( 2*z - 1 ); + assert( z >= 0 && z <= 1 ); + return Math.sqrt( 2 ) * inverseError( 2*z - 1 ); } @Override @@ -187,7 +232,7 @@ private double[] getVarianceVector() { } } else{ - + } } return res; @@ -223,7 +268,7 @@ private double calcNormalizedDistance(double[] point) { return Math.sqrt(res); } - /** + /** * Approximates the inverse error function. Clustream needs this. * @param x */ @@ -266,7 +311,7 @@ protected void getClusterSpecificInfo(ArrayList infoTitle, ArrayList windowedResults; public double[] cumulativeResults; - public ArrayList targets; - public ArrayList predictions; public HashMap otherMeasurements; public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults) { this.windowedResults = windowedResults; this.cumulativeResults = cumulativeResults; - this.targets = null; - this.predictions = null; - } - - public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, - ArrayList targets, ArrayList predictions) { - this.windowedResults = windowedResults; - this.cumulativeResults = cumulativeResults; - this.targets = targets; - this.predictions = predictions; } - /*** - * This constructor is useful to store metrics beyond the evaluation metrics available through the evaluators. - * @param windowedResults - * @param cumulativeResults - * @param otherMeasurements - */ public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, HashMap otherMeasurements) { this(windowedResults, cumulativeResults); @@ -93,29 +52,23 @@ public PrequentialResult(ArrayList windowedResults, double[] cumulativ * @param windowedEvaluator * @param maxInstances * @param windowSize - * @return PrequentialResult is a custom class that holds the respective results from the execution + * @return the return has to be an ArrayList because we don't know ahead of time how many windows will be produced */ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Learner learner, LearningPerformanceEvaluator basicEvaluator, LearningPerformanceEvaluator windowedEvaluator, - long maxInstances, long windowSize, - boolean storeY, boolean storePredictions) { + long maxInstances, long windowSize) { int instancesProcessed = 0; if (!stream.hasMoreInstances()) stream.restart(); ArrayList windowed_results = new ArrayList<>(); - ArrayList targetValues = new ArrayList<>(); - ArrayList predictions = new ArrayList<>(); - while (stream.hasMoreInstances() && (maxInstances == -1 || instancesProcessed < maxInstances)) { - Example instance = stream.nextInstance(); - if (storeY) - targetValues.add(instance.getData().classValue()); + Example instance = stream.nextInstance(); double[] prediction = learner.getVotesForInstance(instance); if (basicEvaluator != null) @@ -123,9 +76,6 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear if (windowedEvaluator != null) windowedEvaluator.addResult(instance, prediction); - if (storePredictions) - predictions.add(prediction.length == 0? 0 : prediction[0]); - learner.trainOnInstance(instance); instancesProcessed++; @@ -156,62 +106,280 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear for (int i = 0; i < cumulative_results.length; ++i) cumulative_results[i] = measurements[i].getValue(); } - if (!storePredictions && !storeY) - return new PrequentialResult(windowed_results, cumulative_results); - else - return new PrequentialResult(windowed_results, cumulative_results, targetValues, predictions); + + return new PrequentialResult(windowed_results, cumulative_results); } + public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, Learner learner, + LearningPerformanceEvaluator basicEvaluator, + LearningPerformanceEvaluator windowedEvaluator, + long maxInstances, + long windowSize, + long initialWindowSize, + long delayLength, + double labelProbability, + int randomSeed, + boolean debugPseudoLabels) { +// int delayLength = this.delayLengthOption.getValue(); +// double labelProbability = this.labelProbabilityOption.getValue(); + + RandomGenerator taskRandom = new MersenneTwister(randomSeed); +// ExampleStream stream = (ExampleStream) getPreparedClassOption(this.streamOption); +// Learner learner = getLearner(stream); - /*** - * The following code can be used to provide examples of how to use the class. - * In the future, some of these examples can be turned into tests. - * @param args - */ - public static void main(String[] args) { - examplePrequentialEvaluation_edge_cases1(); - examplePrequentialEvaluation_edge_cases2(); - examplePrequentialEvaluation_edge_cases3(); - examplePrequentialEvaluation_edge_cases4(); - examplePrequentialEvaluation_SampleFrequency_TestThenTrain(); - examplePrequentialRegressionEvaluation(); - examplePrequentialEvaluation(); - exampleTestThenTrainEvaluation(); - exampleWindowedEvaluation(); - - // Run time efficiency evaluation examples - StreamingRandomPatches srp10 = new StreamingRandomPatches(); - srp10.getOptions().setViaCLIString("-s 10"); // 10 learners - srp10.setRandomSeed(5); - srp10.prepareForUse(); - - StreamingRandomPatches srp100 = new StreamingRandomPatches(); - srp100.getOptions().setViaCLIString("-s 100"); // 100 learners - srp100.setRandomSeed(5); - srp100.prepareForUse(); - - int maxInstances = 100000; - examplePrequentialEfficiency(srp10, maxInstances); - examplePrequentialEfficiency(srp100, maxInstances); + int instancesProcessed = 0; + int numCorrectPseudoLabeled = 0; + int numUnlabeledData = 0; + int numInstancesTested = 0; + + if (!stream.hasMoreInstances()) + stream.restart(); + + ArrayList windowed_results = new ArrayList<>(); + + HashMap other_measures = new HashMap<>(); + + // The buffer is a list of tuples. The first element is the index when + // it should be emitted. The second element is the instance itself. + List> delayBuffer = new ArrayList>(); + + while (stream.hasMoreInstances() && + (maxInstances == -1 || instancesProcessed < maxInstances)) { + + // TRAIN on delayed instances + while (delayBuffer.size() > 0 + && delayBuffer.get(0).getKey() == instancesProcessed) { + Example delayedExample = delayBuffer.remove(0).getValue(); +// System.out.println("[TRAIN][DELAY] "+delayedExample.getData().toString()); + learner.trainOnInstance(delayedExample); + } + + Example instance = stream.nextInstance(); + Example unlabeledExample = instance.copy(); + int trueClass = (int) ((Instance) instance.getData()).classValue(); + + // In case it is set, then the label is not removed. We want to pass the + // labelled data to the learner even in trainOnUnlabeled data to generate statistics such as number + // of correctly pseudo-labeled instances. + if (!debugPseudoLabels) { + // Remove the label of the unlabeledExample indirectly through + // unlabeledInstanceData. + Instance __instance = (Instance) unlabeledExample.getData(); + __instance.setMissing(__instance.classIndex()); + } + + // WARMUP + // Train on the initial instances. These are not used for testing! + if (instancesProcessed < initialWindowSize) { +// if (learner instanceof SemiSupervisedLearner) +// ((SemiSupervisedLearner) learner).addInitialWarmupTrainingInstances(); +// System.out.println("[TRAIN][INITIAL_WINDOW] "+instance.getData().toString()); + learner.trainOnInstance(instance); + instancesProcessed++; + continue; + } + + Boolean is_labeled = labelProbability > taskRandom.nextDouble(); + if (!is_labeled) { + numUnlabeledData++; + } + + // TEST + // Obtain the prediction for the testInst (i.e. no label) +// System.out.println("[TEST] " + unlabeledExample.getData().toString()); + double[] prediction = learner.getVotesForInstance(unlabeledExample); + numInstancesTested++; + + if (basicEvaluator != null) + basicEvaluator.addResult(instance, prediction); + if (windowedEvaluator != null) + windowedEvaluator.addResult(instance, prediction); + + int pseudoLabel = -1; + // TRAIN + if (is_labeled && delayLength >= 0) { + // The instance will be labeled but has been delayed + if (learner instanceof SemiSupervisedLearner) { +// System.out.println("[TRAIN_UNLABELED][DELAYED] " + unlabeledExample.getData().toString()); + pseudoLabel = ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData()); + } + delayBuffer.add(new MutablePair(1 + instancesProcessed + delayLength, instance)); + } else if (is_labeled) { +// System.out.println("[TRAIN] " + instance.getData().toString()); + // The instance will be labeled and is not delayed e.g delayLength = -1 + learner.trainOnInstance(instance); + } else { + // The instance will never be labeled + if (learner instanceof SemiSupervisedLearner) { +// System.out.println("[TRAIN_UNLABELED][IMMEDIATE] " + unlabeledExample.getData().toString()); + pseudoLabel = ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData()); + } + } + if(trueClass == pseudoLabel) + numCorrectPseudoLabeled++; + + instancesProcessed++; + + if (windowedEvaluator != null) + if (instancesProcessed % windowSize == 0) { + Measurement[] measurements = windowedEvaluator.getPerformanceMeasurements(); + double[] values = new double[measurements.length]; + for (int i = 0; i < values.length; ++i) + values[i] = measurements[i].getValue(); + windowed_results.add(values); + } + } + if (windowedEvaluator != null) + if (instancesProcessed % windowSize != 0) { + Measurement[] measurements = windowedEvaluator.getPerformanceMeasurements(); + double[] values = new double[measurements.length]; + for (int i = 0; i < values.length; ++i) + values[i] = measurements[i].getValue(); + windowed_results.add(values); + } + + double[] cumulative_results = null; + + if (basicEvaluator != null) { + Measurement[] measurements = basicEvaluator.getPerformanceMeasurements(); + cumulative_results = new double[measurements.length]; + for (int i = 0; i < cumulative_results.length; ++i) + cumulative_results[i] = measurements[i].getValue(); + } + + // TODO: Add this measures in a windowed way. + other_measures.put("num_unlabeled_instances", (double) numUnlabeledData); + other_measures.put("num_correct_pseudo_labeled", (double) numCorrectPseudoLabeled); + other_measures.put("num_instances_tested", (double) numInstancesTested); + other_measures.put("pseudo_label_accuracy", (double) numCorrectPseudoLabeled/numInstancesTested); + return new PrequentialResult(windowed_results, cumulative_results, other_measures); + } + + /******************************************************************************************************************/ + /******************************************************************************************************************/ + /***************************************** TESTS ******************************************************************/ + /******************************************************************************************************************/ + /******************************************************************************************************************/ + + private static void testPrequentialSSL(String file_path, Learner learner, + long maxInstances, + long windowSize, + long initialWindowSize, + long delayLength, + double labelProbability) { + System.out.println( + "maxInstances: " + maxInstances + ", " + + "windowSize: " + windowSize + ", " + + "initialWindowSize: " + initialWindowSize + ", " + + "delayLength: " + delayLength + ", " + + "labelProbability: " + labelProbability + ); + + // Record the start time + long startTime = System.currentTimeMillis(); + + ArffFileStream stream = new ArffFileStream(file_path, -1); + stream.prepareForUse(); + + BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); + basic_evaluator.recallPerClassOption.setValue(true); + basic_evaluator.prepareForUse(); + + WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator(); + windowed_evaluator.widthOption.setValue((int) windowSize); + windowed_evaluator.prepareForUse(); + + PrequentialResult result = PrequentialSSLEvaluation(stream, learner, + basic_evaluator, + windowed_evaluator, + maxInstances, + windowSize, + initialWindowSize, + delayLength, + labelProbability, 1, true); + + // Record the end time + long endTime = System.currentTimeMillis(); + + // Calculate the elapsed time in milliseconds + long elapsedTime = endTime - startTime; + + // Print the elapsed time + System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds"); + System.out.println("Number of unlabeled instances: " + result.otherMeasurements.get("num_unlabeled_instances")); + + System.out.println("\tBasic performance"); + for (int i = 0; i < result.cumulativeResults.length; ++i) + System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + result.cumulativeResults[i]); + + System.out.println("\tWindowed performance"); + for (int j = 0; j < result.windowedResults.size(); ++j) { + System.out.print("Window: " + j + ", "); + for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i) + System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + result.windowedResults.get(j)[i]); + } } + public static void main(String[] args) { + String hyper_arff = "/Users/gomeshe/Desktop/data/Hyper100k.arff"; + String debug_arff = "/Users/gomeshe/Desktop/data/debug_prequential_SSL.arff"; + String ELEC_arff = "/Users/gomeshe/Dropbox/ciencia_computacao/lecturer/research/ssl_disagreement/datasets/ELEC/elecNormNew.arff"; + + NaiveBayes learner = new NaiveBayes(); + learner.prepareForUse(); + +// testPrequentialSSL(debug_arff, learner, 100, 10, 0, 0, 1.0); // OK +// testPrequentialSSL(debug_arff, learner, 100, 10, 1, 0, 1.0); //OK +// testPrequentialSSL(debug_arff, learner, 10, 10, 5, 0, 1.0); // OK +// testPrequentialSSL(debug_arff, learner, 10, 10, -1, 1, 1.0); // OK +// testPrequentialSSL(debug_arff, learner, 20, 10, -1, 10, 1.0); // OK +// testPrequentialSSL(debug_arff, learner, 20, 10, -1, 2, 0.5); // OK +// testPrequentialSSL(debug_arff, learner, 100, 10, 50, 2, 0.0); // OK +// testPrequentialSSL(debug_arff, learner, 100, 10, 0, 90, 1.0); // OK +// testPrequentialSSL(debug_arff, learner, 100, 10, 0, -1, 0.5); // OK + +// testPrequentialSSL(hyper_arff, learner, -1, 1000, -1, -1, 1.0); +// testPrequentialSSL(hyper_arff, learner, -1, 1000, -1, -1, 0.5); // OK + +// testPrequentialSSL(hyper_arff, learner, -1, 1000, 1000, -1, 0.5); + + ClusterAndLabelClassifier ssl_learner = new ClusterAndLabelClassifier(); + ssl_learner.prepareForUse(); + + testPrequentialSSL(ELEC_arff, ssl_learner, 10000, 1000, -1, -1, 0.01); + +// testWindowedEvaluation(); +// testTestThenTrainEvaluation(); +// testPrequentialEvaluation(); +// +// StreamingRandomPatches learner = new StreamingRandomPatches(); +// learner.getOptions().setViaCLIString("-s 100"); // 10 learners +//// learner.setRandomSeed(5); +// learner.prepareForUse(); +// testPrequentialEfficiency1(learner); + +// testPrequentialEvaluation_edge_cases1(); +// testPrequentialEvaluation_edge_cases2(); +// testPrequentialEvaluation_edge_cases3(); +// testPrequentialEvaluation_edge_cases4(); +// testPrequentialEvaluation_SampleFrequency_TestThenTrain(); + +// testPrequentialRegressionEvaluation(); + } - private static void examplePrequentialEfficiency(Learner learner, int maxInstances) { - System.out.println("Assessing efficiency for " + learner.getCLICreationString(learner.getClass()) + - " maxInstances: " + maxInstances); + private static void testPrequentialEfficiency1(Learner learner) { // Record the start time long startTime = System.currentTimeMillis(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); basic_evaluator.recallPerClassOption.setValue(true); basic_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null, - maxInstances, 1, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null, 100000, 1); // Record the end time long endTime = System.currentTimeMillis(); @@ -227,18 +395,25 @@ private static void examplePrequentialEfficiency(Learner learner, int maxInstanc System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]); } - private static void examplePrequentialEvaluation_edge_cases1() { + private static void testPrequentialEvaluation_edge_cases1() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); learner.prepareForUse(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, null, - 100000, 1000, false, false); +// BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); +// basic_evaluator.recallPerClassOption.setValue(true); +// basic_evaluator.prepareForUse(); +// +// WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator(); +// windowed_evaluator.widthOption.setValue(1000); +// windowed_evaluator.prepareForUse(); + + PrequentialResult results = PrequentialEvaluation(stream, learner, null, null, 100000, 1000); // Record the end time long endTime = System.currentTimeMillis(); @@ -248,16 +423,28 @@ private static void examplePrequentialEvaluation_edge_cases1() { // Print the elapsed time System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds"); + +// System.out.println("\tBasic performance"); +// for (int i = 0; i < results.basicResults.length; ++i) +// System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.basicResults[i]); + +// System.out.println("\tWindowed performance"); +// for (int j = 0; j < results.windowedResults.size(); ++j) { +// System.out.println("\t" + j); +// for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i) +// System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]); +// } } - private static void examplePrequentialEvaluation_edge_cases2() { + + private static void testPrequentialEvaluation_edge_cases2() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); learner.prepareForUse(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); @@ -268,8 +455,7 @@ private static void examplePrequentialEvaluation_edge_cases2() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, - 1000, 10000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 1000, 10000); // Record the end time long endTime = System.currentTimeMillis(); @@ -292,14 +478,14 @@ private static void examplePrequentialEvaluation_edge_cases2() { } } - private static void examplePrequentialEvaluation_edge_cases3() { + private static void testPrequentialEvaluation_edge_cases3() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); learner.prepareForUse(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); @@ -310,8 +496,7 @@ private static void examplePrequentialEvaluation_edge_cases3() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, - 10, 1, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 10, 1); // Record the end time long endTime = System.currentTimeMillis(); @@ -324,26 +509,24 @@ private static void examplePrequentialEvaluation_edge_cases3() { System.out.println("\tBasic performance"); for (int i = 0; i < results.cumulativeResults.length; ++i) - System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + - results.cumulativeResults[i]); + System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]); System.out.println("\tWindowed performance"); for (int j = 0; j < results.windowedResults.size(); ++j) { System.out.println("\t" + j); for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i) - System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + - results.windowedResults.get(j)[i]); + System.out.println(windowed_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]); } } - private static void examplePrequentialEvaluation_edge_cases4() { + private static void testPrequentialEvaluation_edge_cases4() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); learner.prepareForUse(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); @@ -354,8 +537,7 @@ private static void examplePrequentialEvaluation_edge_cases4() { windowed_evaluator.widthOption.setValue(10000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, - 100000, 10000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, -1, 10000); // Record the end time long endTime = System.currentTimeMillis(); @@ -378,22 +560,26 @@ private static void examplePrequentialEvaluation_edge_cases4() { } } - private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain() { + + private static void testPrequentialEvaluation_SampleFrequency_TestThenTrain() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); learner.prepareForUse(); - AgrawalGenerator stream = new AgrawalGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator basic_evaluator = new BasicClassificationPerformanceEvaluator(); basic_evaluator.recallPerClassOption.setValue(true); basic_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator, - 100000, 10000, false, false); +// WindowClassificationPerformanceEvaluator windowed_evaluator = new WindowClassificationPerformanceEvaluator(); +// windowed_evaluator.widthOption.setValue(10000); +// windowed_evaluator.prepareForUse(); + + PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator, -1, 10000); // Record the end time long endTime = System.currentTimeMillis(); @@ -404,6 +590,10 @@ private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain() // Print the elapsed time System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds"); +// System.out.println("\tBasic performance"); +// for (int i = 0; i < results.basicResults.length; ++i) +// System.out.println(basic_evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.basicResults[i]); + System.out.println("\tWindowed performance"); for (int j = 0; j < results.windowedResults.size(); ++j) { System.out.println("\t" + j); @@ -412,23 +602,26 @@ private static void examplePrequentialEvaluation_SampleFrequency_TestThenTrain() } } - private static void examplePrequentialRegressionEvaluation() { + + private static void testPrequentialRegressionEvaluation() { // Record the start time long startTime = System.currentTimeMillis(); FIMTDD learner = new FIMTDD(); +// learner.getOptions().setViaCLIString("-s 10"); // 10 learners +// learner.setRandomSeed(5); learner.prepareForUse(); - HyperplaneGenerator stream = new HyperplaneGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/metrotraffic_with_nominals.arff", -1); stream.prepareForUse(); BasicRegressionPerformanceEvaluator basic_evaluator = new BasicRegressionPerformanceEvaluator(); WindowRegressionPerformanceEvaluator windowed_evaluator = new WindowRegressionPerformanceEvaluator(); windowed_evaluator.widthOption.setValue(1000); +// windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, - 10000, 1000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000); // Record the end time long endTime = System.currentTimeMillis(); @@ -451,7 +644,7 @@ private static void examplePrequentialRegressionEvaluation() { } } - private static void examplePrequentialEvaluation() { + private static void testPrequentialEvaluation() { // Record the start time long startTime = System.currentTimeMillis(); @@ -471,7 +664,7 @@ private static void examplePrequentialEvaluation() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000); // Record the end time long endTime = System.currentTimeMillis(); @@ -494,22 +687,23 @@ private static void examplePrequentialEvaluation() { } } - private static void exampleTestThenTrainEvaluation() { + private static void testTestThenTrainEvaluation() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); +// learner.getOptions().setViaCLIString("-s 10"); // 10 learners +// learner.setRandomSeed(5); learner.prepareForUse(); - HyperplaneGenerator stream = new HyperplaneGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); BasicClassificationPerformanceEvaluator evaluator = new BasicClassificationPerformanceEvaluator(); evaluator.recallPerClassOption.setValue(true); evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null, - 100000, 100000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null, 100000, 100000); // Record the end time long endTime = System.currentTimeMillis(); @@ -521,18 +715,19 @@ private static void exampleTestThenTrainEvaluation() { System.out.println("Elapsed Time: " + elapsedTime / 1000 + " seconds"); for (int i = 0; i < results.cumulativeResults.length; ++i) - System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + - results.cumulativeResults[i]); + System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.cumulativeResults[i]); } - private static void exampleWindowedEvaluation() { + private static void testWindowedEvaluation() { // Record the start time long startTime = System.currentTimeMillis(); NaiveBayes learner = new NaiveBayes(); +// learner.getOptions().setViaCLIString("-s 10"); // 10 learners +// learner.setRandomSeed(5); learner.prepareForUse(); - HyperplaneGenerator stream = new HyperplaneGenerator(); + ArffFileStream stream = new ArffFileStream("/Users/gomeshe/Desktop/data/Hyper100k.arff", -1); stream.prepareForUse(); WindowClassificationPerformanceEvaluator evaluator = new WindowClassificationPerformanceEvaluator(); @@ -540,8 +735,7 @@ private static void exampleWindowedEvaluation() { evaluator.recallPerClassOption.setValue(true); evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator, - 100000, 10000, false, false); + PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator, 100000, 10000); // Record the end time long endTime = System.currentTimeMillis(); @@ -555,9 +749,7 @@ private static void exampleWindowedEvaluation() { for (int j = 0; j < results.windowedResults.size(); ++j) { System.out.println("\t" + j); for (int i = 0; i < 2; ++i) // results.get(results.size()-1).length; ++i) - System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + - results.windowedResults.get(j)[i]); + System.out.println(evaluator.getPerformanceMeasurements()[i].getName() + ": " + results.windowedResults.get(j)[i]); } } - } \ No newline at end of file diff --git a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java index a7c655be8..911ac4c50 100644 --- a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java +++ b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.evaluation; @@ -35,35 +35,37 @@ * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ -public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler { +public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler, AutoCloseable { - /** - * Resets this evaluator. It must be similar to - * starting a new evaluator from scratch. - * - */ + /** + * Resets this evaluator. It must be similar to + * starting a new evaluator from scratch. + * + */ public void reset(); - /** - * Adds a learning result to this evaluator. - * - * @param example the example to be classified - * @param classVotes an array containing the estimated membership - * probabilities of the test instance in each class - */ - public void addResult(E example, double[] classVotes); - public void addResult(E testInst, Prediction prediction); + /** + * Adds a learning result to this evaluator. + * + * @param example the example to be classified + * @param classVotes an array containing the estimated membership + * probabilities of the test instance in each class + */ + public void addResult(E example, double[] classVotes); + public void addResult(E testInst, Prediction prediction); - /** - * Gets the current measurements monitored by this evaluator. - * - * @return an array of measurements monitored by this evaluator - */ + /** + * Gets the current measurements monitored by this evaluator. + * + * @return an array of measurements monitored by this evaluator + */ public Measurement[] getPerformanceMeasurements(); @Override default ImmutableCapabilities defineImmutableCapabilities() { - return new ImmutableCapabilities(Capability.VIEW_STANDARD); + return new ImmutableCapabilities(Capability.VIEW_STANDARD); } + default void close() throws Exception { + } } diff --git a/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java b/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java new file mode 100644 index 000000000..f9dc784e1 --- /dev/null +++ b/moa/src/main/java/moa/gui/SemiSupervisedTabPanel.java @@ -0,0 +1,29 @@ +package moa.gui; + +import java.awt.*; + +public class SemiSupervisedTabPanel extends AbstractTabPanel { + + protected SemiSupervisedTaskManagerPanel taskManagerPanel; + + protected PreviewPanel previewPanel; + + public SemiSupervisedTabPanel() { + this.taskManagerPanel = new SemiSupervisedTaskManagerPanel(); + this.previewPanel = new PreviewPanel(); + this.taskManagerPanel.setPreviewPanel(this.previewPanel); + setLayout(new BorderLayout()); + add(this.taskManagerPanel, BorderLayout.NORTH); + add(this.previewPanel, BorderLayout.CENTER); + } + + @Override + public String getTabTitle() { + return "Semi-Supervised Learning"; + } + + @Override + public String getDescription() { + return "MOA Semi-Supervised Learning"; + } +} diff --git a/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java b/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java new file mode 100644 index 000000000..0e2f251c6 --- /dev/null +++ b/moa/src/main/java/moa/gui/SemiSupervisedTaskManagerPanel.java @@ -0,0 +1,468 @@ +package moa.gui; + +import moa.core.StringUtils; +import moa.options.ClassOption; +import moa.options.OptionHandler; +import moa.tasks.EvaluateInterleavedTestThenTrainSSLDelayed; +import moa.tasks.SemiSupervisedMainTask; +import moa.tasks.Task; +import moa.tasks.TaskThread; +import nz.ac.waikato.cms.gui.core.BaseFileChooser; + +import javax.swing.*; +import javax.swing.event.ListSelectionEvent; +import javax.swing.event.ListSelectionListener; +import javax.swing.table.AbstractTableModel; +import javax.swing.table.DefaultTableCellRenderer; +import javax.swing.table.TableCellRenderer; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.StringSelection; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; +import java.awt.event.MouseAdapter; +import java.awt.event.MouseEvent; +import java.io.*; +import java.util.ArrayList; +import java.util.prefs.Preferences; + +public class SemiSupervisedTaskManagerPanel extends JPanel { + + private static final long serialVersionUID = 1L; + + public static final int MILLISECS_BETWEEN_REFRESH = 600; + + public static String exportFileExtension = "log"; + + public class ProgressCellRenderer extends JProgressBar implements + TableCellRenderer { + + private static final long serialVersionUID = 1L; + + public ProgressCellRenderer() { + super(SwingConstants.HORIZONTAL, 0, 10000); + setBorderPainted(false); + setStringPainted(true); + } + + @Override + public Component getTableCellRendererComponent(JTable table, + Object value, boolean isSelected, boolean hasFocus, int row, + int column) { + double frac = -1.0; + if (value instanceof Double) { + frac = ((Double) value).doubleValue(); + } + if (frac >= 0.0) { + setIndeterminate(false); + setValue((int) (frac * 10000.0)); + setString(StringUtils.doubleToString(frac * 100.0, 2, 2)); + } else { + setValue(0); + } + return this; + } + + @Override + public void validate() { } + + @Override + public void revalidate() { } + + @Override + protected void firePropertyChange(String propertyName, Object oldValue, + Object newValue) { } + + @Override + public void firePropertyChange(String propertyName, boolean oldValue, + boolean newValue) { } + } + + protected class TaskTableModel extends AbstractTableModel { + + private static final long serialVersionUID = 1L; + + @Override + public String getColumnName(int col) { + switch (col) { + case 0: + return "command"; + case 1: + return "status"; + case 2: + return "time elapsed"; + case 3: + return "current activity"; + case 4: + return "% complete"; + } + return null; + } + + @Override + public int getColumnCount() { + return 5; + } + + @Override + public int getRowCount() { + return SemiSupervisedTaskManagerPanel.this.taskList.size(); + } + + @Override + public Object getValueAt(int row, int col) { + TaskThread thread = SemiSupervisedTaskManagerPanel.this.taskList.get(row); + switch (col) { + case 0: + return ((OptionHandler) thread.getTask()).getCLICreationString(SemiSupervisedMainTask.class); + case 1: + return thread.getCurrentStatusString(); + case 2: + return StringUtils.secondsToDHMSString(thread.getCPUSecondsElapsed()); + case 3: + return thread.getCurrentActivityString(); + case 4: + return Double.valueOf(thread.getCurrentActivityFracComplete()); + } + return null; + } + + @Override + public boolean isCellEditable(int row, int col) { + return false; + } + } + + protected SemiSupervisedMainTask currentTask; + + protected java.util.List taskList = new ArrayList<>(); + + protected JButton configureTaskButton = new JButton("Configure"); + + protected JTextField taskDescField = new JTextField(); + + protected JButton runTaskButton = new JButton("Run"); + + protected TaskTableModel taskTableModel; + + protected JTable taskTable; + + protected JButton pauseTaskButton = new JButton("Pause"); + + protected JButton resumeTaskButton = new JButton("Resume"); + + protected JButton cancelTaskButton = new JButton("Cancel"); + + protected JButton deleteTaskButton = new JButton("Delete"); + + protected PreviewPanel previewPanel; + + private Preferences prefs; + + private final String PREF_NAME = "currentTask"; + + public SemiSupervisedTaskManagerPanel() { + // Read current task preference + prefs = Preferences.userRoot().node(this.getClass().getName()); + currentTask = new EvaluateInterleavedTestThenTrainSSLDelayed(); + String taskText = this.currentTask.getCLICreationString(SemiSupervisedMainTask.class); + String propertyValue = prefs.get(PREF_NAME, taskText); + //this.taskDescField.setText(propertyValue); + setTaskString(propertyValue, false); //Not store preference + this.taskDescField.setEditable(false); + + final Component comp = this.taskDescField; + this.taskDescField.addMouseListener(new MouseAdapter() { + + @Override + public void mouseClicked(MouseEvent evt) { + if (evt.getClickCount() == 1) { + if ((evt.getButton() == MouseEvent.BUTTON3) + || ((evt.getButton() == MouseEvent.BUTTON1) && evt.isAltDown() && evt.isShiftDown())) { + JPopupMenu menu = new JPopupMenu(); + JMenuItem item; + + item = new JMenuItem("Copy configuration to clipboard"); + item.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + copyClipBoardConfiguration(); + } + }); + menu.add(item); + + item = new JMenuItem("Save selected tasks to file"); + item.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + saveLogSelectedTasks(); + } + }); + menu.add(item); + + + item = new JMenuItem("Enter configuration..."); + item.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + String newTaskString = JOptionPane.showInputDialog("Insert command line"); + if (newTaskString != null) { + setTaskString(newTaskString); + } + } + }); + menu.add(item); + + menu.show(comp, evt.getX(), evt.getY()); + } + } + } + }); + + JPanel configPanel = new JPanel(); + configPanel.setLayout(new BorderLayout()); + configPanel.add(this.configureTaskButton, BorderLayout.WEST); + configPanel.add(this.taskDescField, BorderLayout.CENTER); + configPanel.add(this.runTaskButton, BorderLayout.EAST); + this.taskTableModel = new TaskTableModel(); + this.taskTable = new JTable(this.taskTableModel); + DefaultTableCellRenderer centerRenderer = new DefaultTableCellRenderer(); + centerRenderer.setHorizontalAlignment(SwingConstants.CENTER); + this.taskTable.getColumnModel().getColumn(1).setCellRenderer( + centerRenderer); + this.taskTable.getColumnModel().getColumn(2).setCellRenderer( + centerRenderer); + this.taskTable.getColumnModel().getColumn(4).setCellRenderer( + new ProgressCellRenderer()); + JPanel controlPanel = new JPanel(); + controlPanel.add(this.pauseTaskButton); + controlPanel.add(this.resumeTaskButton); + controlPanel.add(this.cancelTaskButton); + controlPanel.add(this.deleteTaskButton); + setLayout(new BorderLayout()); + add(configPanel, BorderLayout.NORTH); + add(new JScrollPane(this.taskTable), BorderLayout.CENTER); + add(controlPanel, BorderLayout.SOUTH); + this.taskTable.getSelectionModel().addListSelectionListener( + new ListSelectionListener() { + + @Override + public void valueChanged(ListSelectionEvent arg0) { + taskSelectionChanged(); + } + }); + this.configureTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + String newTaskString = ClassOptionSelectionPanel.showSelectClassDialog( + SemiSupervisedTaskManagerPanel.this, + "Configure task", SemiSupervisedMainTask.class, + SemiSupervisedTaskManagerPanel.this.currentTask.getCLICreationString(SemiSupervisedMainTask.class), + null); + setTaskString(newTaskString); + } + }); + this.runTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + runTask((Task) SemiSupervisedTaskManagerPanel.this.currentTask.copy()); + } + }); + this.pauseTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + pauseSelectedTasks(); + } + }); + this.resumeTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + resumeSelectedTasks(); + } + }); + this.cancelTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + cancelSelectedTasks(); + } + }); + this.deleteTaskButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent arg0) { + deleteSelectedTasks(); + } + }); + + Timer updateListTimer = new Timer( + MILLISECS_BETWEEN_REFRESH, new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + SemiSupervisedTaskManagerPanel.this.taskTable.repaint(); + } + }); + updateListTimer.start(); + setPreferredSize(new Dimension(0, 200)); + } + + public void setPreviewPanel(PreviewPanel previewPanel) { + this.previewPanel = previewPanel; + } + + public void setTaskString(String cliString) { + setTaskString(cliString, true); + } + + public void setTaskString(String cliString, boolean storePreference) { + try { + this.currentTask = (SemiSupervisedMainTask) ClassOption.cliStringToObject( + cliString, SemiSupervisedMainTask.class, null); + String taskText = this.currentTask.getCLICreationString(SemiSupervisedMainTask.class); + this.taskDescField.setText(taskText); + if (storePreference) { + //Save task text as a preference + prefs.put(PREF_NAME, taskText); + } + } catch (Exception ex) { + GUIUtils.showExceptionDialog(this, "Problem with task", ex); + } + } + + public void runTask(Task task) { + TaskThread thread = new TaskThread(task); + this.taskList.add(0, thread); + this.taskTableModel.fireTableDataChanged(); + this.taskTable.setRowSelectionInterval(0, 0); + thread.start(); + } + + public void taskSelectionChanged() { + TaskThread[] selectedTasks = getSelectedTasks(); + if (selectedTasks.length == 1) { + setTaskString(((OptionHandler) selectedTasks[0].getTask()).getCLICreationString(SemiSupervisedMainTask.class)); + if (this.previewPanel != null) { + this.previewPanel.setTaskThreadToPreview(selectedTasks[0]); + } + } else { + this.previewPanel.setTaskThreadToPreview(null); + } + } + + public TaskThread[] getSelectedTasks() { + int[] selectedRows = this.taskTable.getSelectedRows(); + TaskThread[] selectedTasks = new TaskThread[selectedRows.length]; + for (int i = 0; i < selectedRows.length; i++) { + selectedTasks[i] = this.taskList.get(selectedRows[i]); + } + return selectedTasks; + } + + public void pauseSelectedTasks() { + TaskThread[] selectedTasks = getSelectedTasks(); + for (TaskThread thread : selectedTasks) { + thread.pauseTask(); + } + } + + public void resumeSelectedTasks() { + TaskThread[] selectedTasks = getSelectedTasks(); + for (TaskThread thread : selectedTasks) { + thread.resumeTask(); + } + } + + public void cancelSelectedTasks() { + TaskThread[] selectedTasks = getSelectedTasks(); + for (TaskThread thread : selectedTasks) { + thread.cancelTask(); + } + } + + public void deleteSelectedTasks() { + TaskThread[] selectedTasks = getSelectedTasks(); + for (TaskThread thread : selectedTasks) { + thread.cancelTask(); + this.taskList.remove(thread); + } + this.taskTableModel.fireTableDataChanged(); + } + + public void copyClipBoardConfiguration() { + + StringSelection selection = new StringSelection(this.taskDescField.getText().trim()); + Clipboard clipboard = Toolkit.getDefaultToolkit().getSystemClipboard(); + clipboard.setContents(selection, selection); + + } + + public void saveLogSelectedTasks() { + String tasksLog = ""; + TaskThread[] selectedTasks = getSelectedTasks(); + for (TaskThread thread : selectedTasks) { + tasksLog += ((OptionHandler) thread.getTask()).getCLICreationString(SemiSupervisedMainTask.class) + "\n"; + } + + BaseFileChooser fileChooser = new BaseFileChooser(); + fileChooser.setAcceptAllFileFilterUsed(true); + fileChooser.addChoosableFileFilter(new FileExtensionFilter( + exportFileExtension)); + if (fileChooser.showSaveDialog(this) == BaseFileChooser.APPROVE_OPTION) { + File chosenFile = fileChooser.getSelectedFile(); + String fileName = chosenFile.getPath(); + if (!chosenFile.exists() + && !fileName.endsWith(exportFileExtension)) { + fileName = fileName + "." + exportFileExtension; + } + try { + PrintWriter out = new PrintWriter(new BufferedWriter( + new FileWriter(fileName))); + out.write(tasksLog); + out.close(); + } catch (IOException ioe) { + GUIUtils.showExceptionDialog( + this, + "Problem saving file " + fileName, ioe); + } + } + } + + private static void createAndShowGUI() { + + // Create and set up the labeledInstancesBuffer. + JFrame frame = new JFrame("Test"); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + + // Create and set up the content pane. + JPanel panel = new SemiSupervisedTabPanel(); + panel.setOpaque(true); // content panes must be opaque + frame.setContentPane(panel); + + // Display the labeledInstancesBuffer. + frame.pack(); + // frame.setSize(400, 400); + frame.setVisible(true); + } + + public static void main(String[] args) { + try { + UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName()); + SwingUtilities.invokeLater(new Runnable() { + @Override + public void run() { + createAndShowGUI(); + } + }); + } catch (Exception e) { + e.printStackTrace(); + } + } +} diff --git a/moa/src/main/java/moa/learners/Learner.java b/moa/src/main/java/moa/learners/Learner.java index be959a8d0..a806ad587 100644 --- a/moa/src/main/java/moa/learners/Learner.java +++ b/moa/src/main/java/moa/learners/Learner.java @@ -19,14 +19,10 @@ */ package moa.learners; +import com.yahoo.labs.samoa.instances.*; import moa.MOAObject; import moa.core.Example; -import com.yahoo.labs.samoa.instances.InstanceData; -import com.yahoo.labs.samoa.instances.InstancesHeader; -import com.yahoo.labs.samoa.instances.MultiLabelInstance; -import com.yahoo.labs.samoa.instances.Prediction; - import moa.core.Measurement; import moa.gui.AWTRenderable; import moa.options.OptionHandler; @@ -95,6 +91,14 @@ public interface Learner extends MOAObject, OptionHandler, AW */ public double[] getVotesForInstance(E example); + /** + * + * @param example the instance whose confidence we are observing + * @param label + * @return + */ + public double getConfidenceForPrediction(E example, double label); + /** * Gets the current measurements of this learner. * diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java new file mode 100644 index 000000000..b5b02904a --- /dev/null +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrainSSLDelayed.java @@ -0,0 +1,351 @@ +package moa.tasks; + +import com.github.javacliparser.FileOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.MultiClassClassifier; +import moa.classifiers.SemiSupervisedLearner; +import moa.core.*; +import moa.evaluation.LearningEvaluation; +import moa.evaluation.LearningPerformanceEvaluator; +import moa.evaluation.preview.LearningCurve; +import moa.learners.Learner; +import moa.options.ClassOption; +import moa.streams.ExampleStream; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang3.tuple.MutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.random.MersenneTwister; +import org.apache.commons.math3.random.RandomGenerator; + +/** + * An evaluation task that relies on the mechanism of Interleaved Test Then + * Train, + * applied on semi-supervised data streams + */ +public class EvaluateInterleavedTestThenTrainSSLDelayed extends SemiSupervisedMainTask { + + @Override + public String getPurposeString() { + return "Evaluates a classifier on a semi-supervised stream by testing only the labeled data, " + + "then training with each example in sequence."; + } + + private static final long serialVersionUID = 1L; + + public IntOption randomSeedOption = new IntOption( + "instanceRandomSeed", 'r', + "Seed for random generation of instances.", 1); + + public FlagOption onlyLabeledDataOption = new FlagOption("labeledDataOnly", 'a', + "Learner only trained on labeled data"); + + public ClassOption standardLearnerOption = new ClassOption("standardLearner", 'b', + "A standard learner to train. This will be ignored if labeledDataOnly flag is not set.", + MultiClassClassifier.class, "moa.classifiers.trees.HoeffdingTree"); + + public ClassOption sslLearnerOption = new ClassOption("sslLearner", 'l', + "A semi-supervised learner to train.", SemiSupervisedLearner.class, + "moa.classifiers.semisupervised.ClusterAndLabelClassifier"); + + public ClassOption streamOption = new ClassOption("stream", 's', + "Stream to learn from.", ExampleStream.class, + "moa.streams.ArffFileStream"); + + public ClassOption evaluatorOption = new ClassOption("evaluator", 'e', + "Classification performance evaluation method.", + LearningPerformanceEvaluator.class, + "BasicClassificationPerformanceEvaluator"); + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + /** Option: Probability of instance being unlabeled */ + public FloatOption labelProbabilityOption = new FloatOption("labelProbability", 'j', + "The ratio of labeled data", + 0.01); + + public IntOption delayLengthOption = new IntOption("delay", 'k', + "Number of instances before test instance is used for training. -1 = no delayed labeling.", + -1, -1, Integer.MAX_VALUE); + + public IntOption initialWindowSizeOption = new IntOption("initialTrainingWindow", 'p', + "Number of instances used for training in the beginning of the stream (-1 = no initialWindow).", + -1, -1, Integer.MAX_VALUE); + + public FlagOption debugPseudoLabelsOption = new FlagOption("debugPseudoLabels", 'w', + "Learner also receives the labeled data, but it is not used for training (just for statistics)"); + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', + "Maximum number of instances to test/train on (-1 = no limit).", + 100000000, -1, Integer.MAX_VALUE); + + public IntOption timeLimitOption = new IntOption("timeLimit", 't', + "Maximum number of seconds to test/train for (-1 = no limit).", -1, + -1, Integer.MAX_VALUE); + + public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", + 'f', + "How many instances between samples of the learning performance.", + 100000, 0, Integer.MAX_VALUE); + + public IntOption memCheckFrequencyOption = new IntOption( + "memCheckFrequency", 'q', + "How many instances between memory bound checks.", 100000, 0, + Integer.MAX_VALUE); + + public FileOption dumpFileOption = new FileOption("dumpFile", 'd', + "File to append intermediate csv results to.", null, "csv", true); + + public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', + "File to append output predictions to.", null, "pred", true); + + public FileOption debugOutputUnlabeledClassInformation = new FileOption("debugOutputUnlabeledClassInformation", 'h', + "Single column containing the class label or -999 indicating missing labels.", null, "csv", true); + + private int numUnlabeledData = 0; + + private Learner getLearner(ExampleStream stream) { + Learner learner; + if (this.onlyLabeledDataOption.isSet()) { + learner = (Learner) getPreparedClassOption(this.standardLearnerOption); + } else { + learner = (SemiSupervisedLearner) getPreparedClassOption(this.sslLearnerOption); + } + + learner.setModelContext(stream.getHeader()); + if (learner.isRandomizable()) { + learner.setRandomSeed(this.randomSeedOption.getValue()); + learner.resetLearning(); + } + return learner; + } + + private String getLearnerString() { + if (this.onlyLabeledDataOption.isSet()) { + return this.standardLearnerOption.getValueAsCLIString(); + } else { + return this.sslLearnerOption.getValueAsCLIString(); + } + } + + private PrintStream newPrintStream(File f, String err_msg) { + if (f == null) + return null; + try { + return new PrintStream(new FileOutputStream(f, f.exists()), true); + } catch (FileNotFoundException e) { + throw new RuntimeException(err_msg, e); + } + } + + private Object internalDoMainTask(TaskMonitor monitor, ObjectRepository repository, LearningPerformanceEvaluator evaluator) + { + int maxInstances = this.instanceLimitOption.getValue(); + int maxSeconds = this.timeLimitOption.getValue(); + int delayLength = this.delayLengthOption.getValue(); + double labelProbability = this.labelProbabilityOption.getValue(); + String streamString = this.streamOption.getValueAsCLIString(); + RandomGenerator taskRandom = new MersenneTwister(this.randomSeedOption.getValue()); + ExampleStream stream = (ExampleStream) getPreparedClassOption(this.streamOption); + Learner learner = getLearner(stream); + String learnerString = getLearnerString(); + + // A number of output files used for debugging and manual evaluation + PrintStream dumpStream = newPrintStream(this.dumpFileOption.getFile(), "Failed to create dump file"); + PrintStream predStream = newPrintStream(this.outputPredictionFileOption.getFile(), + "Failed to create prediction file"); + PrintStream labelStream = newPrintStream(this.debugOutputUnlabeledClassInformation.getFile(), + "Failed to create unlabeled class information file"); + if (labelStream != null) + labelStream.println("class"); + + // Setup evaluation + monitor.setCurrentActivity("Evaluating learner...", -1.0); + LearningCurve learningCurve = new LearningCurve("learning evaluation instances"); + + boolean firstDump = true; + boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); + long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); + long lastEvaluateStartTime = evaluateStartTime; + long instancesProcessed = 0; + int secondsElapsed = 0; + double RAMHours = 0.0; + + // The buffer is a list of tuples. The first element is the index when + // it should be emitted. The second element is the instance itself. + List> delayBuffer = new ArrayList>(); + + while (stream.hasMoreInstances() + && ((maxInstances < 0) || (instancesProcessed < maxInstances)) + && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) { + instancesProcessed++; + + // TRAIN on delayed instances + while (delayBuffer.size() > 0 + && delayBuffer.get(0).getKey() == instancesProcessed) { + Example delayedExample = delayBuffer.remove(0).getValue(); + learner.trainOnInstance(delayedExample); + } + + // Obtain the next Example from the stream. + // The instance is expected to be labeled. + Example originalExample = stream.nextInstance(); + Example unlabeledExample = originalExample.copy(); + int trueClass = (int) ((Instance) originalExample.getData()).classValue(); + + // In case it is set, then the label is not removed. We want to pass the + // labelled data to the learner even in trainOnUnlabeled data to generate statistics such as number + // of correctly pseudo-labeled instances. + if (!debugPseudoLabelsOption.isSet()) { + // Remove the label of the unlabeledExample indirectly through + // unlabeledInstanceData. + Instance instance = (Instance) unlabeledExample.getData(); + instance.setMissing(instance.classIndex()); + } + + // WARMUP + // Train on the initial instances. These are not used for testing! + if (instancesProcessed <= this.initialWindowSizeOption.getValue()) { + if (learner instanceof SemiSupervisedLearner) + ((SemiSupervisedLearner) learner).addInitialWarmupTrainingInstances(); + learner.trainOnInstance(originalExample); + continue; + } + + Boolean is_labeled = labelProbability > taskRandom.nextDouble(); + if (!is_labeled) { + this.numUnlabeledData++; + if (labelStream != null) + labelStream.println(-999); + } else { + if (labelStream != null) + labelStream.println((int) trueClass); + } + + // TEST + // Obtain the prediction for the testInst (i.e. no label) + double[] prediction = learner.getVotesForInstance(unlabeledExample); + + // Output prediction + if (predStream != null) { + // Assuming that the class label is not missing for the originalInstanceData + predStream.println(Utils.maxIndex(prediction) + "," + trueClass); + } + evaluator.addResult(originalExample, prediction); + + // TRAIN + if (is_labeled && delayLength >= 0) { + // The instance will be labeled but has been delayed + if (learner instanceof SemiSupervisedLearner) + { + ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData()); + } + delayBuffer.add( + new MutablePair(1 + instancesProcessed + delayLength, originalExample)); + } else if (is_labeled) { + // The instance will be labeled and is not delayed e.g delayLength = -1 + learner.trainOnInstance(originalExample); + } else { + // The instance will never be labeled + if (learner instanceof SemiSupervisedLearner) + ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData()); + } + + if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 || !stream.hasMoreInstances()) { + long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); + double time = TimingUtils.nanoTimeToSeconds(evaluateTime - evaluateStartTime); + double timeIncrement = TimingUtils.nanoTimeToSeconds(evaluateTime - lastEvaluateStartTime); + double RAMHoursIncrement = learner.measureByteSize() / (1024.0 * 1024.0 * 1024.0); // GBs + RAMHoursIncrement *= (timeIncrement / 3600.0); // Hours + RAMHours += RAMHoursIncrement; + lastEvaluateStartTime = evaluateTime; + learningCurve.insertEntry(new LearningEvaluation( + new Measurement[] { + new Measurement( + "learning evaluation instances", + instancesProcessed), + new Measurement( + "evaluation time (" + + (preciseCPUTiming ? "cpu " + : "") + + "seconds)", + time), + new Measurement( + "model cost (RAM-Hours)", + RAMHours), + new Measurement( + "Unlabeled instances", + this.numUnlabeledData) + }, + evaluator, learner)); + if (dumpStream != null) { + if (firstDump) { + dumpStream.print("Learner,stream,randomSeed,"); + dumpStream.println(learningCurve.headerToString()); + firstDump = false; + } + dumpStream.print(learnerString + "," + streamString + "," + + this.randomSeedOption.getValueAsCLIString() + ","); + dumpStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); + dumpStream.flush(); + } + } + if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { + if (monitor.taskShouldAbort()) { + return null; + } + long estimatedRemainingInstances = stream.estimatedRemainingInstances(); + if (maxInstances > 0) { + long maxRemaining = maxInstances - instancesProcessed; + if ((estimatedRemainingInstances < 0) + || (maxRemaining < estimatedRemainingInstances)) { + estimatedRemainingInstances = maxRemaining; + } + } + monitor.setCurrentActivityFractionComplete(estimatedRemainingInstances < 0 ? -1.0 + : (double) instancesProcessed / (double) (instancesProcessed + estimatedRemainingInstances)); + if (monitor.resultPreviewRequested()) { + monitor.setLatestResultPreview(learningCurve.copy()); + } + secondsElapsed = (int) TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() + - evaluateStartTime); + } + } + if (dumpStream != null) { + dumpStream.close(); + } + if (predStream != null) { + predStream.close(); + } + return learningCurve; + } + + @Override + protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { + // Some resource must be closed at the end of the task + try ( + LearningPerformanceEvaluator evaluator = (LearningPerformanceEvaluator) getPreparedClassOption(this.evaluatorOption) + ) { + return internalDoMainTask(monitor, repository, evaluator); + } catch (Exception e) { + throw new RuntimeException(e); + } + + } + + @Override + public Class getTaskResultType() { + return LearningCurve.class; + } +} diff --git a/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java b/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java new file mode 100644 index 000000000..fecf7feae --- /dev/null +++ b/moa/src/main/java/moa/tasks/SemiSupervisedMainTask.java @@ -0,0 +1,24 @@ +package moa.tasks; + +import moa.streams.clustering.ClusterEvent; + +import java.util.ArrayList; + +/** + * + */ +public abstract class SemiSupervisedMainTask extends MainTask { + + private static final long serialVersionUID = 1L; + + protected ArrayList events; + + protected void setEventsList(ArrayList events) { + this.events = events; + } + + public ArrayList getEventsList() { + return this.events; + } + +} diff --git a/moa/src/main/resources/moa/gui/GUI.props b/moa/src/main/resources/moa/gui/GUI.props index d1deabd85..3bb990469 100644 --- a/moa/src/main/resources/moa/gui/GUI.props +++ b/moa/src/main/resources/moa/gui/GUI.props @@ -8,6 +8,7 @@ Tabs=\ moa.gui.ClassificationTabPanel,\ moa.gui.RegressionTabPanel,\ + moa.gui.SemiSupervisedTabPanel,\ moa.gui.MultiLabelTabPanel,\ moa.gui.MultiTargetTabPanel,\ moa.gui.clustertab.ClusteringTabPanel,\ From 335734df23229913c460ed82387f374d73a0508b Mon Sep 17 00:00:00 2001 From: sunyibin Date: Thu, 18 Apr 2024 23:58:22 +1200 Subject: [PATCH 02/31] updating SOKNL in MOA to the newest version --- .../classifiers/meta/SelfOptimisingKNearestLeaves.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/meta/SelfOptimisingKNearestLeaves.java b/moa/src/main/java/moa/classifiers/meta/SelfOptimisingKNearestLeaves.java index d8dec74f7..8135eca2d 100644 --- a/moa/src/main/java/moa/classifiers/meta/SelfOptimisingKNearestLeaves.java +++ b/moa/src/main/java/moa/classifiers/meta/SelfOptimisingKNearestLeaves.java @@ -45,12 +45,10 @@ public String getPurposeString() { public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 100, 1, Integer.MAX_VALUE); - public FlagOption DisableSelfOptimisingOption = new FlagOption("DisableSelfOptimising",'f',"Disable the self optimising procedure."); + public FlagOption disableSelfOptimisingOption = new FlagOption("disableSelfOptimising",'f',"Disable the self optimising procedure."); public IntOption kOption = new IntOption("kNearestLeaves",'k',"Specify k value when not self-optimising",10,1,this.ensembleSizeOption.getMaxValue()); - public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "The random seed", 1); - public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", @@ -97,8 +95,6 @@ public void resetLearningImpl() { this.instancesSeen = 0; this.evaluator = new BasicRegressionPerformanceEvaluator(); - this.classifierRandom = new Random(randomSeedOption.getValue()); - this.previousPrediction = new double[this.ensembleSizeOption.getValue()]; this.selfOptimisingEvaluators = new BasicRegressionPerformanceEvaluator[this.ensembleSizeOption.getValue()]; @@ -157,7 +153,7 @@ public double[] getVotesForInstance(Instance instance) { } // Activate Self-Optimising K-Nearest Leaves - if (!this.DisableSelfOptimisingOption.isSet()) { + if (!this.disableSelfOptimisingOption.isSet()) { InstanceExample example = new InstanceExample(instance); int n = selfOptimisingEvaluators.length; double[] performances = new double[n]; From d9842f6589ac559393e000bdc9203b8b8103835c Mon Sep 17 00:00:00 2001 From: sunyibin Date: Fri, 19 Apr 2024 13:43:04 +1200 Subject: [PATCH 03/31] uploading prediction interval package --- .../AdaptivePredictionInterval.java | 1 + .../WindowPredictionIntervalEvaluator.java | 17 ++++++++---- .../main/java/moa/gui/RegressionTabPanel.java | 3 +-- ...valuatePrequentialPredictionIntervals.java | 26 +++++++++---------- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/predictioninterval/AdaptivePredictionInterval.java b/moa/src/main/java/moa/classifiers/predictioninterval/AdaptivePredictionInterval.java index c6d675e4e..8c22a32d7 100644 --- a/moa/src/main/java/moa/classifiers/predictioninterval/AdaptivePredictionInterval.java +++ b/moa/src/main/java/moa/classifiers/predictioninterval/AdaptivePredictionInterval.java @@ -5,6 +5,7 @@ import moa.capabilities.Capabilities; import moa.classifiers.AbstractClassifier; import moa.classifiers.Regressor; +import moa.classifiers.predictioninterval.PredictionIntervalLearner; import moa.core.InstanceExample; import moa.core.Measurement; import moa.core.ObjectRepository; diff --git a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java index b2ec1e305..b1c0b80ed 100644 --- a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java +++ b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java @@ -43,16 +43,18 @@ public class WindowPredictionIntervalEvaluator extends AbstractOptionHandler 'w', "Size of Window", 1000); protected double TotalweightObserved = 0; + protected Estimator weightObserved; + protected Estimator squareError; + protected Estimator averageError; protected Estimator lower; protected Estimator upper; - protected Estimator counterCorrect; - protected Estimator truth; + protected Estimator counterCorrect; protected int numClasses; @@ -118,7 +120,6 @@ public void reset(int numClasses) { this.lower = new Estimator(this.widthOption.getValue()); this.upper = new Estimator(this.widthOption.getValue()); this.counterCorrect = new Estimator(this.widthOption.getValue()); - this.truth = new Estimator(this.widthOption.getValue()); this.TotalweightObserved = 0; } @@ -139,7 +140,6 @@ public void addResult(Example example, double[] prediction) { this.lower.add(prediction[0]); this.upper.add(prediction[2]); this.counterCorrect.add( inst.classValue() >= prediction[0] && inst.classValue() <= prediction[2]? 1 : 0); - this.truth.add(inst.classValue()); } //System.out.println(inst.classValue()+", "+prediction[0]); } @@ -185,7 +185,14 @@ public double getAverageLength(){ } public double getNMPIW(){ - return Math.round(getAverageLength() / (this.truth.max() - this.truth.min()) * 10000.0) / 100.0; +// return Math.round((getAverageLength() / +// (((this.upper.max() - this.lower.max()) /2 + this.lower.max()) +// - ((this.upper.min() - this.lower.min()) /2 + this.lower.min()) +// ) +// * 10000.0) / 100.0); + + + return Math.round(getAverageLength() / ((this.upper.max() + this.lower.max() - this.upper.min() - this.lower.min()) / 2) * 10000.0) / 100.0; } @Override diff --git a/moa/src/main/java/moa/gui/RegressionTabPanel.java b/moa/src/main/java/moa/gui/RegressionTabPanel.java index d7a5f5ad9..a6ebd6473 100644 --- a/moa/src/main/java/moa/gui/RegressionTabPanel.java +++ b/moa/src/main/java/moa/gui/RegressionTabPanel.java @@ -40,8 +40,7 @@ public class RegressionTabPanel extends AbstractTabPanel { public RegressionTabPanel() { this.taskManagerPanel = new RegressionTaskManagerPanel(); - if (Objects.equals(this.taskManagerPanel.currentTask.getTaskName(), - "EvaluatePrequentialPredictionIntervals")){ + if (Objects.equals(this.taskManagerPanel.currentTask.getTaskName(), "EvaluatePrequentialPredictionIntervals")) { this.previewPanel = new PreviewPanel(TypePanel.PREDICTIONINTERVAL); } else{ diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java index c8e559d12..b577acb9c 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialPredictionIntervals.java @@ -16,7 +16,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.tasks; @@ -224,17 +224,17 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { lastEvaluateStartTime = evaluateTime; learningCurve.insertEntry(new LearningEvaluation( new Measurement[]{ - new Measurement( - "learning evaluation instances", - instancesProcessed), - new Measurement( - "evaluation time (" - + (preciseCPUTiming ? "cpu " - : "") + "seconds)", - time), - new Measurement( - "model cost (RAM-Hours)", - RAMHours) + new Measurement( + "learning evaluation instances", + instancesProcessed), + new Measurement( + "evaluation time (" + + (preciseCPUTiming ? "cpu " + : "") + "seconds)", + time), + new Measurement( + "model cost (RAM-Hours)", + RAMHours) }, evaluator, learner)); @@ -277,4 +277,4 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { } return learningCurve; } -} \ No newline at end of file +} From b0349c6b4b7714e69d8af379e8692abe14cf375a Mon Sep 17 00:00:00 2001 From: sunyibin Date: Sat, 20 Apr 2024 01:02:00 +1200 Subject: [PATCH 04/31] SOKNL-test --- .../moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java b/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java index 7c1510a4d..b76795378 100644 --- a/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java +++ b/moa/src/test/java/moa/classifiers/meta/SelfOptimisingKNearestLeavesTest.java @@ -55,7 +55,7 @@ protected Classifier[] getRegressionClassifierSetups() { SOKNLTest.ensembleSizeOption.setValue(5); SOKNLTest.mFeaturesModeOption.setChosenIndex(0); SOKNLTest.mFeaturesPerTreeSizeOption.setValue(2); - SOKNLTest.DisableSelfOptimisingOption.set(); + SOKNLTest.disableSelfOptimisingOption.set(); SOKNLTest.kOption.setValue(5); From b76334069d45564e528d66048ffe8fe0c38e5aac Mon Sep 17 00:00:00 2001 From: cassales Date: Tue, 23 Apr 2024 15:21:59 +1200 Subject: [PATCH 05/31] Initial commig with Mini-Batch classes --- .../moa/classifiers/AbstractClassifier.java | 1 - .../AbstractParallelClassifierMiniBatch.java | 101 ++++++++++++++++++ .../minibatch/AdaptiveRandomForestMB.java | 20 ++-- .../meta/minibatch/LeveragingBagMB.java | 25 ++--- .../meta/minibatch/OzaBagAdwinMB.java | 27 ++--- .../classifiers/meta/minibatch/OzaBagMB.java | 31 +++--- 6 files changed, 139 insertions(+), 66 deletions(-) create mode 100644 moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java diff --git a/moa/src/main/java/moa/classifiers/AbstractClassifier.java b/moa/src/main/java/moa/classifiers/AbstractClassifier.java index 29350a48e..11dfba4c5 100644 --- a/moa/src/main/java/moa/classifiers/AbstractClassifier.java +++ b/moa/src/main/java/moa/classifiers/AbstractClassifier.java @@ -95,7 +95,6 @@ public void prepareForUseImpl(TaskMonitor monitor, resetLearning(); } } - @Override public double[] getVotesForInstance(Example example){ diff --git a/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java b/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java new file mode 100644 index 000000000..c008390fe --- /dev/null +++ b/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java @@ -0,0 +1,101 @@ +/* + * AbstractClassifier.java + * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand + * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package moa.classifiers; + +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.instances.*; +import moa.capabilities.CapabilitiesHandler; +import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public abstract class AbstractParallelClassifierMiniBatch extends AbstractClassifier + implements Classifier, CapabilitiesHandler { //Learner> { + + @Override + public String getPurposeString() { + return "MOA Parallel Classifier with MiniBatch: " + getClass().getCanonicalName(); + } + + public IntOption numberOfCoresOption = new IntOption("numCores", 'c', + "The amount of CPU Cores used for multi-threading", 1, + 1, Runtime.getRuntime().availableProcessors()); + + public IntOption batchSizeOption = new IntOption("batchSize", 'b', + "The amount of instances the classifier should buffer before training.", + 1, 1, Integer.MAX_VALUE); + + // The amount of CPU cores to be run in parallel + public int numOfCores; + + // The threadpool to be used, based on the number of cores + protected ExecutorService threadpool; + + protected ArrayList myBatch; + + public AbstractParallelClassifierMiniBatch() { + if (isRandomizable()) { + this.randomSeedOption = new IntOption("randomSeed", 'r', + "Seed for random behaviour of the classifier.", 1); + } + } + + /** + * Trains this classifier "incrementally" using the given instance.

+ * + * The reason for ...Impl methods: ease programmer burden by not requiring + * them to remember calls to super in overridden methods. + * Note that this will produce compiler errors if not overridden. + * + * @param inst the instance to be used for training + */ + public void trainOnInstanceImpl(Instance inst) { + if (myBatch != null) { + this.myBatch.add(inst); + if (this.myBatch.size() == this.batchSizeOption.getValue()){ + this.trainOnInstances(this.myBatch); + this.myBatch.clear(); + } + } + } + + public abstract void trainOnInstances(ArrayList instances); + + + public void trainingHasEnded() { + if (this.threadpool != null) + this.threadpool.shutdown(); + } + + @Override + public void resetLearning() { + this.numOfCores = this.numberOfCoresOption.getValue(); + if (this.numOfCores > 1) { + this.threadpool = Executors.newFixedThreadPool(this.numOfCores); + } + this.trainingWeightSeenByModel = 0.0; + if (isRandomizable()) { + this.classifierRandom = new Random(this.randomSeed); + } + this.myBatch = new ArrayList<>(); + resetLearningImpl(); + } +} diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java index d5b2ef7ec..d5d4a0013 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java @@ -26,7 +26,7 @@ import moa.AbstractMOAObject; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractClassifierMiniBatch; +import moa.classifiers.AbstractParallelClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ChangeDetector; @@ -38,7 +38,6 @@ import moa.options.ClassOption; import java.util.ArrayList; -import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -78,7 +77,7 @@ * @author Heitor Murilo Gomes (heitor_murilo_gomes at yahoo dot com dot br) * @version $Revision: 1 $ */ -public class AdaptiveRandomForestMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { +public class AdaptiveRandomForestMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { @Override public String getPurposeString() { @@ -244,7 +243,7 @@ protected void initEnsemble(Instance instance) { ARFHoeffdingTree treeLearner = (ARFHoeffdingTree) getPreparedClassOption(this.treeLearnerOption); treeLearner.resetLearning(); - int seed = this.randomSeedOption.getValue(); + for(int i = 0 ; i < ensembleSize ; ++i) { treeLearner.subspaceSizeOption.setValue(this.subspaceSize); ARFBaseLearner tempARFBL = new ARFBaseLearner( @@ -257,8 +256,7 @@ protected void initEnsemble(Instance instance) { driftDetectionMethodOption, warningDetectionMethodOption, false); - this.trainers.add(new TrainingRunnable(tempARFBL, this.lambdaOption.getValue(), seed)); - seed++; + this.trainers.add(new TrainingRunnable(tempARFBL, this.lambdaOption.getValue())); } } @@ -429,27 +427,23 @@ protected class TrainingRunnable implements Runnable, Callable { private ArrayList instances; private final double lambdaOption; private long instancesSeen; - private int localSeed; - private Random trRandom; - public TrainingRunnable(ARFBaseLearner learner, double lambdaOption, int seed) { + public TrainingRunnable(ARFBaseLearner learner, double lambdaOption) { this.learner = learner; this.lambdaOption = lambdaOption; this.instancesSeen = 0; - this.localSeed = seed; - this.trRandom = new Random(); - this.trRandom.setSeed(this.localSeed); } @Override public void run() { for (Instance instance : this.instances) { ++this.instancesSeen; - int k = MiscUtils.poisson(this.lambdaOption, this.trRandom); + int k = MiscUtils.poisson(this.lambdaOption, ThreadLocalRandom.current()); if (k > 0) { learner.trainOnInstance(instance, k, this.instancesSeen); } } + } @Override diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java index b3e9d0e15..6d6b865c6 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java @@ -27,7 +27,7 @@ import moa.capabilities.Capabilities; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractClassifierMiniBatch; +import moa.classifiers.AbstractParallelClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ADWIN; @@ -37,7 +37,6 @@ import moa.options.ClassOption; import java.util.ArrayList; -import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -52,7 +51,7 @@ * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ -public class LeveragingBagMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { +public class LeveragingBagMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { private static final long serialVersionUID = 1L; @@ -109,9 +108,8 @@ public void resetLearningImpl() { if (ocos) { this.initMatrixCodes = true; } - int seed = this.randomSeedOption.getValue(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN((double) this.deltaAdwinOption.getValue()), ocos, wso, lao, seed)); + this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN((double) this.deltaAdwinOption.getValue()), ocos, wso, lao)); } } @@ -277,18 +275,13 @@ protected class TrainingRunnable implements Runnable, Callable { protected ADWIN ADError; protected boolean outputCodesOptionIsSet; protected int[] matrixCodes; - private int localSeed; - private Random trRandom; - public TrainingRunnable(Classifier learner, ADWIN ADError, boolean ocos, double wso, int lao, int seed) { + public TrainingRunnable(Classifier learner, ADWIN ADError, boolean ocos, double wso, int lao) { this.learner = learner; this.ADError = ADError; this.outputCodesOptionIsSet = ocos; this.w = wso; this.LevAlgOption = lao; - this.localSeed = seed; - this.trRandom = new Random(); - this.trRandom.setSeed(this.localSeed); } @Override @@ -297,24 +290,24 @@ public void run() { double k = 0.0; switch (this.LevAlgOption) { case 0: //LBagMC - k = MiscUtils.poisson(w, this.trRandom); + k = MiscUtils.poisson(w, ThreadLocalRandom.current()); break; case 1: //LeveragingBagME double error = this.ADError.getEstimation(); k = !this.learner.correctlyClassifies(instances.get(i).copy()) ? - 1.0 : (this.trRandom.nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0); + 1.0 : (ThreadLocalRandom.current().nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0); break; case 2: //LeveragingBagHalf w = 1.0; - k = this.trRandom.nextBoolean() ? 0.0 : w; + k = ThreadLocalRandom.current().nextBoolean() ? 0.0 : w; break; case 3: //LeveragingBagWT w = 1.0; - k = 1.0 + MiscUtils.poisson(w, this.trRandom); + k = 1.0 + MiscUtils.poisson(w, ThreadLocalRandom.current()); break; case 4: //LeveragingSubag w = 1.0; - k = MiscUtils.poisson(1, this.trRandom); + k = MiscUtils.poisson(1, ThreadLocalRandom.current()); k = (k > 0) ? w : 0; break; } diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java index ed638c746..30b4661f9 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java @@ -24,7 +24,7 @@ import moa.capabilities.Capabilities; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractClassifierMiniBatch; +import moa.classifiers.AbstractParallelClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ADWIN; @@ -34,7 +34,6 @@ import moa.options.ClassOption; import java.util.ArrayList; -import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -87,7 +86,7 @@ * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ -public class OzaBagAdwinMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { +public class OzaBagAdwinMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { private static final long serialVersionUID = 2L; @@ -109,12 +108,10 @@ public String getPurposeString() { @Override public void resetLearningImpl() { this.trainers = new ArrayList<>(); - int seed = this.randomSeedOption.getValue(); Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN(), seed)); - seed++; + this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN())); } } @@ -203,26 +200,20 @@ protected class TrainingRunnable implements Runnable, Callable { private Classifier learner; private ArrayList instances; protected ADWIN ADError; - private int localSeed; - private Random trRandom; - public TrainingRunnable(Classifier learner, ADWIN ADError, int seed) { + public TrainingRunnable(Classifier learner, ADWIN ADError) { this.learner = learner; this.instances = new ArrayList<>(); - this.localSeed = seed; this.ADError = ADError; - this.trRandom = new Random(); - this.trRandom.setSeed(this.localSeed); } @Override public void run() { - for (Instance inst : this.instances) { - int k = MiscUtils.poisson(1.0, this.trRandom); - Instance weightedInst = inst.copy(); - weightedInst.setWeight(inst.weight() * k); - this.learner.trainOnInstance(weightedInst); - boolean correctlyClassifies = this.learner.correctlyClassifies(inst); + for (Instance instance : this.instances) { + int k = MiscUtils.poisson(1.0, ThreadLocalRandom.current()); + instance.setWeight(instance.weight() * k); + this.learner.trainOnInstance(instance); + boolean correctlyClassifies = this.learner.correctlyClassifies(instance); double ErrEstimation = this.ADError.getEstimation(); if (this.ADError.setInput(correctlyClassifies ? 0 : 1)) { if (this.ADError.getEstimation() > ErrEstimation) { diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java index fc2034bb2..26a723512 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java @@ -23,7 +23,7 @@ import com.yahoo.labs.samoa.instances.Instance; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractClassifierMiniBatch; +import moa.classifiers.AbstractParallelClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.core.DoubleVector; @@ -32,8 +32,8 @@ import moa.options.ClassOption; import java.util.ArrayList; -import java.util.Random; import java.util.concurrent.Callable; +import java.util.concurrent.ThreadLocalRandom; /** * Incremental on-line bagging of Oza and Russell. @@ -55,7 +55,7 @@ * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ -public class OzaBagMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { +public class OzaBagMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { @Override public String getPurposeString() { @@ -75,12 +75,10 @@ public String getPurposeString() { @Override public void resetLearningImpl() { this.trainers = new ArrayList<>(); - int seed = this.randomSeedOption.getValue(); Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - trainers.add(new TrainingRunnable(baseLearner.copy(), seed)); - seed++; + trainers.add(new TrainingRunnable(baseLearner.copy())); } } @@ -146,29 +144,26 @@ public ImmutableCapabilities defineImmutableCapabilities() { /*** * Inner class to assist with the multi-thread execution. */ - protected class TrainingRunnable implements Runnable, Callable { - // TODO: Fix bug that makes seed initialized random objects not give the same result in MOA private Classifier learner; private ArrayList instances; - private Random trRandom; - public int localSeed; +// private int instancesSeen; +// private int weightsSeen; + - public TrainingRunnable(Classifier learner, int seed) { + public TrainingRunnable(Classifier learner) { this.learner = learner; this.instances = new ArrayList<>(); - this.localSeed = seed; - this.trRandom = new Random(); - this.trRandom.setSeed(this.localSeed); } @Override public void run() { for (Instance inst : this.instances) { - int k = MiscUtils.poisson(1.0, this.trRandom); - Instance weightedInst = inst.copy(); - weightedInst.setWeight(inst.weight() * k); - this.learner.trainOnInstance(weightedInst); + int k = MiscUtils.poisson(1.0, ThreadLocalRandom.current()); +// this.weightsSeen += k; +// this.instancesSeen++; + inst.setWeight(inst.weight() * k); + this.learner.trainOnInstance(inst); } } From 498ef4a00f8de04189992ee578f518c8160a4a14 Mon Sep 17 00:00:00 2001 From: sunyibin Date: Wed, 24 Apr 2024 15:18:56 +1200 Subject: [PATCH 06/31] update the windowed PI evaluator to get correct NMPIW results --- .../WindowPredictionIntervalEvaluator.java | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java index b1c0b80ed..b2ec1e305 100644 --- a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java +++ b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java @@ -43,19 +43,17 @@ public class WindowPredictionIntervalEvaluator extends AbstractOptionHandler 'w', "Size of Window", 1000); protected double TotalweightObserved = 0; - protected Estimator weightObserved; - protected Estimator squareError; - protected Estimator averageError; protected Estimator lower; protected Estimator upper; - protected Estimator counterCorrect; + protected Estimator truth; + protected int numClasses; public class Estimator { @@ -120,6 +118,7 @@ public void reset(int numClasses) { this.lower = new Estimator(this.widthOption.getValue()); this.upper = new Estimator(this.widthOption.getValue()); this.counterCorrect = new Estimator(this.widthOption.getValue()); + this.truth = new Estimator(this.widthOption.getValue()); this.TotalweightObserved = 0; } @@ -140,6 +139,7 @@ public void addResult(Example example, double[] prediction) { this.lower.add(prediction[0]); this.upper.add(prediction[2]); this.counterCorrect.add( inst.classValue() >= prediction[0] && inst.classValue() <= prediction[2]? 1 : 0); + this.truth.add(inst.classValue()); } //System.out.println(inst.classValue()+", "+prediction[0]); } @@ -185,14 +185,7 @@ public double getAverageLength(){ } public double getNMPIW(){ -// return Math.round((getAverageLength() / -// (((this.upper.max() - this.lower.max()) /2 + this.lower.max()) -// - ((this.upper.min() - this.lower.min()) /2 + this.lower.min()) -// ) -// * 10000.0) / 100.0); - - - return Math.round(getAverageLength() / ((this.upper.max() + this.lower.max() - this.upper.min() - this.lower.min()) / 2) * 10000.0) / 100.0; + return Math.round(getAverageLength() / (this.truth.max() - this.truth.min()) * 10000.0) / 100.0; } @Override From 653eaa0685c76433863daded265654ee87327f23 Mon Sep 17 00:00:00 2001 From: sunyibin Date: Fri, 19 Apr 2024 13:43:04 +1200 Subject: [PATCH 07/31] uploading prediction interval package --- .../WindowPredictionIntervalEvaluator.java | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java index b2ec1e305..b1c0b80ed 100644 --- a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java +++ b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java @@ -43,16 +43,18 @@ public class WindowPredictionIntervalEvaluator extends AbstractOptionHandler 'w', "Size of Window", 1000); protected double TotalweightObserved = 0; + protected Estimator weightObserved; + protected Estimator squareError; + protected Estimator averageError; protected Estimator lower; protected Estimator upper; - protected Estimator counterCorrect; - protected Estimator truth; + protected Estimator counterCorrect; protected int numClasses; @@ -118,7 +120,6 @@ public void reset(int numClasses) { this.lower = new Estimator(this.widthOption.getValue()); this.upper = new Estimator(this.widthOption.getValue()); this.counterCorrect = new Estimator(this.widthOption.getValue()); - this.truth = new Estimator(this.widthOption.getValue()); this.TotalweightObserved = 0; } @@ -139,7 +140,6 @@ public void addResult(Example example, double[] prediction) { this.lower.add(prediction[0]); this.upper.add(prediction[2]); this.counterCorrect.add( inst.classValue() >= prediction[0] && inst.classValue() <= prediction[2]? 1 : 0); - this.truth.add(inst.classValue()); } //System.out.println(inst.classValue()+", "+prediction[0]); } @@ -185,7 +185,14 @@ public double getAverageLength(){ } public double getNMPIW(){ - return Math.round(getAverageLength() / (this.truth.max() - this.truth.min()) * 10000.0) / 100.0; +// return Math.round((getAverageLength() / +// (((this.upper.max() - this.lower.max()) /2 + this.lower.max()) +// - ((this.upper.min() - this.lower.min()) /2 + this.lower.min()) +// ) +// * 10000.0) / 100.0); + + + return Math.round(getAverageLength() / ((this.upper.max() + this.lower.max() - this.upper.min() - this.lower.min()) / 2) * 10000.0) / 100.0; } @Override From d367cf84aa4a86800646db46b9fe644e6f02e980 Mon Sep 17 00:00:00 2001 From: sunyibin Date: Wed, 24 Apr 2024 15:18:56 +1200 Subject: [PATCH 08/31] update the windowed PI evaluator to get correct NMPIW results --- .../WindowPredictionIntervalEvaluator.java | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java index b1c0b80ed..b2ec1e305 100644 --- a/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java +++ b/moa/src/main/java/moa/evaluation/WindowPredictionIntervalEvaluator.java @@ -43,19 +43,17 @@ public class WindowPredictionIntervalEvaluator extends AbstractOptionHandler 'w', "Size of Window", 1000); protected double TotalweightObserved = 0; - protected Estimator weightObserved; - protected Estimator squareError; - protected Estimator averageError; protected Estimator lower; protected Estimator upper; - protected Estimator counterCorrect; + protected Estimator truth; + protected int numClasses; public class Estimator { @@ -120,6 +118,7 @@ public void reset(int numClasses) { this.lower = new Estimator(this.widthOption.getValue()); this.upper = new Estimator(this.widthOption.getValue()); this.counterCorrect = new Estimator(this.widthOption.getValue()); + this.truth = new Estimator(this.widthOption.getValue()); this.TotalweightObserved = 0; } @@ -140,6 +139,7 @@ public void addResult(Example example, double[] prediction) { this.lower.add(prediction[0]); this.upper.add(prediction[2]); this.counterCorrect.add( inst.classValue() >= prediction[0] && inst.classValue() <= prediction[2]? 1 : 0); + this.truth.add(inst.classValue()); } //System.out.println(inst.classValue()+", "+prediction[0]); } @@ -185,14 +185,7 @@ public double getAverageLength(){ } public double getNMPIW(){ -// return Math.round((getAverageLength() / -// (((this.upper.max() - this.lower.max()) /2 + this.lower.max()) -// - ((this.upper.min() - this.lower.min()) /2 + this.lower.min()) -// ) -// * 10000.0) / 100.0); - - - return Math.round(getAverageLength() / ((this.upper.max() + this.lower.max() - this.upper.min() - this.lower.min()) / 2) * 10000.0) / 100.0; + return Math.round(getAverageLength() / (this.truth.max() - this.truth.min()) * 10000.0) / 100.0; } @Override From 9571b830e4ced8dcb4d83b4dfe4629d02eff2743 Mon Sep 17 00:00:00 2001 From: Spencer Sun Date: Mon, 20 May 2024 14:42:04 +1200 Subject: [PATCH 09/31] fix: add storing functionality in EfficientEvaluationLoops --- .../evaluation/EfficientEvaluationLoops.java | 52 ++++++++++++++----- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java b/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java index 250aa3642..719c18ead 100644 --- a/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java +++ b/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java @@ -4,6 +4,7 @@ import moa.classifiers.SemiSupervisedLearner; import moa.classifiers.semisupervised.ClusterAndLabelClassifier; import moa.core.Example; +import moa.core.InstanceExample; import moa.core.Measurement; import moa.learners.Learner; import moa.streams.ArffFileStream; @@ -25,12 +26,24 @@ public class EfficientEvaluationLoops { public static class PrequentialResult { public ArrayList windowedResults; public double[] cumulativeResults; + public ArrayList targets; + public ArrayList predictions; public HashMap otherMeasurements; public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults) { this.windowedResults = windowedResults; this.cumulativeResults = cumulativeResults; + this.targets = null; + this.predictions = null; + } + + public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, + ArrayList targets, ArrayList predictions) { + this.windowedResults = windowedResults; + this.cumulativeResults = cumulativeResults; + this.targets = targets; + this.predictions = predictions; } public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, @@ -57,18 +70,24 @@ public PrequentialResult(ArrayList windowedResults, double[] cumulativ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Learner learner, LearningPerformanceEvaluator basicEvaluator, LearningPerformanceEvaluator windowedEvaluator, - long maxInstances, long windowSize) { + long maxInstances, long windowSize, + boolean storeY, boolean storePredictions) { int instancesProcessed = 0; if (!stream.hasMoreInstances()) stream.restart(); ArrayList windowed_results = new ArrayList<>(); + ArrayList targetValues = new ArrayList<>(); + ArrayList predictions = new ArrayList<>(); + while (stream.hasMoreInstances() && (maxInstances == -1 || instancesProcessed < maxInstances)) { - Example instance = stream.nextInstance(); + Example instance = stream.nextInstance(); + if (storeY) + targetValues.add(instance.getData().classValue()); double[] prediction = learner.getVotesForInstance(instance); if (basicEvaluator != null) @@ -76,6 +95,9 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear if (windowedEvaluator != null) windowedEvaluator.addResult(instance, prediction); + if (storePredictions) + predictions.add(prediction.length == 0? 0 : prediction[0]); + learner.trainOnInstance(instance); instancesProcessed++; @@ -106,8 +128,10 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear for (int i = 0; i < cumulative_results.length; ++i) cumulative_results[i] = measurements[i].getValue(); } - - return new PrequentialResult(windowed_results, cumulative_results); + if (!storePredictions && !storeY) + return new PrequentialResult(windowed_results, cumulative_results); + else + return new PrequentialResult(windowed_results, cumulative_results, targetValues, predictions); } public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, Learner learner, @@ -379,7 +403,7 @@ private static void testPrequentialEfficiency1(Learner learner) { basic_evaluator.recallPerClassOption.setValue(true); basic_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null, 100000, 1); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, null, 100000, 1, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -413,7 +437,7 @@ private static void testPrequentialEvaluation_edge_cases1() { // windowed_evaluator.widthOption.setValue(1000); // windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, null, 100000, 1000); + PrequentialResult results = PrequentialEvaluation(stream, learner, null, null, 100000, 1000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -455,7 +479,7 @@ private static void testPrequentialEvaluation_edge_cases2() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 1000, 10000); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 1000, 10000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -496,7 +520,7 @@ private static void testPrequentialEvaluation_edge_cases3() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 10, 1); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 10, 1, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -537,7 +561,7 @@ private static void testPrequentialEvaluation_edge_cases4() { windowed_evaluator.widthOption.setValue(10000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, -1, 10000); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, -1, 10000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -579,7 +603,7 @@ private static void testPrequentialEvaluation_SampleFrequency_TestThenTrain() { // windowed_evaluator.widthOption.setValue(10000); // windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator, -1, 10000); + PrequentialResult results = PrequentialEvaluation(stream, learner, null, basic_evaluator, -1, 10000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -621,7 +645,7 @@ private static void testPrequentialRegressionEvaluation() { windowed_evaluator.widthOption.setValue(1000); // windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -664,7 +688,7 @@ private static void testPrequentialEvaluation() { windowed_evaluator.widthOption.setValue(1000); windowed_evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000); + PrequentialResult results = PrequentialEvaluation(stream, learner, basic_evaluator, windowed_evaluator, 100000, 1000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -703,7 +727,7 @@ private static void testTestThenTrainEvaluation() { evaluator.recallPerClassOption.setValue(true); evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null, 100000, 100000); + PrequentialResult results = PrequentialEvaluation(stream, learner, evaluator, null, 100000, 100000, false, false); // Record the end time long endTime = System.currentTimeMillis(); @@ -735,7 +759,7 @@ private static void testWindowedEvaluation() { evaluator.recallPerClassOption.setValue(true); evaluator.prepareForUse(); - PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator, 100000, 10000); + PrequentialResult results = PrequentialEvaluation(stream, learner, null, evaluator, 100000, 10000, false, false); // Record the end time long endTime = System.currentTimeMillis(); From 6853d2267408bd029843df872674e6ec74a0ccd3 Mon Sep 17 00:00:00 2001 From: cassales Date: Tue, 21 May 2024 14:08:08 +1200 Subject: [PATCH 10/31] new version of the parallel ensembles with minibatch (has the reproducibility issue) --- .../AbstractClassifierMiniBatch.java | 1 - .../AbstractParallelClassifierMiniBatch.java | 101 --------- .../minibatch/AdaptiveRandomForestMB.java | 20 +- .../meta/minibatch/LeveragingBagMB.java | 25 ++- .../meta/minibatch/OzaBagAdwinMB.java | 18 +- .../classifiers/meta/minibatch/OzaBagMB.java | 30 ++- .../meta/minibatch/threadTesting.java | 197 ++++++++++++++++++ 7 files changed, 258 insertions(+), 134 deletions(-) delete mode 100644 moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java create mode 100644 moa/src/main/java/moa/classifiers/meta/minibatch/threadTesting.java diff --git a/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java b/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java index 1c50f530d..b5026e761 100644 --- a/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java +++ b/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java @@ -82,7 +82,6 @@ public void trainOnInstanceImpl(Instance inst) { public void trainingHasEnded() { if (this.threadpool != null) this.threadpool.shutdown(); - this.myBatch = null; } @Override diff --git a/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java b/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java deleted file mode 100644 index c008390fe..000000000 --- a/moa/src/main/java/moa/classifiers/AbstractParallelClassifierMiniBatch.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * AbstractClassifier.java - * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand - * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package moa.classifiers; - -import com.github.javacliparser.IntOption; -import com.yahoo.labs.samoa.instances.*; -import moa.capabilities.CapabilitiesHandler; -import java.util.*; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; - -public abstract class AbstractParallelClassifierMiniBatch extends AbstractClassifier - implements Classifier, CapabilitiesHandler { //Learner> { - - @Override - public String getPurposeString() { - return "MOA Parallel Classifier with MiniBatch: " + getClass().getCanonicalName(); - } - - public IntOption numberOfCoresOption = new IntOption("numCores", 'c', - "The amount of CPU Cores used for multi-threading", 1, - 1, Runtime.getRuntime().availableProcessors()); - - public IntOption batchSizeOption = new IntOption("batchSize", 'b', - "The amount of instances the classifier should buffer before training.", - 1, 1, Integer.MAX_VALUE); - - // The amount of CPU cores to be run in parallel - public int numOfCores; - - // The threadpool to be used, based on the number of cores - protected ExecutorService threadpool; - - protected ArrayList myBatch; - - public AbstractParallelClassifierMiniBatch() { - if (isRandomizable()) { - this.randomSeedOption = new IntOption("randomSeed", 'r', - "Seed for random behaviour of the classifier.", 1); - } - } - - /** - * Trains this classifier "incrementally" using the given instance.

- * - * The reason for ...Impl methods: ease programmer burden by not requiring - * them to remember calls to super in overridden methods. - * Note that this will produce compiler errors if not overridden. - * - * @param inst the instance to be used for training - */ - public void trainOnInstanceImpl(Instance inst) { - if (myBatch != null) { - this.myBatch.add(inst); - if (this.myBatch.size() == this.batchSizeOption.getValue()){ - this.trainOnInstances(this.myBatch); - this.myBatch.clear(); - } - } - } - - public abstract void trainOnInstances(ArrayList instances); - - - public void trainingHasEnded() { - if (this.threadpool != null) - this.threadpool.shutdown(); - } - - @Override - public void resetLearning() { - this.numOfCores = this.numberOfCoresOption.getValue(); - if (this.numOfCores > 1) { - this.threadpool = Executors.newFixedThreadPool(this.numOfCores); - } - this.trainingWeightSeenByModel = 0.0; - if (isRandomizable()) { - this.classifierRandom = new Random(this.randomSeed); - } - this.myBatch = new ArrayList<>(); - resetLearningImpl(); - } -} diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java index d5d4a0013..d5b2ef7ec 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/AdaptiveRandomForestMB.java @@ -26,7 +26,7 @@ import moa.AbstractMOAObject; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractParallelClassifierMiniBatch; +import moa.classifiers.AbstractClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ChangeDetector; @@ -38,6 +38,7 @@ import moa.options.ClassOption; import java.util.ArrayList; +import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -77,7 +78,7 @@ * @author Heitor Murilo Gomes (heitor_murilo_gomes at yahoo dot com dot br) * @version $Revision: 1 $ */ -public class AdaptiveRandomForestMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { +public class AdaptiveRandomForestMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { @Override public String getPurposeString() { @@ -243,7 +244,7 @@ protected void initEnsemble(Instance instance) { ARFHoeffdingTree treeLearner = (ARFHoeffdingTree) getPreparedClassOption(this.treeLearnerOption); treeLearner.resetLearning(); - + int seed = this.randomSeedOption.getValue(); for(int i = 0 ; i < ensembleSize ; ++i) { treeLearner.subspaceSizeOption.setValue(this.subspaceSize); ARFBaseLearner tempARFBL = new ARFBaseLearner( @@ -256,7 +257,8 @@ protected void initEnsemble(Instance instance) { driftDetectionMethodOption, warningDetectionMethodOption, false); - this.trainers.add(new TrainingRunnable(tempARFBL, this.lambdaOption.getValue())); + this.trainers.add(new TrainingRunnable(tempARFBL, this.lambdaOption.getValue(), seed)); + seed++; } } @@ -427,23 +429,27 @@ protected class TrainingRunnable implements Runnable, Callable { private ArrayList instances; private final double lambdaOption; private long instancesSeen; + private int localSeed; + private Random trRandom; - public TrainingRunnable(ARFBaseLearner learner, double lambdaOption) { + public TrainingRunnable(ARFBaseLearner learner, double lambdaOption, int seed) { this.learner = learner; this.lambdaOption = lambdaOption; this.instancesSeen = 0; + this.localSeed = seed; + this.trRandom = new Random(); + this.trRandom.setSeed(this.localSeed); } @Override public void run() { for (Instance instance : this.instances) { ++this.instancesSeen; - int k = MiscUtils.poisson(this.lambdaOption, ThreadLocalRandom.current()); + int k = MiscUtils.poisson(this.lambdaOption, this.trRandom); if (k > 0) { learner.trainOnInstance(instance, k, this.instancesSeen); } } - } @Override diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java index 6d6b865c6..b3e9d0e15 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/LeveragingBagMB.java @@ -27,7 +27,7 @@ import moa.capabilities.Capabilities; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractParallelClassifierMiniBatch; +import moa.classifiers.AbstractClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ADWIN; @@ -37,6 +37,7 @@ import moa.options.ClassOption; import java.util.ArrayList; +import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -51,7 +52,7 @@ * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ -public class LeveragingBagMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { +public class LeveragingBagMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { private static final long serialVersionUID = 1L; @@ -108,8 +109,9 @@ public void resetLearningImpl() { if (ocos) { this.initMatrixCodes = true; } + int seed = this.randomSeedOption.getValue(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN((double) this.deltaAdwinOption.getValue()), ocos, wso, lao)); + this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN((double) this.deltaAdwinOption.getValue()), ocos, wso, lao, seed)); } } @@ -275,13 +277,18 @@ protected class TrainingRunnable implements Runnable, Callable { protected ADWIN ADError; protected boolean outputCodesOptionIsSet; protected int[] matrixCodes; + private int localSeed; + private Random trRandom; - public TrainingRunnable(Classifier learner, ADWIN ADError, boolean ocos, double wso, int lao) { + public TrainingRunnable(Classifier learner, ADWIN ADError, boolean ocos, double wso, int lao, int seed) { this.learner = learner; this.ADError = ADError; this.outputCodesOptionIsSet = ocos; this.w = wso; this.LevAlgOption = lao; + this.localSeed = seed; + this.trRandom = new Random(); + this.trRandom.setSeed(this.localSeed); } @Override @@ -290,24 +297,24 @@ public void run() { double k = 0.0; switch (this.LevAlgOption) { case 0: //LBagMC - k = MiscUtils.poisson(w, ThreadLocalRandom.current()); + k = MiscUtils.poisson(w, this.trRandom); break; case 1: //LeveragingBagME double error = this.ADError.getEstimation(); k = !this.learner.correctlyClassifies(instances.get(i).copy()) ? - 1.0 : (ThreadLocalRandom.current().nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0); + 1.0 : (this.trRandom.nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0); break; case 2: //LeveragingBagHalf w = 1.0; - k = ThreadLocalRandom.current().nextBoolean() ? 0.0 : w; + k = this.trRandom.nextBoolean() ? 0.0 : w; break; case 3: //LeveragingBagWT w = 1.0; - k = 1.0 + MiscUtils.poisson(w, ThreadLocalRandom.current()); + k = 1.0 + MiscUtils.poisson(w, this.trRandom); break; case 4: //LeveragingSubag w = 1.0; - k = MiscUtils.poisson(1, ThreadLocalRandom.current()); + k = MiscUtils.poisson(1, this.trRandom); k = (k > 0) ? w : 0; break; } diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java index 30b4661f9..b992e1777 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java @@ -24,7 +24,7 @@ import moa.capabilities.Capabilities; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractParallelClassifierMiniBatch; +import moa.classifiers.AbstractClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.core.driftdetection.ADWIN; @@ -34,6 +34,7 @@ import moa.options.ClassOption; import java.util.ArrayList; +import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ThreadLocalRandom; @@ -86,7 +87,7 @@ * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ -public class OzaBagAdwinMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { +public class OzaBagAdwinMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { private static final long serialVersionUID = 2L; @@ -108,10 +109,12 @@ public String getPurposeString() { @Override public void resetLearningImpl() { this.trainers = new ArrayList<>(); + int seed = this.randomSeedOption.getValue(); Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN())); + this.trainers.add(new TrainingRunnable(baseLearner.copy(), new ADWIN(), seed)); + seed++; } } @@ -200,17 +203,22 @@ protected class TrainingRunnable implements Runnable, Callable { private Classifier learner; private ArrayList instances; protected ADWIN ADError; + private int localSeed; + private Random trRandom; - public TrainingRunnable(Classifier learner, ADWIN ADError) { + public TrainingRunnable(Classifier learner, ADWIN ADError, int seed) { this.learner = learner; this.instances = new ArrayList<>(); + this.localSeed = seed; this.ADError = ADError; + this.trRandom = new Random(); + this.trRandom.setSeed(this.localSeed); } @Override public void run() { for (Instance instance : this.instances) { - int k = MiscUtils.poisson(1.0, ThreadLocalRandom.current()); + int k = MiscUtils.poisson(1.0, this.trRandom); instance.setWeight(instance.weight() * k); this.learner.trainOnInstance(instance); boolean correctlyClassifies = this.learner.correctlyClassifies(instance); diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java index 26a723512..3f6249030 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java @@ -23,7 +23,7 @@ import com.yahoo.labs.samoa.instances.Instance; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.AbstractParallelClassifierMiniBatch; +import moa.classifiers.AbstractClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.core.DoubleVector; @@ -32,8 +32,8 @@ import moa.options.ClassOption; import java.util.ArrayList; +import java.util.Random; import java.util.concurrent.Callable; -import java.util.concurrent.ThreadLocalRandom; /** * Incremental on-line bagging of Oza and Russell. @@ -55,7 +55,7 @@ * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ -public class OzaBagMB extends AbstractParallelClassifierMiniBatch implements MultiClassClassifier { +public class OzaBagMB extends AbstractClassifierMiniBatch implements MultiClassClassifier { @Override public String getPurposeString() { @@ -75,10 +75,12 @@ public String getPurposeString() { @Override public void resetLearningImpl() { this.trainers = new ArrayList<>(); + int seed = this.randomSeedOption.getValue(); Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { - trainers.add(new TrainingRunnable(baseLearner.copy())); + trainers.add(new TrainingRunnable(baseLearner.copy(), seed)); + seed++; } } @@ -144,24 +146,30 @@ public ImmutableCapabilities defineImmutableCapabilities() { /*** * Inner class to assist with the multi-thread execution. */ + protected class TrainingRunnable implements Runnable, Callable { + // TODO: Fix bug that makes seed initialized random objects not give the same result in MOA private Classifier learner; private ArrayList instances; -// private int instancesSeen; -// private int weightsSeen; - + private Random trRandom; + public int localSeed; - public TrainingRunnable(Classifier learner) { + public TrainingRunnable(Classifier learner, int seed) { this.learner = learner; this.instances = new ArrayList<>(); + this.localSeed = seed; + this.trRandom = new Random(); + this.trRandom.setSeed(this.localSeed); } @Override public void run() { + if (this.trRandom == null) { + this.trRandom = new Random(); + this.trRandom.setSeed(this.localSeed); + } for (Instance inst : this.instances) { - int k = MiscUtils.poisson(1.0, ThreadLocalRandom.current()); -// this.weightsSeen += k; -// this.instancesSeen++; + int k = MiscUtils.poisson(1.0, this.trRandom); inst.setWeight(inst.weight() * k); this.learner.trainOnInstance(inst); } diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/threadTesting.java b/moa/src/main/java/moa/classifiers/meta/minibatch/threadTesting.java new file mode 100644 index 000000000..585410323 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/threadTesting.java @@ -0,0 +1,197 @@ +/* + * OzaBag.java + * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand + * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ +package moa.classifiers.meta.minibatch; + +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.capabilities.Capability; +import moa.capabilities.ImmutableCapabilities; +import moa.classifiers.AbstractClassifierMiniBatch; +import moa.classifiers.Classifier; +import moa.classifiers.MultiClassClassifier; +import moa.core.DoubleVector; +import moa.core.Measurement; +import moa.core.MiscUtils; +import moa.options.ClassOption; + +import java.util.ArrayList; +import java.util.Random; +import java.util.concurrent.Callable; + +/** + * Incremental on-line bagging of Oza and Russell. + * + *

Oza and Russell developed online versions of bagging and boosting for + * Data Streams. They show how the process of sampling bootstrap replicates + * from training data can be simulated in a data stream context. They observe + * that the probability that any individual example will be chosen for a + * replicate tends to a Poisson(1) distribution.

+ * + *

[OR] N. Oza and S. Russell. Online bagging and boosting. + * In Artificial Intelligence and Statistics 2001, pages 105–112. + * Morgan Kaufmann, 2001.

+ * + *

Parameters:

    + *
  • -l : Classifier to train
  • + *
  • -s : The number of models in the bag
+ * + * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) + * @version $Revision: 7 $ + */ +public class threadTesting extends AbstractClassifierMiniBatch implements MultiClassClassifier { + + @Override + public String getPurposeString() { + return "Incremental on-line bagging of Oza and Russell using parallelism."; + } + + private static final long serialVersionUID = 2L; + + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Classifier.class, "trees.HoeffdingTree"); + + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + protected ArrayList trainers; + + @Override + public void resetLearningImpl() { + System.out.println("\n\n\n\n\n\nthreadTsting resetIMPL"); + this.trainers = new ArrayList<>(); + int seed = this.randomSeedOption.getValue(); + Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); + baseLearner.resetLearning(); + for (int i = 0; i < this.ensembleSizeOption.getValue(); i++) { + trainers.add(new TrainingRunnable(baseLearner.copy(), seed)); + System.out.println("created trainer " + i + " with seed " + seed); + seed++; + } + } + + @Override + public void trainOnInstances(ArrayList instances) { + for (TrainingRunnable t : trainers) + t.instances = new ArrayList<>(instances); + if (this.threadpool != null) { + try { + this.threadpool.invokeAll(trainers); + } catch (InterruptedException ex) { + throw new RuntimeException("Could not call invokeAll() on training threads."); + } + } + } + + @Override + public double[] getVotesForInstance(Instance inst) { + DoubleVector combinedVote = new DoubleVector(); + for (TrainingRunnable trainer : this.trainers) { + DoubleVector vote = new DoubleVector(trainer.learner.getVotesForInstance(inst)); + if (vote.sumOfValues() > 0.0) { + vote.normalize(); + combinedVote.addValues(vote); + } + } + return combinedVote.getArrayRef(); + } + + @Override + public boolean isRandomizable() { + return true; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return new Measurement[]{new Measurement("ensemble size", + this.trainers != null ? this.trainers.size() : 0)}; + } + + @Override + public Classifier[] getSubClassifiers() { + Classifier[] ensemble = new Classifier[this.ensembleSizeOption.getValue()]; + for (int i = 0; i < this.trainers.size(); i++) { + ensemble[i] = this.trainers.get(i).learner; + } + return ensemble.clone(); + } + + @Override + public ImmutableCapabilities defineImmutableCapabilities() { + if (this.getClass() == threadTesting.class) + return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE); + else + return new ImmutableCapabilities(Capability.VIEW_STANDARD); + } + + /*** + * Inner class to assist with the multi-thread execution. + */ + protected class TrainingRunnable implements Runnable, Callable { + private Classifier learner; + private ArrayList instances; + private Random trRandom; + public int localSeed; +// private int instancesSeen; +// private int weightsSeen; + + + public TrainingRunnable(Classifier learner, int seed) { + this.learner = learner; + this.instances = new ArrayList<>(); + this.localSeed = seed; +// this.trRandom = null; + this.trRandom = new Random(seed); + } + + @Override + public void run() { + if (this.trRandom == null) { + this.trRandom = new Random(this.localSeed); + System.out.println("created local random in run method with seed " + this.localSeed); +// String debugging = ""; +// for (int i = 0; i < 10; i++) { +// debugging += this.trRandom.nextInt(1000) + ", "; +// } +// System.out.println("classifier with seed " + this.localSeed + " generated:\n" + debugging); + } +// String debugging = ""; +// for (int i = 0; i < 10; i++) { +// debugging += MiscUtils.poisson(1.0, this.trRandom) + ", "; +// } + for (Instance inst : this.instances) { + int k = MiscUtils.poisson(1.0, this.trRandom); + inst.setWeight(inst.weight() * k); + this.learner.trainOnInstance(inst); + } + } + + @Override + public Integer call() { + run(); + return 0; + } + } + +} From d84581b856abf1183b1a4c6318568013b8753506 Mon Sep 17 00:00:00 2001 From: cassales Date: Mon, 27 May 2024 19:16:59 +1200 Subject: [PATCH 11/31] fix bug from OzaBag and OzaBagADWIN ensembles --- .../moa/classifiers/AbstractClassifierMiniBatch.java | 1 + .../moa/classifiers/meta/minibatch/OzaBagAdwinMB.java | 9 +++++---- .../java/moa/classifiers/meta/minibatch/OzaBagMB.java | 9 +++------ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java b/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java index b5026e761..1c50f530d 100644 --- a/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java +++ b/moa/src/main/java/moa/classifiers/AbstractClassifierMiniBatch.java @@ -82,6 +82,7 @@ public void trainOnInstanceImpl(Instance inst) { public void trainingHasEnded() { if (this.threadpool != null) this.threadpool.shutdown(); + this.myBatch = null; } @Override diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java index b992e1777..ed638c746 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagAdwinMB.java @@ -217,11 +217,12 @@ public TrainingRunnable(Classifier learner, ADWIN ADError, int seed) { @Override public void run() { - for (Instance instance : this.instances) { + for (Instance inst : this.instances) { int k = MiscUtils.poisson(1.0, this.trRandom); - instance.setWeight(instance.weight() * k); - this.learner.trainOnInstance(instance); - boolean correctlyClassifies = this.learner.correctlyClassifies(instance); + Instance weightedInst = inst.copy(); + weightedInst.setWeight(inst.weight() * k); + this.learner.trainOnInstance(weightedInst); + boolean correctlyClassifies = this.learner.correctlyClassifies(inst); double ErrEstimation = this.ADError.getEstimation(); if (this.ADError.setInput(correctlyClassifies ? 0 : 1)) { if (this.ADError.getEstimation() > ErrEstimation) { diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java index 3f6249030..fc2034bb2 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/OzaBagMB.java @@ -164,14 +164,11 @@ public TrainingRunnable(Classifier learner, int seed) { @Override public void run() { - if (this.trRandom == null) { - this.trRandom = new Random(); - this.trRandom.setSeed(this.localSeed); - } for (Instance inst : this.instances) { int k = MiscUtils.poisson(1.0, this.trRandom); - inst.setWeight(inst.weight() * k); - this.learner.trainOnInstance(inst); + Instance weightedInst = inst.copy(); + weightedInst.setWeight(inst.weight() * k); + this.learner.trainOnInstance(weightedInst); } } From 7bae559a7ebb788cea72c44d5b083a9489357f55 Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Tue, 28 May 2024 10:09:38 +1200 Subject: [PATCH 12/31] Add instructions to build moa with dependencies --- README.md | 22 ++++++++++++++++++++++ moa/pom.xml | 16 +++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ca33590d2..c80de92c9 100755 --- a/README.md +++ b/README.md @@ -32,3 +32,25 @@ If you want to refer to MOA in a publication, please cite the following JMLR pap > MOA: Massive Online Analysis; Journal of Machine Learning Research 11: 1601-1604 +## Building MOA for CapyMOA + +> These steps assume you have Java installed and maven installed. If you don't +> have maven installed, you can download it from +> [here](https://maven.apache.org/download.cgi). You can achieve the same +> outcome with IntelliJ IDEA by [building moa with the IDE](https://moa.cms.waikato.ac.nz/tutorial-6-building-moa-from-the-source/) (The linked doc is a little out of date) +> and [packaging it as a single jar file](https://stackoverflow.com/questions/1082580/how-to-build-jars-from-intellij-idea-properly). + +You can compile moa as a single jar file with all dependencies included by running the following command in the `moa` directory: +```bash +cd ./moa +mvn compile assembly:single +``` + +If successful, the jar file will be built to a file like this `moa/target/moa-2023.04.1-SNAPSHOT-jar-with-dependencies.jar` with a different date. + +One way to verify that the jar file was built correctly is to run the following command: +```bash +java -jar ./moa/target/moa-2023.04.1-SNAPSHOT-jar-with-dependencies.jar +``` +This should start the MOA GUI. + diff --git a/moa/pom.xml b/moa/pom.xml index b1f03c533..07970201a 100644 --- a/moa/pom.xml +++ b/moa/pom.xml @@ -251,14 +251,24 @@ license-maven-plugin - + + - org.apache.maven.plugins maven-assembly-plugin + + + + moa.gui.GUI + + + + jar-with-dependencies + + From b231093578bc7194531685e25072bfee91648d45 Mon Sep 17 00:00:00 2001 From: DwayneAcosta Date: Sat, 27 Jul 2024 23:03:57 +1200 Subject: [PATCH 13/31] RW_kNN Random --- .project | 11 +++++++ moa/.classpath | 22 +++++++++++-- moa/.project | 11 +++++++ moa/.settings/org.eclipse.jdt.apt.core.prefs | 2 ++ moa/.settings/org.eclipse.jdt.core.prefs | 11 ++++--- .../java/moa/classifiers/lazy/RW_kNN.java | 4 +-- weka-package/.classpath | 32 ++++++++++++++++++- weka-package/.project | 11 +++++++ .../.settings/org.eclipse.jdt.apt.core.prefs | 2 ++ .../.settings/org.eclipse.jdt.core.prefs | 11 ++++--- 10 files changed, 104 insertions(+), 13 deletions(-) create mode 100644 moa/.settings/org.eclipse.jdt.apt.core.prefs create mode 100644 weka-package/.settings/org.eclipse.jdt.apt.core.prefs diff --git a/.project b/.project index 76bea900b..79e6f3bb8 100644 --- a/.project +++ b/.project @@ -14,4 +14,15 @@ org.eclipse.m2e.core.maven2Nature + + + 1722078006188 + + 30 + + org.eclipse.core.resources.regexFilterMatcher + node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + + + diff --git a/moa/.classpath b/moa/.classpath index 5411c4697..1dfc1cab9 100644 --- a/moa/.classpath +++ b/moa/.classpath @@ -9,6 +9,7 @@ + @@ -22,11 +23,11 @@ + - + - @@ -35,5 +36,22 @@ + + + + + + + + + + + + + + + + + diff --git a/moa/.project b/moa/.project index c5879880d..fc77517ff 100644 --- a/moa/.project +++ b/moa/.project @@ -20,4 +20,15 @@ org.eclipse.jdt.core.javanature org.eclipse.m2e.core.maven2Nature + + + 1722078006159 + + 30 + + org.eclipse.core.resources.regexFilterMatcher + node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + + + diff --git a/moa/.settings/org.eclipse.jdt.apt.core.prefs b/moa/.settings/org.eclipse.jdt.apt.core.prefs new file mode 100644 index 000000000..d4313d4b2 --- /dev/null +++ b/moa/.settings/org.eclipse.jdt.apt.core.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.apt.aptEnabled=false diff --git a/moa/.settings/org.eclipse.jdt.core.prefs b/moa/.settings/org.eclipse.jdt.core.prefs index 43c686ff7..ea7a397f8 100644 --- a/moa/.settings/org.eclipse.jdt.core.prefs +++ b/moa/.settings/org.eclipse.jdt.core.prefs @@ -1,13 +1,16 @@ eclipse.preferences.version=1 org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled -org.eclipse.jdt.core.compiler.codegen.targetPlatform=10 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve -org.eclipse.jdt.core.compiler.compliance=10 +org.eclipse.jdt.core.compiler.compliance=1.8 org.eclipse.jdt.core.compiler.debug.lineNumber=generate org.eclipse.jdt.core.compiler.debug.localVariable=generate org.eclipse.jdt.core.compiler.debug.sourceFile=generate org.eclipse.jdt.core.compiler.problem.assertIdentifier=error +org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled org.eclipse.jdt.core.compiler.problem.enumIdentifier=error org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning -org.eclipse.jdt.core.compiler.release=enabled -org.eclipse.jdt.core.compiler.source=10 +org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.release=disabled +org.eclipse.jdt.core.compiler.source=1.8 diff --git a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java index 7b93a91b6..4613a9e6a 100755 --- a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java +++ b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java @@ -99,7 +99,7 @@ public void trainOnInstanceImpl(Instance inst) { this.reservoir = new Instances(inst.dataset()); } if (this.limitOptionReservoir.getValue() <= this.reservoir.numInstances()) { - int replaceIndex = r.nextInt(this.limitOptionReservoir.getValue() - 1); + int replaceIndex = this.classifierRandom.nextInt(this.limitOptionReservoir.getValue() - 1); this.reservoir.set(replaceIndex, inst); } else this.reservoir.add(inst); @@ -155,6 +155,6 @@ public void getModelDescription(StringBuilder out, int indent) { } public boolean isRandomizable() { - return false; + return true; } } \ No newline at end of file diff --git a/weka-package/.classpath b/weka-package/.classpath index 5e8a55fef..decd61996 100644 --- a/weka-package/.classpath +++ b/weka-package/.classpath @@ -13,7 +13,7 @@ - + @@ -23,5 +23,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/weka-package/.project b/weka-package/.project index f608fce93..0b9b9e1ed 100644 --- a/weka-package/.project +++ b/weka-package/.project @@ -20,4 +20,15 @@ org.eclipse.jdt.core.javanature org.eclipse.m2e.core.maven2Nature + + + 1722078006202 + + 30 + + org.eclipse.core.resources.regexFilterMatcher + node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + + + diff --git a/weka-package/.settings/org.eclipse.jdt.apt.core.prefs b/weka-package/.settings/org.eclipse.jdt.apt.core.prefs new file mode 100644 index 000000000..d4313d4b2 --- /dev/null +++ b/weka-package/.settings/org.eclipse.jdt.apt.core.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.apt.aptEnabled=false diff --git a/weka-package/.settings/org.eclipse.jdt.core.prefs b/weka-package/.settings/org.eclipse.jdt.core.prefs index 51664a5f5..46235dc07 100644 --- a/weka-package/.settings/org.eclipse.jdt.core.prefs +++ b/weka-package/.settings/org.eclipse.jdt.core.prefs @@ -1,6 +1,9 @@ eclipse.preferences.version=1 -org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 -org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=11 +org.eclipse.jdt.core.compiler.compliance=11 +org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning -org.eclipse.jdt.core.compiler.release=enabled -org.eclipse.jdt.core.compiler.source=1.8 +org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.release=disabled +org.eclipse.jdt.core.compiler.source=11 From 864c805e9f37b6e5120b5d3e7aac890b2dcaa94f Mon Sep 17 00:00:00 2001 From: Heitor Murilo Gomes Date: Mon, 29 Jul 2024 10:10:49 +1200 Subject: [PATCH 14/31] Revert "RW_kNN Random" --- .project | 11 ------- moa/.classpath | 22 ++----------- moa/.project | 11 ------- moa/.settings/org.eclipse.jdt.apt.core.prefs | 2 -- moa/.settings/org.eclipse.jdt.core.prefs | 11 +++---- .../java/moa/classifiers/lazy/RW_kNN.java | 4 +-- weka-package/.classpath | 32 +------------------ weka-package/.project | 11 ------- .../.settings/org.eclipse.jdt.apt.core.prefs | 2 -- .../.settings/org.eclipse.jdt.core.prefs | 11 +++---- 10 files changed, 13 insertions(+), 104 deletions(-) delete mode 100644 moa/.settings/org.eclipse.jdt.apt.core.prefs delete mode 100644 weka-package/.settings/org.eclipse.jdt.apt.core.prefs diff --git a/.project b/.project index 79e6f3bb8..76bea900b 100644 --- a/.project +++ b/.project @@ -14,15 +14,4 @@ org.eclipse.m2e.core.maven2Nature - - - 1722078006188 - - 30 - - org.eclipse.core.resources.regexFilterMatcher - node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ - - - diff --git a/moa/.classpath b/moa/.classpath index 1dfc1cab9..5411c4697 100644 --- a/moa/.classpath +++ b/moa/.classpath @@ -9,7 +9,6 @@ - @@ -23,11 +22,11 @@ - - + + @@ -36,22 +35,5 @@ - - - - - - - - - - - - - - - - - diff --git a/moa/.project b/moa/.project index fc77517ff..c5879880d 100644 --- a/moa/.project +++ b/moa/.project @@ -20,15 +20,4 @@ org.eclipse.jdt.core.javanature org.eclipse.m2e.core.maven2Nature - - - 1722078006159 - - 30 - - org.eclipse.core.resources.regexFilterMatcher - node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ - - - diff --git a/moa/.settings/org.eclipse.jdt.apt.core.prefs b/moa/.settings/org.eclipse.jdt.apt.core.prefs deleted file mode 100644 index d4313d4b2..000000000 --- a/moa/.settings/org.eclipse.jdt.apt.core.prefs +++ /dev/null @@ -1,2 +0,0 @@ -eclipse.preferences.version=1 -org.eclipse.jdt.apt.aptEnabled=false diff --git a/moa/.settings/org.eclipse.jdt.core.prefs b/moa/.settings/org.eclipse.jdt.core.prefs index ea7a397f8..43c686ff7 100644 --- a/moa/.settings/org.eclipse.jdt.core.prefs +++ b/moa/.settings/org.eclipse.jdt.core.prefs @@ -1,16 +1,13 @@ eclipse.preferences.version=1 org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled -org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=10 org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve -org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.compliance=10 org.eclipse.jdt.core.compiler.debug.lineNumber=generate org.eclipse.jdt.core.compiler.debug.localVariable=generate org.eclipse.jdt.core.compiler.debug.sourceFile=generate org.eclipse.jdt.core.compiler.problem.assertIdentifier=error -org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled org.eclipse.jdt.core.compiler.problem.enumIdentifier=error org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning -org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore -org.eclipse.jdt.core.compiler.processAnnotations=disabled -org.eclipse.jdt.core.compiler.release=disabled -org.eclipse.jdt.core.compiler.source=1.8 +org.eclipse.jdt.core.compiler.release=enabled +org.eclipse.jdt.core.compiler.source=10 diff --git a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java index 4613a9e6a..7b93a91b6 100755 --- a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java +++ b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java @@ -99,7 +99,7 @@ public void trainOnInstanceImpl(Instance inst) { this.reservoir = new Instances(inst.dataset()); } if (this.limitOptionReservoir.getValue() <= this.reservoir.numInstances()) { - int replaceIndex = this.classifierRandom.nextInt(this.limitOptionReservoir.getValue() - 1); + int replaceIndex = r.nextInt(this.limitOptionReservoir.getValue() - 1); this.reservoir.set(replaceIndex, inst); } else this.reservoir.add(inst); @@ -155,6 +155,6 @@ public void getModelDescription(StringBuilder out, int indent) { } public boolean isRandomizable() { - return true; + return false; } } \ No newline at end of file diff --git a/weka-package/.classpath b/weka-package/.classpath index decd61996..5e8a55fef 100644 --- a/weka-package/.classpath +++ b/weka-package/.classpath @@ -13,7 +13,7 @@ - + @@ -23,35 +23,5 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/weka-package/.project b/weka-package/.project index 0b9b9e1ed..f608fce93 100644 --- a/weka-package/.project +++ b/weka-package/.project @@ -20,15 +20,4 @@ org.eclipse.jdt.core.javanature org.eclipse.m2e.core.maven2Nature - - - 1722078006202 - - 30 - - org.eclipse.core.resources.regexFilterMatcher - node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ - - - diff --git a/weka-package/.settings/org.eclipse.jdt.apt.core.prefs b/weka-package/.settings/org.eclipse.jdt.apt.core.prefs deleted file mode 100644 index d4313d4b2..000000000 --- a/weka-package/.settings/org.eclipse.jdt.apt.core.prefs +++ /dev/null @@ -1,2 +0,0 @@ -eclipse.preferences.version=1 -org.eclipse.jdt.apt.aptEnabled=false diff --git a/weka-package/.settings/org.eclipse.jdt.core.prefs b/weka-package/.settings/org.eclipse.jdt.core.prefs index 46235dc07..51664a5f5 100644 --- a/weka-package/.settings/org.eclipse.jdt.core.prefs +++ b/weka-package/.settings/org.eclipse.jdt.core.prefs @@ -1,9 +1,6 @@ eclipse.preferences.version=1 -org.eclipse.jdt.core.compiler.codegen.targetPlatform=11 -org.eclipse.jdt.core.compiler.compliance=11 -org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.compliance=1.8 org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning -org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore -org.eclipse.jdt.core.compiler.processAnnotations=disabled -org.eclipse.jdt.core.compiler.release=disabled -org.eclipse.jdt.core.compiler.source=11 +org.eclipse.jdt.core.compiler.release=enabled +org.eclipse.jdt.core.compiler.source=1.8 From 25f446478ef317952a98eda252a65eb7d7839070 Mon Sep 17 00:00:00 2001 From: Spencer Sun Date: Tue, 11 Jun 2024 14:23:11 +1200 Subject: [PATCH 15/31] fix: fix the instance index for window regression evaluation --- .../moa/evaluation/WindowRegressionPerformanceEvaluator.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java index eb9745a08..23e977598 100644 --- a/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java +++ b/moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java @@ -175,6 +175,7 @@ public double getCoefficientOfDetermination() { return 0.0; } + public double getAdjustedCoefficientOfDetermination() { return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) / (getTotalWeightObserved() - numAttributes - 1); @@ -197,6 +198,7 @@ private double getRelativeSquareError() { } public double getTotalWeightObserved() { +// return this.weightObserved.total(); return this.TotalweightObserved; } From 1c0dcc6e38cadd61f70e92d6ed069b1f7c68c962 Mon Sep 17 00:00:00 2001 From: DwayneAcosta Date: Mon, 29 Jul 2024 10:30:08 +1200 Subject: [PATCH 16/31] Commented out Random r variable --- moa/src/main/java/moa/classifiers/lazy/RW_kNN.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java index 7b93a91b6..4d9601502 100755 --- a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java +++ b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java @@ -90,7 +90,7 @@ public void resetLearningImpl() { } public void trainOnInstanceImpl(Instance inst) { - Random r = new Random(); + // Random r = new Random(); if (inst.classValue() > (double)this.C) this.C = (int)inst.classValue(); From 1303e68dbdc4512aea02e0ad0d7e9482e2d11375 Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Tue, 6 Aug 2024 09:42:15 +0900 Subject: [PATCH 17/31] Initial working version of PEARL --- .../main/java/moa/classifiers/meta/PEARL.java | 1220 +++++++++++++++++ 1 file changed, 1220 insertions(+) create mode 100644 moa/src/main/java/moa/classifiers/meta/PEARL.java diff --git a/moa/src/main/java/moa/classifiers/meta/PEARL.java b/moa/src/main/java/moa/classifiers/meta/PEARL.java new file mode 100644 index 000000000..dea321ee0 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/PEARL.java @@ -0,0 +1,1220 @@ +/* + * PEARL.java + * + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package moa.classifiers.meta; + +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.AbstractMOAObject; +import moa.capabilities.CapabilitiesHandler; +import moa.capabilities.Capability; +import moa.capabilities.ImmutableCapabilities; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.MultiClassClassifier; +import moa.classifiers.core.driftdetection.ChangeDetector; +import moa.classifiers.trees.ARFHoeffdingTree; +import moa.core.*; +import moa.evaluation.BasicClassificationPerformanceEvaluator; +import moa.options.ClassOption; + +import java.util.*; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + + +/** + * PEARL + * + *

Parameters:

    + *
  • -l : Classifier to train. Must be set to ARFHoeffdingTree
  • + *
  • -s : The number of trees in the ensemble
  • + *
  • -o : How the number of features is interpreted (4 options): + * "Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)"
  • + *
  • -m : Number of features allowed considered for each split. Negative + * values corresponds to M - m
  • + *
  • -a : The lambda value for bagging (lambda=6 corresponds to levBag)
  • + *
  • -j : Number of threads to be used for training
  • + *
  • -x : Change detector for drifts and its parameters
  • + *
  • -p : Change detector for warnings (start training bkg learner)
  • + *
  • -w : Should use weighted voting?
  • + *
  • -u : Should use drift detection? If disabled then bkg learner is also disabled
  • + *
  • -q : Should use bkg learner? If disabled then reset tree immediately
  • + *
+ * + * @version $Revision: 1 $ + */ +public class PEARL extends AbstractClassifier implements MultiClassClassifier, + CapabilitiesHandler { + + @Override + public String getPurposeString() { + return "PEARL framework for evolving data streams from Wu et al."; + } + + private static final long serialVersionUID = 1L; + + public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', + "PEARL Tree.", ARFHoeffdingTree.class, + "ARFHoeffdingTree -e 2000000 -g 50 -c 0.01"); + + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of trees.", 10, 1, Integer.MAX_VALUE); + + public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', + "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", + new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", + "Percentage (M * (m / 100))"}, + new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 1); + + public IntOption mFeaturesPerTreeSizeOption = new IntOption("mFeaturesPerTreeSize", 'm', + "Number of features allowed considered for each split. Negative values corresponds to M - m", 2, Integer.MIN_VALUE, Integer.MAX_VALUE); + + public FloatOption lambdaOption = new FloatOption("lambda", 'a', + "The lambda parameter for bagging.", 6.0, 1.0, Float.MAX_VALUE); + + public IntOption numberOfJobsOption = new IntOption("numberOfJobs", 'j', + "Total number of concurrent jobs used for processing (-1 = as much as possible, 0 = do not use multithreading)", 1, -1, Integer.MAX_VALUE); + + public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', + "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5"); + + public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', + "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4"); + + public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', + "Should use weighted voting?"); + + public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', + "Should use drift detection? If disabled then bkg learner is also disabled"); + + public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', + "Should use bkg learner? If disabled then reset tree immediately."); + + public IntOption treeRepoSizeOption = new IntOption("treeRepoSize", 'r', + "The number of trees in tree pool.", 100000, 60, Integer.MAX_VALUE); + + public IntOption candidatePoolSizeOption = new IntOption("candidatePoolSize", 'c', + "The number of candidate trees.", 120, 60, Integer.MAX_VALUE); + + public FloatOption cdKappaOption = new FloatOption("cdKappa", 'k', + "The kappa parameter for tree swapping.", 0.0, 0.0, Float.MAX_VALUE); + + public IntOption editDistanceOption = new IntOption("editDistance", 'e', + "The edit distance parameter for tree swapping", 100, 60, 120); + + public IntOption lruQueueSize = new IntOption("lruQueueSize", 'f', + "The size of LRU state queue", 10000000, 100, Integer.MAX_VALUE); + + public IntOption performanceEvalWindowSize = new IntOption("performanceEvalWindowSize", 'z', + "The window size for tracking candidate trees' performance",50, 1, Integer.MAX_VALUE); + + public FlagOption enableStateGraph = new FlagOption("enableStateGraph", 'g', + "Is lossy state graph enabled"); + + public IntOption lossyWindowSizeSizeOption = new IntOption("lossyWindowSize", 'd', + "The number of trees in tree pool.", 100000, 100, Integer.MAX_VALUE); + + public FloatOption candidateTreeReuseRate = new FloatOption("candidateTreeReuseRate", 'v', + "The kappa parameter for tree swapping.", 1.0, 0.0, 1.0); + + public IntOption reuseWindowSizeOption = new IntOption("reuseWindowSize", 'y', + "The reuse window size",100000, 1, Integer.MAX_VALUE); + + protected static final int FEATURES_M = 0; + protected static final int FEATURES_SQRT = 1; + protected static final int FEATURES_SQRT_INV = 2; + protected static final int FEATURES_PERCENT = 3; + + protected static final int SINGLE_THREAD = 0; + + protected ARFBaseLearner[] ensemble; + protected long instancesSeen; + protected int subspaceSize; + protected BasicClassificationPerformanceEvaluator evaluator; + + private ExecutorService executor; + + // PEARL data structures + protected ArrayList treePool = new ArrayList<>(); + protected ArrayList candidateTrees = new ArrayList(); + protected ArrayList actualLabels = new ArrayList<>(); + protected static LRUState stateQueue; + protected SortedSet curState = new TreeSet<>(); + protected LossyStateGraph stateGraph; + protected StateGraphSwitch graphSwitch; + +// @Override +// public int getClassifierPoolSize() { +// return treePool.size(); +// } + + public int getPredictedClass(Instance inst) { + return Utils.maxIndex(getVotesForInstance(inst)); + } + + @Override + public void resetLearningImpl() { + // Reset attributes + this.ensemble = null; + this.subspaceSize = 0; + this.instancesSeen = 0; + this.evaluator = new BasicClassificationPerformanceEvaluator(); + + // Multi-threading + int numberOfJobs; + if (this.numberOfJobsOption.getValue() == -1) + numberOfJobs = Runtime.getRuntime().availableProcessors(); + else + numberOfJobs = this.numberOfJobsOption.getValue(); + // SINGLE_THREAD and requesting for only 1 thread are equivalent. + // this.executor will be null and not used... + if(numberOfJobs != PEARL.SINGLE_THREAD && numberOfJobs != 1) + this.executor = Executors.newFixedThreadPool(numberOfJobs); + } + + @Override + public void trainOnInstanceImpl(Instance instance) { + ++this.instancesSeen; + if(this.ensemble == null) + initEnsemble(instance); + + ArrayList warningTreePosList = new ArrayList<>(); + ArrayList driftedTreePosList = new ArrayList<>(); + + if (this.actualLabels.size() >= this.performanceEvalWindowSize.getValue()) { + this.actualLabels.remove(0); + } + this.actualLabels.add((int) instance.classValue()); + + Collection trainers = new ArrayList(); + for (int i = 0 ; i < this.ensemble.length ; i++) { + DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance)); + InstanceExample example = new InstanceExample(instance); + this.ensemble[i].evaluator.addResult(example, vote.getArrayRef()); + int k = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom); + if (k > 0) { + if (this.executor != null) { + TrainingRunnable trainer = new TrainingRunnable(this.ensemble[i], + instance, k, this.instancesSeen); + trainers.add(trainer); + } else { // SINGLE_THREAD is in-place... + DriftInfo driftInfo = this.ensemble[i].trainOnInstance(instance, k, this.instancesSeen); + boolean warningDetectedOnly = false; + if (driftInfo.warningDetected) { + warningDetectedOnly = true; + } + if (driftInfo.driftDetected) { + warningDetectedOnly = false; + driftedTreePosList.add(i); + } + if (warningDetectedOnly) { + warningTreePosList.add(i); + } + } + + } + } + + for (ARFBaseLearner tree : this.candidateTrees) { + // candidateTrees performs predictions to keep track of performance + if (tree.predictedLabelsWindow.size() >= performanceEvalWindowSize.getValue()) { + tree.predictedLabelsWindow.remove(0); + } + tree.predictedLabelsWindow.add(getPredictedClass(instance)); + } + + if (warningTreePosList.size() > 0) { + selectCandidateTrees(warningTreePosList); + } + + if (driftedTreePosList.size() > 0) { + adaptState(instance, driftedTreePosList); + } + + if (this.executor != null) { + try { + this.executor.invokeAll(trainers); + } catch (InterruptedException ex) { + throw new RuntimeException("Could not call invokeAll() on training threads."); + } + } + } + + private void selectCandidateTrees(ArrayList warningTreePosList) { + if (this.enableStateGraph.isSet()) { + // try trigger lossy counting + if (this.stateGraph.update(warningTreePosList.size())) { + // TODO log + } + } + + // add selected neighbors as candidate trees if graph is stable + if (this.stateGraph.getIsStable()) { + treeTransition(warningTreePosList); + } + + // trigger pattern matching if graph has become unstable + if (!this.stateGraph.getIsStable()) { + patternMatchCandidateTrees(warningTreePosList); + + } else { + // TODO log + } + } + + void patternMatchCandidateTrees(ArrayList warningTreePosList) { + + Set ids_to_exclude = new HashSet<>(); + + for (int tree_pos : warningTreePosList) { + ARFBaseLearner curTree = this.ensemble[tree_pos]; + + if (curTree.treePoolId == -1) { + System.out.println("Error: tree_pool_id is not updated"); + System.exit(1); + } + + ids_to_exclude.add(curTree.treePoolId); + } + + Set closestState = this.stateQueue.getClosestState(curState, ids_to_exclude); + + if (closestState.size() == 0) { + return; + } + + for (int i : closestState) { + if ( (i < this.curState.size()) && (i < this.treePool.size()) && !this.curState.contains(i) && !this.treePool.get(i).isCandidate) { + + if (this.candidateTrees.size() >= this.candidatePoolSizeOption.getValue()) { + this.candidateTrees.get(0).isCandidate = false; + this.candidateTrees.remove(0); + } + + this.treePool.get(i).isCandidate = true; + this.candidateTrees.add(treePool.get(i)); + } + } + } + + private void treeTransition(ArrayList warningTreePosList) { + ARFBaseLearner cur_tree; + for (int warning_tree_pos : warningTreePosList) { + cur_tree = ensemble[warning_tree_pos]; + + int warning_tree_id = cur_tree.treePoolId; + int next_id = stateGraph.get_next_tree_id(warning_tree_id); + + if (next_id == -1) { + stateGraph.set_is_stable(false); + } else { + if (!treePool.get(next_id).isCandidate) { + // TODO + if (candidateTrees.size() >= candidatePoolSizeOption.getValue()) { + candidateTrees.get(0).isCandidate = false; + candidateTrees.remove(0); + } + treePool.get(next_id).isCandidate = true; + candidateTrees.add(treePool.get(next_id)); + } + } + } + } + + private void adaptState(Instance instance, ArrayList driftedTreePosList) { + int class_count = instance.numClasses(); + + // sort candidate trees by kappa + for (ARFBaseLearner candidateTree: this.candidateTrees) { + candidateTree.updateKappa(this.actualLabels, class_count); + } + Collections.sort(this.candidateTrees, + (tree1, tree2) -> Double.compare(tree1.kappa, tree2.kappa)); + // TODO validate sorting order + + for (int i = 0; i < driftedTreePosList.size(); i++) { + // TODO + if (this.treePool.size() >= this.treeRepoSizeOption.getValue()) { + System.out.println("tree_pool full: " + this.treePool.size()); + System.exit(1); + } + + int drifted_pos = driftedTreePosList.get(i); + ARFBaseLearner drifted_tree = this.ensemble[drifted_pos]; + ARFBaseLearner swap_tree = null; + + drifted_tree.updateKappa(actualLabels, class_count); + + boolean add_to_repo = false; + + if (candidateTrees.size() > 0) { + ARFBaseLearner bestCandidate = candidateTrees.get(candidateTrees.size() - 1); + if (drifted_tree.isEvalReady + && bestCandidate.isEvalReady + && bestCandidate.kappa - drifted_tree.kappa >= cdKappaOption.getValue()) { + + bestCandidate.isCandidate = false; + swap_tree = bestCandidate; + candidateTrees.remove(candidateTrees.size() - 1); + + if (this.enableStateGraph.isSet()) { + graphSwitch.update_reuse_count(1); + } + } + } + + if (swap_tree == null) { + add_to_repo = true; + + if (this.enableStateGraph.isSet()) { + graphSwitch.update_reuse_count(0); + } + + ARFBaseLearner bkgLearner = drifted_tree.bkgLearner; + + if (bkgLearner == null) { + swap_tree = drifted_tree.makeTree(treePool.size()); + + } else { + bkgLearner.updateKappa(actualLabels, class_count); + + if (!bkgLearner.isEvalReady || !drifted_tree.isEvalReady + || bkgLearner.kappa - drifted_tree.kappa >= 0.0) { + // TODO 0.0: bg_kappa_threshold + swap_tree = bkgLearner; + + } else { + // bg tree is a false positive + add_to_repo = false; + + } + } + + if (add_to_repo) { + swap_tree.reset(false); + + // assign a new tree_pool_id for background tree + // and allocate a slot for background tree in tree_pool + swap_tree.treePoolId = treePool.size(); + treePool.add(swap_tree); + + } + } + + if (swap_tree != null) { + // update current state pattern + curState.remove(drifted_tree.treePoolId); + curState.add(swap_tree.treePoolId); + + // replace drifted_tree with swap tree + swap_tree.isBackgroundLearner = false; + ensemble[drifted_pos] = swap_tree; + + if (this.enableStateGraph.isSet()) { + stateGraph.add_edge(drifted_tree.treePoolId, swap_tree.treePoolId); + } + + } + + drifted_tree.reset(false); + } + + this.stateQueue.enqueue(new TreeSet<>(curState)); + + if (this.enableStateGraph.isSet()) { + this.graphSwitch.update_switch(); + } + } + + @Override + public double[] getVotesForInstance(Instance instance) { + Instance testInstance = instance.copy(); + if(this.ensemble == null) + initEnsemble(testInstance); + DoubleVector combinedVote = new DoubleVector(); + + for(int i = 0 ; i < this.ensemble.length ; ++i) { + DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(testInstance)); + if (vote.sumOfValues() > 0.0) { + vote.normalize(); + double acc = this.ensemble[i].evaluator.getPerformanceMeasurements()[1].getValue(); + if(! this.disableWeightedVote.isSet() && acc > 0.0) { + for(int v = 0 ; v < vote.numValues() ; ++v) { + vote.setValue(v, vote.getValue(v) * acc); + } + } + combinedVote.addValues(vote); + } + } + return combinedVote.getArrayRef(); + } + + @Override + public boolean isRandomizable() { + return true; + } + + @Override + public void getModelDescription(StringBuilder arg0, int arg1) { + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + protected void initEnsemble(Instance instance) { + // Init the ensemble. + int ensembleSize = this.ensembleSizeOption.getValue(); + this.ensemble = new ARFBaseLearner[ensembleSize]; + + // TODO: this should be an option with default = BasicClassificationPerformanceEvaluator +// BasicClassificationPerformanceEvaluator classificationEvaluator = (BasicClassificationPerformanceEvaluator) getPreparedClassOption(this.evaluatorOption); + BasicClassificationPerformanceEvaluator classificationEvaluator = new BasicClassificationPerformanceEvaluator(); + + this.subspaceSize = this.mFeaturesPerTreeSizeOption.getValue(); + + // The size of m depends on: + // 1) mFeaturesPerTreeSizeOption + // 2) mFeaturesModeOption + int n = instance.numAttributes()-1; // Ignore class label ( -1 ) + + switch(this.mFeaturesModeOption.getChosenIndex()) { + case PEARL.FEATURES_SQRT: + this.subspaceSize = (int) Math.round(Math.sqrt(n)) + 1; + break; + case PEARL.FEATURES_SQRT_INV: + this.subspaceSize = n - (int) Math.round(Math.sqrt(n) + 1); + break; + case PEARL.FEATURES_PERCENT: + // If subspaceSize is negative, then first find out the actual percent, i.e., 100% - m. + double percent = this.subspaceSize < 0 ? (100 + this.subspaceSize)/100.0 : this.subspaceSize / 100.0; + this.subspaceSize = (int) Math.round(n * percent); + break; + } + // Notice that if the selected mFeaturesModeOption was + // AdaptiveRandomForest.FEATURES_M then nothing is performed in the + // previous switch-case, still it is necessary to check (and adjusted) + // for when a negative value was used. + + // m is negative, use size(features) + -m + if(this.subspaceSize < 0) + this.subspaceSize = n + this.subspaceSize; + // Other sanity checks to avoid runtime errors. + // m <= 0 (m can be negative if this.subspace was negative and + // abs(m) > n), then use m = 1 + if(this.subspaceSize <= 0) + this.subspaceSize = 1; + // m > n, then it should use n + if(this.subspaceSize > n) + this.subspaceSize = n; + + this.stateQueue = new LRUState(lruQueueSize.getValue(), this.editDistanceOption.getValue()); + + this.stateGraph = new LossyStateGraph(this.treeRepoSizeOption.getValue(), + this.lossyWindowSizeSizeOption.getValue()); + this.graphSwitch = new StateGraphSwitch(this.stateGraph, + this.reuseWindowSizeOption.getValue(), + this.candidateTreeReuseRate.getValue()); + + ARFHoeffdingTree treeLearner = (ARFHoeffdingTree) getPreparedClassOption(this.treeLearnerOption); + treeLearner.resetLearning(); + + for (int i = 0; i < ensembleSize; ++i) { + // treeLearner.setRandomSeed(this.classifierRandom.nextInt()); + // treeLearner.resetClassifierRandom(); + + treeLearner.subspaceSizeOption.setValue(this.subspaceSize); + this.ensemble[i] = new ARFBaseLearner( + i, + (ARFHoeffdingTree) treeLearner.copy(), + (BasicClassificationPerformanceEvaluator) classificationEvaluator.copy(), + this.instancesSeen, + ! this.disableBackgroundLearnerOption.isSet(), + ! this.disableDriftDetectionOption.isSet(), + driftDetectionMethodOption, + warningDetectionMethodOption, + false); + + this.treePool.add(this.ensemble[i]); + this.curState.add(i); + this.ensemble[i].treePoolId = i; + } + + this.stateQueue.enqueue(new TreeSet<>(curState)); + } + + @Override + public ImmutableCapabilities defineImmutableCapabilities() { + if (this.getClass() == PEARL.class) + return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE); + else + return new ImmutableCapabilities(Capability.VIEW_STANDARD); + } + + /** + * Inner class that represents a single tree member of the forest. + * It contains some analysis information, such as the numberOfDriftsDetected, + */ + protected final class ARFBaseLearner extends AbstractMOAObject { + public int indexOriginal; + public long createdOn; + public long lastDriftOn; + public long lastWarningOn; + public ARFHoeffdingTree classifier; + public boolean isBackgroundLearner; + + // The drift and warning object parameters. + protected ClassOption driftOption; + protected ClassOption warningOption; + + // Drift and warning detection + protected ChangeDetector driftDetectionMethod; + protected ChangeDetector warningDetectionMethod; + + public boolean useBkgLearner; + public boolean useDriftDetector; + + // Bkg learner + protected ARFBaseLearner bkgLearner; + // Statistics + public BasicClassificationPerformanceEvaluator evaluator; + protected int numberOfDriftsDetected; + protected int numberOfWarningsDetected; + + // PEARL specific + protected int treePoolId; + protected double kappa; + protected boolean isCandidate; + protected boolean isEvalReady; + protected ArrayList predictedLabelsWindow; + + private void init(int indexOriginal, ARFHoeffdingTree instantiatedClassifier, BasicClassificationPerformanceEvaluator evaluatorInstantiated, + long instancesSeen, boolean useBkgLearner, boolean useDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) { + this.indexOriginal = indexOriginal; + this.createdOn = instancesSeen; + this.lastDriftOn = 0; + this.lastWarningOn = 0; + + this.classifier = instantiatedClassifier; + this.evaluator = evaluatorInstantiated; + this.useBkgLearner = useBkgLearner; + this.useDriftDetector = useDriftDetector; + + this.numberOfDriftsDetected = 0; + this.numberOfWarningsDetected = 0; + this.isBackgroundLearner = isBackgroundLearner; + + if(this.useDriftDetector) { + this.driftOption = driftOption; + this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy(); + } + + // Init Drift Detector for Warning detection. + if(this.useBkgLearner) { + this.warningOption = warningOption; + this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy(); + } + + this.treePoolId = -1; + this.kappa = Integer.MIN_VALUE; + this.isCandidate = false; + this.isEvalReady = false; + this.predictedLabelsWindow = new ArrayList<>(); + } + + public ARFBaseLearner(int indexOriginal, ARFHoeffdingTree instantiatedClassifier, BasicClassificationPerformanceEvaluator evaluatorInstantiated, + long instancesSeen, boolean useBkgLearner, boolean useDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) { + init(indexOriginal, instantiatedClassifier, evaluatorInstantiated, instancesSeen, useBkgLearner, useDriftDetector, driftOption, warningOption, isBackgroundLearner); + } + + public void reset(boolean keepDriftDetectors) { + if (!keepDriftDetectors) { + this.driftDetectionMethod.resetLearning(); + this.warningDetectionMethod.resetLearning(); + } + reset(); + } + + public void reset() { + this.bkgLearner = null; + this.evaluator.reset(); + + this.isCandidate = false; + this.predictedLabelsWindow.clear(); + this.kappa = Integer.MIN_VALUE; + this.isEvalReady = false; + } + + public ARFBaseLearner makeTree(int treeId) { + ARFHoeffdingTree bkgClassifier = (ARFHoeffdingTree) this.classifier.copy(); + bkgClassifier.resetLearning(); + // bkgClassifier.setRandomSeed(this.classifier.classifierRandom.nextInt()); + // bkgClassifier.resetClassifierRandom(); + BasicClassificationPerformanceEvaluator bkgEvaluator = (BasicClassificationPerformanceEvaluator) this.evaluator.copy(); + bkgEvaluator.reset(); + ARFBaseLearner newTree = new ARFBaseLearner(this.indexOriginal, bkgClassifier, bkgEvaluator, instancesSeen, + this.useBkgLearner, this.useDriftDetector, this.driftOption, this.warningOption, false); + newTree.treePoolId = treeId; + return newTree; + } + + public DriftInfo trainOnInstance(Instance instance, double weight, long instancesSeen) { + Instance weightedInstance = (Instance) instance.copy(); + weightedInstance.setWeight(instance.weight() * weight); + this.classifier.trainOnInstance(weightedInstance); + + // train bg tree and track its performance + if (this.bkgLearner != null) { + this.bkgLearner.classifier.trainOnInstance(instance); + int prediction = getPredictedClass(instance); + if (this.bkgLearner.predictedLabelsWindow.size() >= performanceEvalWindowSize.getValue()) { + this.bkgLearner.predictedLabelsWindow.remove(0); + } + this.bkgLearner.predictedLabelsWindow.add(prediction); + } + + boolean warningDetected = false; + boolean driftDetected = false; + + // Should it use a drift detector? Also, is it a backgroundLearner? If so, then do not "incept" another one. + if (this.useDriftDetector && !this.isBackgroundLearner) { + // boolean correctlyClassifies = this.classifier.correctlyClassifies(instance); + int prediction = Utils.maxIndex(getVotesForInstance(instance)); + boolean correctlyClassifies = prediction == (int) instance.classValue(); + + if (this.predictedLabelsWindow.size() >= performanceEvalWindowSize.getValue()) { + this.predictedLabelsWindow.remove(0); + } + this.predictedLabelsWindow.add(prediction); + + // Check for warning only if useBkgLearner is active + if (this.useBkgLearner) { + // Update the warning detection method + this.warningDetectionMethod.input(correctlyClassifies ? 0 : 1); + // Check if there was a change + if (this.warningDetectionMethod.getChange()) { + warningDetected = true; + this.lastWarningOn = instancesSeen; + this.numberOfWarningsDetected++; + // Create a new bkgTree classifier + ARFHoeffdingTree bkgClassifier = (ARFHoeffdingTree) this.classifier.copy(); + bkgClassifier.resetLearning(); + + // Resets the evaluator + BasicClassificationPerformanceEvaluator bkgEvaluator = (BasicClassificationPerformanceEvaluator) this.evaluator.copy(); + bkgEvaluator.reset(); + + // Create a new bkgLearner object + this.bkgLearner = new ARFBaseLearner(indexOriginal, bkgClassifier, bkgEvaluator, instancesSeen, + this.useBkgLearner, this.useDriftDetector, this.driftOption, this.warningOption, true); + // this.bkgLearner.classifier.setRandomSeed(this.classifier.classifierRandom.nextInt()); + // this.bkgLearner.classifier.resetClassifierRandom(); + + // Update the warning detection object for the current object + // (this effectively resets changes made to the object while it was still a bkg learner). + this.warningDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.warningOption)).copy(); + } + } + + /*********** drift detection ***********/ + + // Update the DRIFT detection method + this.driftDetectionMethod.input(correctlyClassifies ? 0 : 1); + // Check if there was a change + if (this.driftDetectionMethod.getChange()) { + warningDetected = false; + driftDetected = true; + this.lastDriftOn = instancesSeen; + this.numberOfDriftsDetected++; + // this.reset(); + } + } + + return new DriftInfo(warningDetected, driftDetected, getPredictedClass(instance)); + } + + public void updateKappa(ArrayList actualLabels, int classCount) { + if (predictedLabelsWindow.size() < performanceEvalWindowSize.getValue() + || actualLabels.size() < performanceEvalWindowSize.getValue()) { + this.isEvalReady = false; + return; + } + + this.isEvalReady = true; + + int[][] confusionMatrix = new int[classCount][classCount]; + int correct = 0; + + for (int i = 0; i < performanceEvalWindowSize.getValue(); i++) { + confusionMatrix[actualLabels.get(i)][predictedLabelsWindow.get(i)]++; + if (actualLabels.get(i) == predictedLabelsWindow.get(i)) { + correct++; + } + } + + double accuracy = (double) correct / performanceEvalWindowSize.getValue(); + this.kappa = computeKappa(confusionMatrix, accuracy, performanceEvalWindowSize.getValue(), classCount); + } + + private double computeKappa(int[][] confusionMatrix, double accuracy, int sample_count, int classCount) { + // computes the Cohen's kappa coefficient + double p0 = accuracy; + double pc = 0.0; + int row_count = classCount; + int col_count = classCount; + + for (int i = 0; i < row_count; i++) { + double row_sum = 0; + for (int j = 0; j < col_count; j++) { + row_sum += confusionMatrix[i][j]; + } + + double col_sum = 0; + for (int j = 0; j < row_count; j++) { + col_sum += confusionMatrix[j][i]; + } + + pc += (row_sum / sample_count) * (col_sum / sample_count); + } + + if (pc == 1) { + return 1; + } + + return (p0 - pc) / (1.0 - pc); + } + + public double[] getVotesForInstance(Instance instance) { + DoubleVector vote = new DoubleVector(this.classifier.getVotesForInstance(instance)); + return vote.getArrayRef(); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + + // @Override + // public void getDescription(StringBuilder sb, int indent) { + // } + } + + /*** + * Inner class to assist with the multi-thread execution. + */ + protected class TrainingRunnable implements Runnable, Callable { + final private ARFBaseLearner learner; + final private Instance instance; + final private double weight; + final private long instancesSeen; + + public TrainingRunnable(ARFBaseLearner learner, Instance instance, + double weight, long instancesSeen) { + this.learner = learner; + this.instance = instance; + this.weight = weight; + this.instancesSeen = instancesSeen; + } + + @Override + public void run() { + learner.trainOnInstance(this.instance, this.weight, this.instancesSeen); + } + + @Override + public Integer call() throws Exception { + run(); + return 0; + } + } + + protected class LRUState extends AbstractMOAObject { + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + + class State extends AbstractMOAObject{ + SortedSet pattern; + public int freq; + + State(SortedSet pattern, int freq) { + this.pattern = pattern; + this.freq = freq; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + } + + int capacity; + int editDistanceThreshold; + LinkedHashMap map; + + protected LRUState(int capacity, int editDistanceThreshold) { + this.capacity = capacity; + this.editDistanceThreshold = editDistanceThreshold; + this.map = new LinkedHashMap(capacity, 0.75f, true){ + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > capacity; + } + }; + } + + Set getClosestState(Set targetPattern, Set idsToExclude) { + int minEditDistance = Integer.MAX_VALUE; + int maxFreq = 0; + Set closestPattern = new HashSet<>(); + + // find the smallest edit distance + for (Map.Entry entry : map.entrySet()) { + State curState = entry.getValue(); + Set curPattern = curState.pattern; + + int curFreq = curState.freq; + int curEditDistance = 0; + + boolean updateFlag = true; + for (int id : idsToExclude) { + if (curPattern.contains(id)) { + // tree with drift must be unset + updateFlag = false; + break; + } + } + + if (updateFlag) { + for (int id : targetPattern) { + if (!curPattern.contains(id)) { + curEditDistance += 2; + } + + if (curEditDistance > editDistanceThreshold + || curEditDistance > minEditDistance) { + updateFlag = false; + break; + } + } + } + + if (!updateFlag) { + continue; + } + + if (minEditDistance == curEditDistance && curFreq < maxFreq) { + continue; + } + + minEditDistance = curEditDistance; + maxFreq = curFreq; + closestPattern = curPattern; + } + + return closestPattern; + } + + public void enqueue(SortedSet pattern) { + String key = patternToKey(pattern); + + if (map.containsKey(key)) { + State state = map.get(key); + state.freq++; + + } else { + map.put(key, new State(pattern, 1)); + } + } + + String patternToKey(SortedSet pattern) { + StringBuilder sb = new StringBuilder(); + for (int i : pattern) { + sb.append(i); + sb.append(","); + } + + return sb.toString(); + } + + public String toString() { + String str = ""; + + for (Map.Entry entry : map.entrySet()) { + State s = entry.getValue(); + Set cur_pattern = s.pattern; + String freq = "" + s.freq; + + String delim = ""; + for (int i : cur_pattern) { + str += delim; + str += i; + delim = ","; + } + str += ":" + freq + "->"; + } + + return str; + } + } + + + class LossyStateGraph extends AbstractMOAObject{ + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + + class Node extends AbstractMOAObject{ + int indegree; + int total_weight; + Map neighbors; // + public Node() { + this.indegree = 0; + this.total_weight = 0; + this.neighbors = new HashMap<>(); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + }; + + Node[] graph; + int capacity; + int window_size; + Random mrand; + + int drifted_tree_counter = 0; + boolean is_stable = false; + + LossyStateGraph(int capacity, int window_size) { + this.capacity = capacity; + this.window_size = window_size; + this.mrand = new Random(); + + is_stable = false; + graph = new Node[capacity]; + } + + int get_next_tree_id(int src) { + if (graph[src] == null|| graph[src].total_weight == 0){ + return -1; + } + + int r = mrand.nextInt(graph[src].total_weight + 1); + int sum = 0; + + // weighted selection + for (Map.Entry nei : graph[src].neighbors.entrySet()){ + int treeId = nei.getKey(); + int freq = nei.getValue(); + sum += freq; + if (r < sum) { + nei.setValue(nei.getValue() + 1); + graph[src].total_weight++; + return treeId; + } + } + + return -1; + } + + boolean update(int warning_tree_count) { + drifted_tree_counter += warning_tree_count; + + if (drifted_tree_counter < window_size) { + return false; + } + + drifted_tree_counter -= window_size; + + // lossy count + for (int i = 0; i < graph.length; i++) { + if (graph[i] != null) { + continue; + } + + ArrayList keys_to_remove = new ArrayList<>(); + + for (Map.Entry nei : graph[i].neighbors.entrySet()){ + int treeId = nei.getKey(); + int freq = nei.getValue(); + + // decrement freq by 1 + graph[i].total_weight--; + nei.setValue(nei.getValue() - 1); // decrement freq + + if (freq == 0) { + // remove edge + graph[treeId].indegree--; + try_remove_node(treeId); + + keys_to_remove.add(treeId); + } + } + + for (int key : keys_to_remove){ + graph[i].neighbors.remove(key); + } + + try_remove_node(i); + } + + return true; + } + + void try_remove_node(int key) { + if (graph[key].indegree == 0 && graph[key].neighbors.size() == 0){ + graph[key] = null; + } + } + + void add_node(int key) { + if (key >= capacity) { + // System.out.println("id exceeded graph capacity"); + return; + } + + graph[key] = new Node(); + } + + void add_edge(int src, int dest) { + if (graph[src] == null) { + add_node(src); + } + + if (graph[dest] == null) { + add_node(dest); + } + + graph[src].total_weight++; + + if (!graph[src].neighbors.containsKey(dest)) { + graph[src].neighbors.put(dest, 0); + graph[dest].indegree++; + } + + graph[src].neighbors.put(dest, graph[src].neighbors.get(dest) + 1); + } + + void set_is_stable(boolean is_stable_) { + is_stable = is_stable_; + } + + boolean getIsStable() { + return is_stable; + } + + // String to_string() { + // stringstream ss; + // for (int i = 0; i < graph.size(); i++) { + // ss << i; + // if (!graph[i]) { + // ss << " {}" << endl; + // continue; + // } + + // ss << " w:" << std::to_string (graph[i]->total_weight) <<" {"; + // for (auto & nei :graph[i]->neighbors){ + // ss << std::to_string (nei.first) << ":" << std::to_string (nei.second) << " "; + // } + // ss << "}" << endl; + // } + + // return ss.str(); + // } + } + + class StateGraphSwitch extends AbstractMOAObject{ + + int window_size = 0; + int reused_tree_count = 0; + int total_tree_count = 0; + double reuse_rate = 1.0; + + LossyStateGraph state_graph; + ArrayDeque window; + + public StateGraphSwitch(LossyStateGraph state_graph, + int window_size, + double reuse_rate) { + this.state_graph = state_graph; + this.window_size = window_size; + this.reuse_rate = reuse_rate; + this.window = new ArrayDeque<>(); + } + + void update_reuse_count(int num_reused_trees) { + reused_tree_count += num_reused_trees; + total_tree_count++; + + if (window_size <= 0) { + return; + } + + if (window.size() >= window_size) { + reused_tree_count -= window.poll(); + } + + window.offer(num_reused_trees); + } + + void update_switch() { + double cur_reuse_rate = 0; + if (window_size <= 0) { + cur_reuse_rate = (double) reused_tree_count / total_tree_count; + } else { + cur_reuse_rate = (double) reused_tree_count / window_size; + } + + // cout << "reused_tree_count: " << to_string(reused_tree_count) << endl; + // cout << "total_tree_count: " << to_string(total_tree_count) << endl; + // cout << "cur_reuse_rate: " << to_string(cur_reuse_rate) << endl; + + if (cur_reuse_rate >= reuse_rate) { + state_graph.set_is_stable(true); + } else { + state_graph.set_is_stable(false); + } + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } + } + + public class DriftInfo { + public boolean warningDetected; + public boolean driftDetected; + public int predictedClass; + public DriftInfo(boolean warningDetected, boolean driftDetected, int predictedClass) { + this.warningDetected = warningDetected; + this.driftDetected = driftDetected; + this.predictedClass = predictedClass; + } + } +} From b41578882a05f558830832e21f934ad41e835077 Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Wed, 7 Aug 2024 07:02:20 +0900 Subject: [PATCH 18/31] set default parameters as per the paper --- .../main/java/moa/classifiers/meta/PEARL.java | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/meta/PEARL.java b/moa/src/main/java/moa/classifiers/meta/PEARL.java index dea321ee0..f6945f0eb 100644 --- a/moa/src/main/java/moa/classifiers/meta/PEARL.java +++ b/moa/src/main/java/moa/classifiers/meta/PEARL.java @@ -40,8 +40,19 @@ import java.util.concurrent.Executors; + /** - * PEARL + * Probabilistic Exact Adaptive Random Forest with Lossy Counting (PEARL) + * + *

Probabilistic Exact Adaptive Random Forest with Lossy Counting (PEARL), + * utilises an exact technique and a probabilistic technique to replace drifted + * trees in an ensemble with relevant trees from a repository of trees for + * learning from a data stream with recurrent concept drifts.

+ * + *

See details in:
Ocean Wu, Yun Sing Koh, Gillian Dobbie, Thomas Lacombe + * Probabilistic exact adaptive random forest for recurrent concepts in data streams. + * International Journal of Data Science and Analytics, 2022. + * DOI.

* *

Parameters:

    *
  • -l : Classifier to train. Must be set to ARFHoeffdingTree
  • @@ -57,8 +68,19 @@ *
  • -w : Should use weighted voting?
  • *
  • -u : Should use drift detection? If disabled then bkg learner is also disabled
  • *
  • -q : Should use bkg learner? If disabled then reset tree immediately
  • + *
  • -r : The number of trees in tree pool (default 100000)
  • + *
  • -c : The number of candidate trees. (default 120)
  • + *
  • -k : The kappa parameter for tree swapping. (default 0.0)
  • + *
  • -e : The edit distance parameter for tree swapping (default 100)
  • + *
  • -f : The size of LRU state queue (default 10000000)
  • + *
  • -z : The window size for tracking candidate trees' performance (default 50)
  • + *
  • -g : Enable lossy state graph
  • + *
  • -d : Lossy window size (default 100000)
  • + *
  • -v : Candidate tree reuse rate (default 1.0 [0.0 - 1.0])
  • + *
  • -y : The reuse window size (default 100000)
  • *
* + * BEST values for parameters are not available in the * @version $Revision: 1 $ */ public class PEARL extends AbstractClassifier implements MultiClassClassifier, @@ -70,22 +92,22 @@ public String getPurposeString() { } private static final long serialVersionUID = 1L; - +// ##### START of ARF parameters (unchanged for fair comparission with ARF) public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', - "PEARL Tree.", ARFHoeffdingTree.class, + "Random Forest Tree.", ARFHoeffdingTree.class, "ARFHoeffdingTree -e 2000000 -g 50 -c 0.01"); public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', - "The number of trees.", 10, 1, Integer.MAX_VALUE); + "The number of trees.", 100, 1, Integer.MAX_VALUE); public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, - new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 1); + new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 3); public IntOption mFeaturesPerTreeSizeOption = new IntOption("mFeaturesPerTreeSize", 'm', - "Number of features allowed considered for each split. Negative values corresponds to M - m", 2, Integer.MIN_VALUE, Integer.MAX_VALUE); + "Number of features allowed considered for each split. Negative values corresponds to M - m", 60, Integer.MIN_VALUE, Integer.MAX_VALUE); public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0, 1.0, Float.MAX_VALUE); @@ -94,10 +116,10 @@ public String getPurposeString() { "Total number of concurrent jobs used for processing (-1 = as much as possible, 0 = do not use multithreading)", 1, -1, Integer.MAX_VALUE); public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', - "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5"); + "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-3"); public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', - "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4"); + "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-2"); public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', "Should use weighted voting?"); @@ -107,6 +129,7 @@ public String getPurposeString() { public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled then reset tree immediately."); +// ##### END of ARF parameters (unchanged for fair comparission with ARF) public IntOption treeRepoSizeOption = new IntOption("treeRepoSize", 'r', "The number of trees in tree pool.", 100000, 60, Integer.MAX_VALUE); @@ -130,10 +153,12 @@ public String getPurposeString() { "Is lossy state graph enabled"); public IntOption lossyWindowSizeSizeOption = new IntOption("lossyWindowSize", 'd', - "The number of trees in tree pool.", 100000, 100, Integer.MAX_VALUE); + "Lossy window size", 100000, 100, Integer.MAX_VALUE); + // according to the paper though v does not have a significan effect on predictive capability of the model + // using the default used in the Agrawal experiments https://github.com/ingako/PEARL/blob/master/run/run-agrawal-3.sh public FloatOption candidateTreeReuseRate = new FloatOption("candidateTreeReuseRate", 'v', - "The kappa parameter for tree swapping.", 1.0, 0.0, 1.0); + "Candidate tree reuse rate.", 0.4, 0.0, 1.0); public IntOption reuseWindowSizeOption = new IntOption("reuseWindowSize", 'y', "The reuse window size",100000, 1, Integer.MAX_VALUE); From 04c7b3253888f858ad05e900f91ad58b9888df0e Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Wed, 7 Aug 2024 07:38:18 +0900 Subject: [PATCH 19/31] Add PEARL tests --- .../java/moa/classifiers/meta/PEARLTest.java | 81 ++++++++++ .../resources/moa/classifiers/meta/PEARL.ref | 145 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 moa/src/test/java/moa/classifiers/meta/PEARLTest.java create mode 100644 moa/src/test/resources/moa/classifiers/meta/PEARL.ref diff --git a/moa/src/test/java/moa/classifiers/meta/PEARLTest.java b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java new file mode 100644 index 000000000..836986cb9 --- /dev/null +++ b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java @@ -0,0 +1,81 @@ +/* + * AdaptiveRandomForestTest.java + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +/** + * + */ +package moa.classifiers.meta; + +import junit.framework.Test; +import junit.framework.TestSuite; +import moa.classifiers.AbstractMultipleClassifierTestCase; +import moa.classifiers.Classifier; + +/** + * Tests the PEARL classifier. + * + * @author Nuwan Gunasekara (ng98 at students dot waikato dot ac dot nz) + * @version $Revision$ + */ +public class PEARLTest + extends AbstractMultipleClassifierTestCase { + + /** + * Constructs the test case. Called by subclasses. + * + * @param name the name of the test + */ + public PEARLTest(String name) { + super(name); + this.setNumberTests(1); + } + + /** + * Returns the classifier setups to use in the regression test. + * + * @return the setups + */ + @Override + protected Classifier[] getRegressionClassifierSetups() { + PEARL PEARLTest = new PEARL(); + PEARLTest.ensembleSizeOption.setValue(5); + PEARLTest.mFeaturesModeOption.setChosenIndex(0); + PEARLTest.mFeaturesPerTreeSizeOption.setValue(2); + + return new Classifier[]{ + PEARLTest, + }; + } + + /** + * Returns a test suite. + * + * @return the test suite + */ + public static Test suite() { + return new TestSuite(PEARLTest.class); + } + + /** + * Runs the test from commandline. + * + * @param args ignored + */ + public static void main(String[] args) { + runTest(suite()); + } +} diff --git a/moa/src/test/resources/moa/classifiers/meta/PEARL.ref b/moa/src/test/resources/moa/classifiers/meta/PEARL.ref new file mode 100644 index 000000000..202a255c5 --- /dev/null +++ b/moa/src/test/resources/moa/classifiers/meta/PEARL.ref @@ -0,0 +1,145 @@ +--> classification-out0.arff +moa.classifiers.meta.PEARL -s 5 -o (Specified m (integer value)) -m 2 -x (ADWINChangeDetector -a 0.001) -p (ADWINChangeDetector -a 0.01) + +Index + 10000 +Votes + 0: 201.05038779 + 1: 71.46686394 +Measurements + classified instances: 9999 + classifications correct (percent): 76.07760776 + Kappa Statistic (percent): 49.28392856 + Kappa Temporal Statistic (percent): 49.57841484 + Kappa M Statistic (percent): 41.58730159 +Model measurements + model training instances: 9999 + +Index + 20000 +Votes + 0: 38.6256937 + 1: 312.13184418 +Measurements + classified instances: 19999 + classifications correct (percent): 78.18890945 + Kappa Statistic (percent): 54.27721317 + Kappa Temporal Statistic (percent): 54.46288757 + Kappa M Statistic (percent): 47.7103812 +Model measurements + model training instances: 19999 + +Index + 30000 +Votes + 0: 2.11253342 + 1: 354.36268242 +Measurements + classified instances: 29999 + classifications correct (percent): 79.79932664 + Kappa Statistic (percent): 57.69378087 + Kappa Temporal Statistic (percent): 58.14049872 + Kappa M Statistic (percent): 51.47341448 +Model measurements + model training instances: 29999 + +Index + 40000 +Votes + 0: 5.27561538 + 1: 281.57905599 +Measurements + classified instances: 39999 + classifications correct (percent): 80.73701843 + Kappa Statistic (percent): 59.76778227 + Kappa Temporal Statistic (percent): 60.18293628 + Kappa M Statistic (percent): 53.91746411 +Model measurements + model training instances: 39999 + +Index + 50000 +Votes + 0: 203.06408337 + 1: 86.63171055 +Measurements + classified instances: 49999 + classifications correct (percent): 81.39562791 + Kappa Statistic (percent): 61.17254121 + Kappa Temporal Statistic (percent): 61.49515688 + Kappa M Statistic (percent): 55.48430322 +Model measurements + model training instances: 49999 + +Index + 60000 +Votes + 0: 5.48964024 + 1: 360.37979092 +Measurements + classified instances: 59999 + classifications correct (percent): 81.91136519 + Kappa Statistic (percent): 62.34778415 + Kappa Temporal Statistic (percent): 62.59520937 + Kappa M Statistic (percent): 56.93595746 +Model measurements + model training instances: 59999 + +Index + 70000 +Votes + 0: 3.2953865 + 1: 364.58558323 +Measurements + classified instances: 69999 + classifications correct (percent): 82.35689081 + Kappa Statistic (percent): 63.32418526 + Kappa Temporal Statistic (percent): 63.6326158 + Kappa M Statistic (percent): 58.14126898 +Model measurements + model training instances: 69999 + +Index + 80000 +Votes + 0: 171.67957388 + 1: 197.86379541 +Measurements + classified instances: 79999 + classifications correct (percent): 82.71728397 + Kappa Statistic (percent): 64.09329532 + Kappa Temporal Statistic (percent): 64.41092435 + Kappa M Statistic (percent): 58.99276308 +Model measurements + model training instances: 79999 + +Index + 90000 +Votes + 0: 2.74400602 + 1: 368.37233971 +Measurements + classified instances: 89999 + classifications correct (percent): 83.05203391 + Kappa Statistic (percent): 64.80130245 + Kappa Temporal Statistic (percent): 65.08412499 + Kappa M Statistic (percent): 59.77584388 +Model measurements + model training instances: 89999 + +Index + 100000 +Votes + 0: 357.30819817 + 1: 15.06752559 +Measurements + classified instances: 99999 + classifications correct (percent): 83.31583316 + Kappa Statistic (percent): 65.37306693 + Kappa Temporal Statistic (percent): 65.63117996 + Kappa M Statistic (percent): 60.44758428 +Model measurements + model training instances: 99999 + + + From a9769a4d05a4e110adbfea747949cb4bb8e88858 Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Wed, 7 Aug 2024 07:44:22 +0900 Subject: [PATCH 20/31] Add PEARL tests --- .../java/moa/classifiers/meta/AdaptiveRandomForestTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java b/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java index 351aa6d01..baa0f1c30 100644 --- a/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java +++ b/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java @@ -1,5 +1,5 @@ /* - * AdaptiveRandomForestTest.java + * PEARLTest.java * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by From 3617fc273b61c6a73578d30e6e14c783d0d517f6 Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Wed, 7 Aug 2024 07:46:07 +0900 Subject: [PATCH 21/31] Revert changes ARF tests This reverts commit a2c690c3123c408123190ac8e61ca9542945aafc. --- .../java/moa/classifiers/meta/AdaptiveRandomForestTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java b/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java index baa0f1c30..351aa6d01 100644 --- a/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java +++ b/moa/src/test/java/moa/classifiers/meta/AdaptiveRandomForestTest.java @@ -1,5 +1,5 @@ /* - * PEARLTest.java + * AdaptiveRandomForestTest.java * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by From ec8da623312107869f2af1bef3b37a55938ea756 Mon Sep 17 00:00:00 2001 From: nuwangunasekara Date: Wed, 7 Aug 2024 07:48:38 +0900 Subject: [PATCH 22/31] Change file name at the comment --- moa/src/test/java/moa/classifiers/meta/PEARLTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moa/src/test/java/moa/classifiers/meta/PEARLTest.java b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java index 836986cb9..cc53afc78 100644 --- a/moa/src/test/java/moa/classifiers/meta/PEARLTest.java +++ b/moa/src/test/java/moa/classifiers/meta/PEARLTest.java @@ -1,5 +1,5 @@ /* - * AdaptiveRandomForestTest.java + * PEARLTest.java * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by From 37c14c55b332fb6338658f8ca2ef30cafc615edc Mon Sep 17 00:00:00 2001 From: cassales Date: Tue, 20 Aug 2024 14:25:13 +1200 Subject: [PATCH 23/31] fix: removing a debug print in DBSCAN --- moa/src/main/java/moa/clusterers/macro/dbscan/DBScan.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moa/src/main/java/moa/clusterers/macro/dbscan/DBScan.java b/moa/src/main/java/moa/clusterers/macro/dbscan/DBScan.java index 74efb2a5c..77941ac14 100644 --- a/moa/src/main/java/moa/clusterers/macro/dbscan/DBScan.java +++ b/moa/src/main/java/moa/clusterers/macro/dbscan/DBScan.java @@ -171,8 +171,8 @@ public Clustering getClustering(Clustering microClusters) { noise++; } } - System.out.println("microclusters which are not clustered:: " - + noise); +// System.out.println("microclusters which are not clustered:: " +// + noise); Clustering result = new Clustering(res); setClusterIDs(result); // int i = 0; From de4c132b8d5157bb5e4360ed1b4f535dea80f2b1 Mon Sep 17 00:00:00 2001 From: Maroua Bahri Date: Wed, 4 Sep 2024 19:32:16 +0200 Subject: [PATCH 24/31] add autoclass for autoML --- .DS_Store | Bin 0 -> 6148 bytes moa/.DS_Store | Bin 0 -> 6148 bytes moa/src/.DS_Store | Bin 0 -> 6148 bytes moa/src/main/.DS_Store | Bin 0 -> 6148 bytes moa/src/main/java/.DS_Store | Bin 0 -> 6148 bytes moa/src/main/java/moa/.DS_Store | Bin 0 -> 6148 bytes moa/src/main/java/moa/classifiers/.DS_Store | Bin 0 -> 6148 bytes .../main/java/moa/classifiers/meta/.DS_Store | Bin 0 -> 6148 bytes .../moa/classifiers/meta/AutoML/.DS_Store | Bin 0 -> 6148 bytes .../classifiers/meta/AutoML/Algorithm.java | 199 +++++ .../classifiers/meta/AutoML/AutoClass.java | 697 ++++++++++++++++++ .../meta/AutoML/BooleanParameter.java | 127 ++++ .../meta/AutoML/CategoricalParameter.java | 125 ++++ .../AutoML/HeterogeneousEnsembleAbstract.java | 228 ++++++ .../classifiers/meta/AutoML/IParameter.java | 16 + .../meta/AutoML/IntegerParameter.java | 89 +++ .../meta/AutoML/NumericalParameter.java | 93 +++ .../meta/AutoML/OrdinalParameter.java | 98 +++ .../meta/AutoML/TruncatedNormal.java | 56 ++ .../moa/classifiers/meta/AutoML/settings.json | 52 ++ 20 files changed, 1780 insertions(+) create mode 100644 .DS_Store create mode 100644 moa/.DS_Store create mode 100644 moa/src/.DS_Store create mode 100644 moa/src/main/.DS_Store create mode 100644 moa/src/main/java/.DS_Store create mode 100644 moa/src/main/java/moa/.DS_Store create mode 100644 moa/src/main/java/moa/classifiers/.DS_Store create mode 100644 moa/src/main/java/moa/classifiers/meta/.DS_Store create mode 100644 moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java create mode 100755 moa/src/main/java/moa/classifiers/meta/AutoML/settings.json diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a361f1af412f356a6f96409579fabdc215d28871 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5WRgK4%jtvQrJ?)pN$h|lB9 z?glL8EMjM1_nY6{><8H&#u)b&QI|2BF=jzSJW~a6G=`M>+c=G6?#V?OWvbTK0jp)TM)tvS*&UqqJL0(StU6*j=s6v6 zdeU31TGrm-(fQTrC7Go1P4mftZY4VgOLzyxEa%mqrHM?Rz*AmunFpinIe7tZqJNSi4XWZ3DJuyHGtTRy8rj6(SCHyj#kNou# zvWNj<;GZ$T8$*BS!J^FB`eS)`)(U8k&`>b1Km`Q!wMzgPxR30sppFaFAS5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5`hgE3p0)*MP9cYPsW#OHBl zcLNr47O^w1`_1oe_JiyXV~l(AsLPnm7_*=ua#Sh=-Ibx5Nk-&2Mm7l(8G`i@PE72t z1Acp*Wo*hG!uH#Rn#R?})*cm9*i{aG-ZXI?P9LF-b=BrNqHyo$&9 z$lgAa=`4uT@l+MW(HK&0uj4e5xhLmol&M-@2dtLW8rge`MR#!0?})>`v+Rh)pyzbN z@lkKNY*{<|2d5XK=j0`oZ<Q5`*}*SVI^(WJ>WKkjV3mQoHf=or&*7J;eB`f| zkVOm-1OJQx-Wd8r4;E$4)*s8mvsOU6hlYZA1u7t*uUrDazfwk#I#3d-k#p8au}sPqARMShV-NDL4I!~ij{ zwhY)KLDbfk0a`dQKn(oE0PYVG8lr14H>kG`==84kzE}$afv$QxfXMS UI12i8IUrpGG$GUx1HZt)2XazMhyVZp literal 0 HcmV?d00001 diff --git a/moa/src/main/.DS_Store b/moa/src/main/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..11fd813eb4998949bc63e3738d6f353653be2d1a GIT binary patch literal 6148 zcmeHK%}T>S5Z<-XrW7Fu1-&hJEtpo(7B8XJ7cim+m70*E(U>hw+8#^Bqp z>ww>Gu}2oL2@Ae{e>hCyEO*^^zE-i;H>!5kuG_c%gPi(#kdL!&FuF$TLdqm8^&q^A z2h+Z@bt=<5h||GH6~xg1Qf{u|G?LS<9H&vHYJDBBYj&;g?9OJ*&T+dT4%+U#A!ePH z+Ym>Gt@*rWZ}07&ocEuS=TyFEMmZ3!WYb^?ub_M`XyT93M5g!PEAz`dLSldzAO?tm zm1V#j33h#D8K8v|1H`~j4B-ACpdoq&3yo^)fDW(E7`G5nK*zTPqA=(gEHpv{gzHj3 zUCPZ9gX?ne3zO#=EHvtJ#?{O)j+wc9yl^!;_=QSm+|x)sF+dC~GEmW`gXjMl{4z@) z`HLlF5d*})KVyJfJ-^q5MVYhp+w$^d6KoLS6G4Klvd;se>N`e3Y literal 0 HcmV?d00001 diff --git a/moa/src/main/java/.DS_Store b/moa/src/main/java/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a361f1af412f356a6f96409579fabdc215d28871 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5WRgK4%jtvQrJ?)pN$h|lB9 z?glL8EMjM1_nY6{><8H&#u)b&QI|2BF=jzSJW~a6G=`M>+c=G6?#V?OWvbTK0jp)TM)tvS*&UqqJL0(StU6*j=s6v6 zdeU31TGrm-(fQTrC7Go1P4mftZY4VgOLzyxEa%mqrHM?Rz*AmunFpinIe7tZqJNSi4XWZ3DJuyHGtTRy8rj6(SCHyj#kNou# zvWNj<;GZ$T8$*BS!J^FB`eS)`)(U8k&`>b1Km`Q!wMzgPxR30sppFaFAmd~S83}}NkJ^Xf zLNq)6BLnpA%5Vn}_>kc5`~4+Bnhb(S1sL%j45Ba@)oLFimoKcX7oDP0c5b{!m3kw8 zG)~(7@QP;VN=3oU_JfPCpLQ#oCn_HKVcZ|;gs|7gkn77Z?y0n`#&Ivvxt>{YN=~U; z*_ll0t)pgD?l)`Gs+_bMwW>TkXiTRiXKQ!w_^kUBJ;&I+s6k$>kO{%b03}MpIFKwJ_F*9h=LFk?FId*4ZZzw|Vj((}bLAVBa zWCoalc?R-kTA}`by8Qk>pTsj}fEidR21KFbb=p{x?X3&NQLmM#x2Pl(ml^y@K|{A< gjHOn*jj9Fxk_<%GVrCFMDEuLyY2bkw_)`Wx0mQOS?EnA( literal 0 HcmV?d00001 diff --git a/moa/src/main/java/moa/classifiers/.DS_Store b/moa/src/main/java/moa/classifiers/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..11c33a3d2a008291ef4ab231397b4bad43844b67 GIT binary patch literal 6148 zcmeHK-AcnS6i&A3GKSC#1-%P+JFuIg8{U*TU%-l9sLYlPEq2Y=I(sn&z1J7=MSLF5 zNm6k*Z$;cWkbLJiX+CIv7-PIU8#Ebn7-Iq&B1dI~pnGMgWRnp&juB?laTu!*>^B?x z>ww>0VIwwVDGR=Se>je#Y3jP~yi~4iY*s~8)Wx0us4_nd(%G~dOm5J+R4NV&JqWL& zVb*uH&s35IQ8Jw9f@m;=l-uhl8K|tQX31cxYkdfwk#I#3d-k#p8ZJ@tKkG`=)WEM|O44#3kyG=UL1R U;wb3X<$!b%(1cJ&4EzEEAIz{yi~s-t literal 0 HcmV?d00001 diff --git a/moa/src/main/java/moa/classifiers/meta/.DS_Store b/moa/src/main/java/moa/classifiers/meta/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..76d826fc257b1b5bee5cc826571beff1b9796070 GIT binary patch literal 6148 zcmeHK%}xR_5N-jXgqU#9L~l*Jk_c)P;$^}0rWe=fK@GAlfsMi_%4v(C6E&hi zzt-qo$lj?CfjhMu?wQ{l*7GX|B8(b-*zBnU-)ln1<+&evVptP{(Ce#QM-Q+p%hvPj zqfx1{S1$1Fa&cVXqsmsXz;}1H#^Wqo-Pqhesow^jP~2-KDSUKFmJLo}0>;dyf>z17 z5uNTd;Ny5l*tYp-+0)B}w(Q61e2n*FOt=00I(JSl3`E$1xu?(m?vc<$29N<{U{MU1 zBZ%b|wIHs83?Ku4h5>m#a8QYs!C0d@I-pS{0ALE-O2C%2gv{XvErYQ}7y+R=6;P*A zQ({n^4t{RpEQ7H|oldAJKB!rlnhJ%g)nR_F!U?rBVv7tQ1G5YybvG~X|HH5A|Jfw$ zAp^+3zhZ!f#&a?oo~DKakB_?7~Ox{4uIUd3ymO2E&h0caVFHG&5O Oe*_c_*dPN7W#AQWt7v!t literal 0 HcmV?d00001 diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store b/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0= 2){ + System.out.println("Copy failed for " + x.classifier.getCLICreationString(Classifier.class) + "! Reinitialise instead."); + } + this.classifier = x.classifier; // keep the old algorithm for now + keepCurrentModel = false; + } + } else{ + this.classifier = x.classifier; // keep the old algorithm for now + } + + adjustAlgorithm(keepCurrentModel, verbose); + } + + // init constructor + public Algorithm(AlgorithmConfiguration x) { + + this.algorithm = x.algorithm; + this.parameters = new IParameter[x.parameters.length]; + this.preventRemoval = false; + this.isDefault = true; + + this.attributes = new Attribute[x.parameters.length]; + for (int i = 0; i < x.parameters.length; i++) { + + + ParameterConfiguration paramConfig = x.parameters[i]; + if (paramConfig.type.equals("numeric") || paramConfig.type.equals("float") || paramConfig.type.equals("real")) { + NumericalParameter param = new NumericalParameter(paramConfig); + this.parameters[i] = param; + this.attributes[i] = new Attribute(param.getParameter()); + } else if (paramConfig.type.equals("integer")) { + IntegerParameter param = new IntegerParameter(paramConfig); + this.parameters[i] = param; + this.attributes[i] = new Attribute(param.getParameter()); + } else if (paramConfig.type.equals("nominal") || paramConfig.type.equals("categorical") || paramConfig.type.equals("factor")) { + CategoricalParameter param = new CategoricalParameter(paramConfig); + this.parameters[i] = param; + this.attributes[i] = new Attribute(param.getParameter(), Arrays.asList(param.getRange())); + } else if (paramConfig.type.equals("boolean") || paramConfig.type.equals("flag")) { + BooleanParameter param = new BooleanParameter(paramConfig); + this.parameters[i] = param; + this.attributes[i] = new Attribute(param.getParameter(), Arrays.asList(param.getRange())); + } else if (paramConfig.type.equals("ordinal")) { + OrdinalParameter param = new OrdinalParameter(paramConfig); + this.parameters[i] = param; + this.attributes[i] = new Attribute(param.getParameter()); + } else { + throw new RuntimeException("Unknown parameter type: '" + paramConfig.type + + "'. Available options are 'numeric', 'integer', 'nominal', 'boolean' or 'ordinal'"); + } + } + init(); + } + + // initialise a new algorithm using the Command Line Interface (CLI) + public void init() { + // construct CLI string from settings, e.g. denstream.WithDBSCAN -e 0.08 -b 0.3 + StringBuilder commandLine = new StringBuilder(); + commandLine.append(this.algorithm); // first the algorithm class + for (IParameter param : this.parameters) { + commandLine.append(" "); + commandLine.append(param.getCLIString()); + } + + // create new classifier from CLI string + ClassOption opt = new ClassOption("", ' ', "", Classifier.class, commandLine.toString()); + this.classifier = (AbstractClassifier) opt.materializeObject(null, null); + this.classifier.prepareForUse(); + } + + // sample a new configuration based on the current one + public void adjustAlgorithm(boolean keepCurrentModel, int verbose) { + + if (keepCurrentModel) { + // Option 1: keep the old state and just change parameter + StringBuilder commandLine = new StringBuilder(); + for (IParameter param : this.parameters) { + commandLine.append(param.getCLIString()); + } + + Options opts = this.classifier.getOptions(); + for (IParameter param : this.parameters) { + Option opt = opts.getOption(param.getParameter().charAt(0)); + opt.setValueViaCLIString(param.getCLIValueString()); + } + + // these changes do not transfer over directly since all algorithms cache the + // option values. Therefore we try to adjust the cached values if possible + try { + ((AbstractClassifier) this.classifier).adjustParameters(); + if (verbose >= 2) { + System.out.println("Changed: " + this.classifier.getCLICreationString(Classifier.class)); + } + } catch (UnsupportedOperationException e) { + if (verbose >= 2) { + System.out.println("Cannot change parameters of " + this.algorithm + " on the fly, reset instead."); + } + adjustAlgorithm(false, verbose); + } + } else{ + // Option 2: reinitialise the entire state + this.init(); + if (verbose >= 2) { + System.out.println("Initialise: " + this.classifier.getCLICreationString(Classifier.class)); + } + + if (verbose >= 2) { + System.out.println("Train with existing classifiers."); + } + + } + } + + // returns the parameter values as an array + public double[] getParamVector(int padding) { + double[] params = new double[this.parameters.length + padding]; + int pos = 0; + for (IParameter param : this.parameters) { + params[pos++] = param.getValue(); + } + return params; + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java new file mode 100755 index 000000000..c9b39aa68 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java @@ -0,0 +1,697 @@ +package moa.classifiers.meta.AutoML; + +import com.github.javacliparser.FileOption; +import com.google.gson.Gson; +import com.yahoo.labs.samoa.instances.DenseInstance; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import moa.classifiers.Classifier; +import moa.classifiers.MultiClassClassifier; +import moa.classifiers.meta.AdaptiveRandomForestRegressor; +import moa.core.DoubleVector; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.tasks.TaskMonitor; +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + + + public class AutoClass extends HeterogeneousEnsembleAbstract implements MultiClassClassifier { + + private static final long serialVersionUID = 1L; + + int instancesSeen; + int iter; + public int bestModel; + public ArrayList ensemble; + public ArrayList candidateEnsemble; + public Instances windowPoints; + HashMap ARFregs = new HashMap(); + GeneralConfiguration settings; + ArrayList performanceMeasures; + int verbose = 0; + protected ExecutorService executor; + int numberOfCores; + int corrects; + protected boolean[][] onlineHistory; + // the file option dialogue in the UI + public FileOption fileOption = new FileOption("ConfigurationFile", 'f', "Configuration file in json format.", + "/Users/mbahri/Desktop/Dell/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json", ".json", false); + + public void init() { + this.fileOption.getFile(); + } + + + @Override + public boolean isRandomizable() { + return false; + } + + // @Override + //public double[] getVotesForInstance(Instance inst) { + //return null; } + + + @Override + public void resetLearningImpl() { + this.performanceMeasures = new ArrayList<>(); + this.onlineHistory = new boolean[this.settings.ensembleSize][this.settings.windowSize]; + this.historyTotal = new double[this.settings.ensembleSize]; + this.instancesSeen = 0; + this.corrects = 0; + this.bestModel = 0; + this.iter = 0; + this.windowPoints = null ; + + // reset ARFrefs + for (AdaptiveRandomForestRegressor ARFreg : this.ARFregs.values()) { + ARFreg.resetLearning(); + } + + // reset individual classifiers + for (int i = 0; i < this.ensemble.size(); i++) { + //this.ensemble.get(i).init(); + this.ensemble.get(i).classifier.resetLearning(); + } + + if (this.settings.numberOfCores == -1) { + this.numberOfCores = Runtime.getRuntime().availableProcessors(); + } else { + this.numberOfCores = this.settings.numberOfCores; + } + this.executor = Executors.newFixedThreadPool(this.numberOfCores); + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + + if (this.windowPoints == null) { + this.windowPoints = new Instances(inst.dataset()); + } + + if (this.settings.windowSize <= this.windowPoints.numInstances()) { + this.windowPoints.delete(0); + } + this.windowPoints.add(inst); // remember points of the current window + this.instancesSeen++; + + int wValue = this.settings.windowSize; + + for (int i = 0; i < this.ensemble.size(); i++) { + // Online performance estimation + double [] votes = ensemble.get(i).classifier.getVotesForInstance(inst); + boolean correct = (maxIndex(votes) * 1.0 == inst.classValue()); + + + // Maroua: boolean correct = (votes[maxIndex(votes)] * 1.0 == inst.classValue()); + if (correct && ! this.onlineHistory[i][instancesSeen % wValue]) { + // performance estimation increases + this.onlineHistory[i][instancesSeen % wValue] = true; + + this.historyTotal[i] += 1.0/wValue; + + } else if (!correct && this.onlineHistory[i][instancesSeen % wValue]) { + // performance estimation decreases + this.onlineHistory[i][instancesSeen % wValue] = false; + this.historyTotal[i] -= 1.0/wValue; + + } else { + // nothing happens + } + } + + if (this.numberOfCores == 1) { + // train all models with the instance + this.performanceMeasures = new ArrayList(this.ensemble.size()); + double bestPerformance = Double.NEGATIVE_INFINITY; + for (int i = 0; i < this.ensemble.size(); i++) { + this.ensemble.get(i).classifier.trainOnInstance(inst); + + // To extract the best performing mode + double performance = this.historyTotal[i]; + this.performanceMeasures.add(performance); + if (performance > bestPerformance) { + this.bestModel = i; + bestPerformance = performance; + } + } + + if (this.settings.useTestEnsemble && this.candidateEnsemble.size() > 0) { + // train all models with the instance + for (int i = 0; i < this.candidateEnsemble.size(); i++) { //Maroua; cétait Ensemble + this.candidateEnsemble.get(i).classifier.trainOnInstance(inst); + } + } + } else { + //EnsembleClassifierAbstractAUTO.EnsembleRunnable + ArrayList trainers = new ArrayList(); + for (int i = 0; i < this.ensemble.size(); i++) { + EnsembleRunnable trainer = new EnsembleRunnable(this.ensemble.get(i).classifier, inst); + trainers.add(trainer); + } + if (this.settings.useTestEnsemble && this.candidateEnsemble.size() > 0) { + // train all models with the instance + for (int i = 0; i < this.candidateEnsemble.size(); i++) { + EnsembleRunnable trainer = new EnsembleRunnable(this.candidateEnsemble.get(i).classifier, inst); + trainers.add(trainer); + } + } + try { + this.executor.invokeAll(trainers); + } catch (InterruptedException ex) { + throw new RuntimeException("Could not call invokeAll() on training threads."); + } + } + + // every windowSize, we update the configurations + if (this.instancesSeen % this.settings.windowSize == 0) { + + if (this.verbose >= 1) { + System.out.println(" "); + System.out.println("-------------- Processed " + instancesSeen + " Instances --------------"); + } + + updateConfiguration(); // update configuration + } + + } + + /** + * Returns votes using the best performing method + * @param inst + * @return votes + */ + @Override + public double[] getVotesForInstance(Instance inst) { + + + return this.ensemble.get(this.bestModel).classifier.getVotesForInstance(inst); +// DoubleVector combinedVote = new DoubleVector(); +// +// for(int i = 0 ; i < this.ensemble.size() ; ++i) { +// DoubleVector vote = new DoubleVector(this.ensemble.get(i).classifier.getVotesForInstance(inst)); +// if (vote.sumOfValues() > 0.0) { +// vote.normalize(); +// +// combinedVote.addValues(vote); +// } +// } +// return combinedVote.getArrayRef(); + } + + + protected void updateConfiguration() { + // init evaluation measure + if (this.verbose >= 2) { + System.out.println(" "); + System.out.println("---- Evaluate performance of current ensemble:"); + } + evaluatePerformance(); + + if (this.settings.useTestEnsemble) { + promoteCandidatesIntoEnsemble(); + } + + if (this.verbose >= 1) { + System.out.println("Classifier " + this.bestModel + " (" + + this.ensemble.get(this.bestModel).classifier.getCLICreationString(Classifier.class) + + ") is the active classifier with performance: " + this.performanceMeasures.get(this.bestModel)); + } + + generateNewConfigurations(); + + // this.windowPoints.delete(); // flush the current window + this.iter++; + } + + protected void evaluatePerformance() { + + HashMap bestPerformanceValMap = new HashMap(); + HashMap bestPerformanceIdxMap = new HashMap(); + HashMap algorithmCount = new HashMap(); + + this.performanceMeasures = new ArrayList(this.ensemble.size()); + double bestPerformance = Double.NEGATIVE_INFINITY; + for (int i = 0; i < this.ensemble.size(); i++) { + + // predict performance just for evaluation + predictPerformance(this.ensemble.get(i)); + + double performance = this.historyTotal[i]; + this.performanceMeasures.add(performance); + if (performance > bestPerformance) { + this.bestModel = i; + bestPerformance = performance; + } + + if (this.verbose >= 1) { + System.out.println(i + ") " + this.ensemble.get(i).classifier.getCLICreationString(Classifier.class) + + "\t => \t performance: " + performance); + } + + String algorithm = this.ensemble.get(i).classifier.getPurposeString(); + if (!bestPerformanceIdxMap.containsKey(algorithm) || performance > bestPerformanceValMap.get(algorithm)) { + bestPerformanceValMap.put(algorithm, performance); // best performance per algorithm + bestPerformanceIdxMap.put(algorithm, i); // index of best performance per algorithm + } + + // number of instances per algorithm in ensemble + + algorithmCount.put(algorithm, algorithmCount.getOrDefault(algorithm, 0) + 1); + trainRegressor(this.ensemble.get(i), performance); + } + + updateRemovalFlags(bestPerformanceValMap, bestPerformanceIdxMap, algorithmCount); + } + + + /** + * Computes the accuracy of a learner for a given window of instances. + * @param algorithm classifier to compute error + * @return the computed accuracy. + */ + protected double computePerformanceMeasure(Algorithm algorithm) { + double acc = 0; + this.trainOnChunk(algorithm); + for (int i = 0; i < this.windowPoints.numInstances(); i++) { + try { + + double[] votes = algorithm.classifier.getVotesForInstance(this.windowPoints.instance(i)); + boolean correct = (maxIndex(votes)* 1.0 == this.windowPoints.instance(i).classValue()); + + if (correct){ + acc += 1.0/this.windowPoints.numInstances(); + }else + acc -= 1.0/this.windowPoints.numInstances(); + // algorithm.classifier.trainOnInstance(this.windowPoints.instance(i));.... + } catch (Exception e) { + System.out.println("computePerformanceMeasure Error"); + } + } + algorithm.performanceMeasure = acc; + return acc; + } + /** + * Trains a classifier on the most recent window of data. + * + * @param algorithm + * Classifier being trained. + */ + private void trainOnChunk(Algorithm algorithm) { + for (int i = 0; i < this.windowPoints.numInstances(); i++) { + algorithm.classifier.trainOnInstance(this.windowPoints.instance(i)); + } + } + + + protected void promoteCandidatesIntoEnsemble() { + for (int i = 0; i < this.candidateEnsemble.size(); i++) { + + Algorithm newAlgorithm = this.candidateEnsemble.get(i); + + // predict performance just for evaluation + predictPerformance(newAlgorithm); + + // evaluate + double performance = computePerformanceMeasure(newAlgorithm); + + if (this.verbose >= 1) { + System.out.println("Test " + i + ") " + newAlgorithm.classifier.getCLICreationString(Classifier.class) + + "\t => \t Performance: " + performance); + } + + // replace if better than existing + + if (this.ensemble.size() < this.settings.ensembleSize) { + if (this.verbose >= 1) { + System.out.println("Promote " + newAlgorithm.classifier.getCLICreationString(Classifier.class) + + " from test ensemble to the ensemble as new configuration"); + } + + this.performanceMeasures.add(newAlgorithm.performanceMeasure); + + this.ensemble.add(newAlgorithm); + + } else if (performance > AutoClass.getWorstSolution(this.performanceMeasures)) { + + HashMap replace = getReplaceMap(this.performanceMeasures); + + if (replace.size() == 0) { + return; + } + + int replaceIdx = AutoClass.sampleProportionally(replace, + !this.settings.performanceMeasureMaximisation); // false + if (this.verbose >= 1) { + System.out.println("Promote " + newAlgorithm.classifier.getCLICreationString(Classifier.class) + + " from test ensemble to the ensemble by replacing " + replaceIdx); + } + + // update performance measure + this.performanceMeasures.set(replaceIdx, newAlgorithm.performanceMeasure); + + // replace in ensemble + + this.ensemble.set(replaceIdx, newAlgorithm); + } + + } + } + + protected void trainRegressor(Algorithm algortihm, double performance) { + double[] params = algortihm.getParamVector(1); + params[params.length - 1] = performance; // add performance as class + Instance inst = new DenseInstance(1.0, params); + + // add header to dataset TODO: do we need an attribute for the class label? + Instances dataset = new Instances(null, algortihm.attributes, 0); + dataset.setClassIndex(dataset.numAttributes()); // set class index to our performance feature + inst.setDataset(dataset); + + // train adaptive random forest regressor based on performance of model + this.ARFregs.get(algortihm.algorithm).trainOnInstanceImpl(inst); + } + + protected void updateRemovalFlags(HashMap bestPerformanceValMap, + HashMap bestPerformanceIdxMap, HashMap algorithmCount) { + + // reset flags + for (Algorithm algorithm : ensemble) { + algorithm.preventRemoval = false; + } + + // only keep best overall algorithm + if (this.settings.keepGlobalIncumbent) { + this.ensemble.get(this.bestModel).preventRemoval = true; + } + + // keep best instance per algorithm + if (this.settings.keepAlgorithmIncumbents) { + for (int idx : bestPerformanceIdxMap.values()) { + this.ensemble.get(idx).preventRemoval = true; + } + } + + // keep all default configurations + if (this.settings.keepInitialConfigurations) { + for (Algorithm algorithm : this.ensemble) { + if (algorithm.isDefault) { + algorithm.preventRemoval = true; + } + } + } + + // keep at least one instance per algorithm + if (this.settings.preventAlgorithmDeath) { + for (Algorithm algorithm : this.ensemble) { + if (algorithmCount.get(algorithm.algorithm) == 1) { + algorithm.preventRemoval = true; + } + } + } + } + + // predict performance of new configuration + protected void generateNewConfigurations() { + + // get performance values + if (this.settings.useTestEnsemble) { + candidateEnsemble.clear(); + } + + for (int z = 0; z < this.settings.newConfigurations; z++) { + + if (this.verbose == 2) { + System.out.println(" "); + System.out.println("---- Sample new configuration " + z + ":"); + } + + int parentIdx = sampleParent(this.performanceMeasures); + Algorithm newAlgorithm = sampleNewConfiguration(parentIdx); + + if (this.settings.useTestEnsemble) { + if (this.verbose >= 1) { + System.out.println("Based on " + parentIdx + " add " + + newAlgorithm.classifier.getCLICreationString(Classifier.class) + " to test ensemble"); + } + candidateEnsemble.add(newAlgorithm); + } else { + double prediction = predictPerformance(newAlgorithm); + + if (this.verbose >= 1) { + System.out.println("Based on " + parentIdx + " predict: " + + newAlgorithm.classifier.getCLICreationString(Classifier.class) + "\t => \t Performance: " + + prediction); + } + + // the random forest only works with at least two training samples + if (Double.isNaN(prediction)) { + return; + } + + // if we still have open slots in the ensemble (not full) + if (this.ensemble.size() < this.settings.ensembleSize) { + if (this.verbose >= 1) { + System.out.println("Add configuration as new algorithm."); + } + + // add to ensemble + this.ensemble.add(newAlgorithm); + + // update current performance with the prediction + this.performanceMeasures.add(prediction); + } else if (prediction > AutoClass.getWorstSolution(this.performanceMeasures)) { + // if the predicted performance is better than the one we have in the ensemble + HashMap replace = getReplaceMap(this.performanceMeasures); + + if (replace.size() == 0) { + return; + } + + int replaceIdx = AutoClass.sampleProportionally(replace, + !this.settings.performanceMeasureMaximisation); // false + + if (this.verbose >= 1) { + System.out.println("Replace algorithm: " + replaceIdx); + } + + // update current performance with the prediction + this.performanceMeasures.set(replaceIdx, prediction); + + // replace in ensemble + this.ensemble.set(replaceIdx, newAlgorithm); + } + } + + } + + } + + protected int sampleParent(ArrayList performM) { + // copy existing classifier configuration + HashMap parents = new HashMap(); + for (int i = 0; i < performM.size(); i++) { + parents.put(i, performM.get(i)); + } + int parentIdx = AutoClass.sampleProportionally(parents, + this.settings.performanceMeasureMaximisation); // true + + return parentIdx; + } + + protected Algorithm sampleNewConfiguration(int parentIdx) { + + if (this.verbose >= 2) { + System.out.println("Selected Configuration " + parentIdx + " as parent: " + + this.ensemble.get(parentIdx).classifier.getCLICreationString(Classifier.class)); + } + Algorithm newAlgorithm = new Algorithm(this.ensemble.get(parentIdx), this.settings.lambda, + this.settings.resetProbability, this.settings.keepCurrentModel, this.verbose); + + return newAlgorithm; + } + + protected double predictPerformance(Algorithm newAlgorithm) { + // create instance from new configuration + double[] params = newAlgorithm.getParamVector(0); + Instance newInst = new DenseInstance(1.0, params); + Instances newDataset = new Instances(null, newAlgorithm.attributes, 0); + newDataset.setClassIndex(newDataset.numAttributes()); + newInst.setDataset(newDataset); + + // predict the performance of the new configuration using the trained adaptive + // random forest + double prediction = this.ARFregs.get(newAlgorithm.algorithm).getVotesForInstance(newInst)[0]; + + newAlgorithm.prediction = prediction; // remember prediction + + return prediction; + } + + // get mapping of algorithms and their performance that could be removed + HashMap getReplaceMap(ArrayList performM) { + HashMap replace = new HashMap(); + + double worst = AutoClass.getWorstSolution(performM); + + // replace solutions that cannot get worse first + if (worst <= -1.0) { + for (int i = 0; i < this.ensemble.size(); i++) { + if (performM.get(i) <= -1.0 && !this.ensemble.get(i).preventRemoval) { + replace.put(i, performM.get(i)); + } + } + } + + if (replace.size() == 0) { + for (int i = 0; i < this.ensemble.size(); i++) { + if (!this.ensemble.get(i).preventRemoval) { + replace.put(i, performM.get(i)); + } + } + } + + return replace; + } + + // get lowest value in arraylist + static double getWorstSolution(ArrayList values) { + + double min = Double.POSITIVE_INFINITY; + for (int i = 0; i < values.size(); i++) { + if (values.get(i) < min) { + min = values.get(i); + } + } + return (min); + } + + static int sampleProportionally(HashMap values, boolean maximisation) { + + // if we want to sample lower values with higher probability, we invert here + if (!maximisation) { + HashMap vals = new HashMap(values.size()); + + for (int i : values.keySet()) { + vals.put(i, -1 * values.get(i)); + } + return (AutoClass.rouletteWheelSelection(vals)); + } + + return (AutoClass.rouletteWheelSelection(values)); + } + + // sample an index from a list of values, proportionally to the respective value + static int rouletteWheelSelection(HashMap values) { + + // get min + double minVal = Double.POSITIVE_INFINITY; + for (Double value : values.values()) { + if (value < minVal) { + minVal = value; + } + } + + // to have a positive range we shift here + double shift = Math.abs(minVal) - minVal; + + double completeWeight = 0.0; + for (Double value : values.values()) { + completeWeight += value + shift; + } + + // sample random number within range of total weight + double r = Math.random() * completeWeight; + double countWeight = 0.0; + + for (int j : values.keySet()) { + countWeight += values.get(j) + shift; + if (countWeight >= r) { + return j; + } + } + throw new RuntimeException("Sampling failed"); + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // TODO Auto-generated method stub + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + // TODO Auto-generated method stub + } + @Override + public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + + try { + + // read settings from json + BufferedReader bufferedReader = new BufferedReader(new FileReader(fileOption.getValue())); + Gson gson = new Gson(); + // store settings in dedicated class structure + this.settings = gson.fromJson(bufferedReader, GeneralConfiguration.class); + + this.instancesSeen = 0; + this.bestModel = 0; + this.iter = 0; + this.windowPoints = null; + + // create the ensemble + this.ensemble = new ArrayList(this.settings.ensembleSize); + // copy and initialise the provided starting configurations in the ensemble + for (int i = 0; i < this.settings.algorithms.length; i++) { + this.ensemble.add(new Algorithm(this.settings.algorithms[i])); + } + + if (this.settings.useTestEnsemble) { + this.candidateEnsemble = new ArrayList(this.settings.newConfigurations); + } + + // create one regressor per algorithm + for (int i = 0; i < this.settings.algorithms.length; i++) { + AdaptiveRandomForestRegressor ARFreg = new AdaptiveRandomForestRegressor(); + ARFreg.prepareForUse(); + this.ARFregs.put(this.settings.algorithms[i].algorithm, ARFreg); + } + + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + + super.prepareForUseImpl(monitor, repository); + + } + + // Modified from: + // https://github.com/Waikato/moa/blob/master/moa/src/main/java/moa/classifiers/meta/AdaptiveRandomForest.java#L157 + // Helper class for parallelisation + protected class EnsembleRunnable implements Runnable, Callable { + final private Classifier classifier; + final private Instance instance; + + public EnsembleRunnable(Classifier classifier, Instance instance) { + this.classifier = classifier; + this.instance = instance; + } + + @Override + public void run() { + classifier.trainOnInstance(this.instance); + } + + @Override + public Integer call() throws Exception { + run(); + return 0; + } + } diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java new file mode 100755 index 000000000..4342668e7 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java @@ -0,0 +1,127 @@ +package moa.classifiers.meta.AutoML; + +import com.yahoo.labs.samoa.instances.Attribute; +import java.util.ArrayList; +import java.util.HashMap; + +// the representation of a boolean / binary / flag parameter +public class BooleanParameter implements IParameter { + private String parameter; + private int numericValue; + private String value; + private String[] range = { "false", "true" }; + private Attribute attribute; + private ArrayList probabilities; + private boolean optimise; + + public BooleanParameter(BooleanParameter x) { + this.parameter = x.parameter; + this.numericValue = x.numericValue; + this.value = x.value; + this.attribute = x.attribute; + this.optimise = x.optimise; + + if(this.optimise){ + this.range = x.range.clone(); + this.probabilities = new ArrayList(x.probabilities); + } + + } + + public BooleanParameter(ParameterConfiguration x) { + this.parameter = x.parameter; + this.value = String.valueOf(x.value); + for (int i = 0; i < this.range.length; i++) { + if (this.range[i].equals(this.value)) { + this.numericValue = i; // get index of init value + } + } + this.attribute = new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.probabilities = new ArrayList(2); + for (int i = 0; i < 2; i++) { + this.probabilities.add(0.5); // equal probabilities + } + } + } + + public BooleanParameter copy() { + return new BooleanParameter(this); + } + + public String getCLIString() { + // if option is set + if (this.numericValue == 1) { + return ("-" + this.parameter); // only the parameter + } + return ""; + } + + public String getCLIValueString() { + if (this.numericValue == 1) { + return (""); + } else { + return (null); + } + } + + public double getValue() { + return this.numericValue; + } + + public String getParameter() { + return this.parameter; + } + + public String[] getRange() { + return this.range; + } + + public void sampleNewConfig(double lambda, double reset, int verbose) { + + if(!this.optimise){ + return; + } + + if (Math.random() < reset) { + for (int i = 0; i < this.probabilities.size(); i++) { + this.probabilities.set(i, 1.0/this.probabilities.size()); + } + } + + HashMap map = new HashMap(); + for (int i = 0; i < this.probabilities.size(); i++) { + map.put(i, this.probabilities.get(i)); + } + + // update configuration + this.numericValue = AutoClass.sampleProportionally(map, true); + String newValue = this.range[this.numericValue]; + if (verbose >= 3) { + System.out + .print("Sample new configuration for boolean parameter -" + this.parameter + " with probabilities"); + for (int i = 0; i < this.probabilities.size(); i++) { + System.out.print(" " + this.probabilities.get(i)); + } + System.out.println("\t=>\t -" + this.parameter + " " + newValue); + } + this.value = newValue; + + // adapt distribution + // this.probabilities.set(this.numericValue, + // this.probabilities.get(this.numericValue) + (1.0/iter)); + this.probabilities.set(this.numericValue, + this.probabilities.get(this.numericValue) * (2 - Math.pow(2, -1 * lambda))); + + // divide by sum + double sum = 0.0; + for (int i = 0; i < this.probabilities.size(); i++) { + sum += this.probabilities.get(i); + } + for (int i = 0; i < this.probabilities.size(); i++) { + this.probabilities.set(i, this.probabilities.get(i) / sum); + } + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java new file mode 100755 index 000000000..e85e672c8 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java @@ -0,0 +1,125 @@ +package moa.classifiers.meta.AutoML; + +import com.yahoo.labs.samoa.instances.Attribute; +import moa.classifiers.meta.AutoML.IParameter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; + +// the representation of a categorical / nominal parameter +public class CategoricalParameter implements IParameter { + private String parameter; + private int numericValue; + private String value; + private String[] range; + private Attribute attribute; + private ArrayList probabilities; + private boolean optimise; + + public CategoricalParameter(CategoricalParameter x) { + this.parameter = x.parameter; + this.numericValue = x.numericValue; + this.value = x.value; + this.attribute = x.attribute; + this.optimise = x.optimise; + + if(this.optimise){ + this.range = x.range.clone(); + this.probabilities = new ArrayList(x.probabilities); + } + } + + public CategoricalParameter(ParameterConfiguration x) { + this.parameter = x.parameter; + this.value = String.valueOf(x.value); + this.attribute = new Attribute(x.parameter, Arrays.asList(range)); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = new String[x.range.length]; + for (int i = 0; i < x.range.length; i++) { + range[i] = String.valueOf(x.range[i]); + if (this.range[i].equals(this.value)) { + this.numericValue = i; // get index of init value + } + } + this.probabilities = new ArrayList(x.range.length); + for (int i = 0; i < x.range.length; i++) { + this.probabilities.add(1.0 / x.range.length); // equal probabilities + } + } + } + + public CategoricalParameter copy() { + return new CategoricalParameter(this); + } + + public String getCLIString() { + return ("-" + this.parameter + " " + this.value); + } + + public String getCLIValueString() { + return ("" + this.value); + } + + public double getValue() { + return this.numericValue; + } + + public String getParameter() { + return this.parameter; + } + + public String[] getRange() { + return this.range; + } + + public void sampleNewConfig(double lambda, double reset, int verbose) { + + if(!this.optimise){ + return; + } + + if (Math.random() < reset) { + for (int i = 0; i < this.probabilities.size(); i++) { + this.probabilities.set(i, 1.0/this.probabilities.size()); + } + } + + HashMap map = new HashMap(); + for (int i = 0; i < this.probabilities.size(); i++) { + map.put(i, this.probabilities.get(i)); + } + // update configuration + this.numericValue = AutoClass.sampleProportionally(map, true); + String newValue = this.range[this.numericValue]; + + if (verbose >= 3) { + System.out + .print("Sample new configuration for nominal parameter -" + this.parameter + "with probabilities"); + for (int i = 0; i < this.probabilities.size(); i++) { + System.out.print(" " + this.probabilities.get(i)); + } + System.out.println("\t=>\t -" + this.parameter + " " + newValue); + } + this.value = newValue; + + // adapt distribution + // TODO not directly transferable from irace: (1-((iter -1) / maxIter)) + // this.probabilities.set(this.numericValue, + // this.probabilities.get(this.numericValue) + (1.0/iter)); + this.probabilities.set(this.numericValue, + this.probabilities.get(this.numericValue) * (2 - Math.pow(2, -1 * lambda))); + + // divide by sum (TODO is this even necessary with our proportional sampling + // strategy?) + double sum = 0.0; + for (int i = 0; i < this.probabilities.size(); i++) { + sum += this.probabilities.get(i); + } + for (int i = 0; i < this.probabilities.size(); i++) { + this.probabilities.set(i, this.probabilities.get(i) / sum); + } + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java new file mode 100755 index 000000000..7a17f8d64 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java @@ -0,0 +1,228 @@ +/* + * HeterogeneousEnsembleAbstract.java + * Copyright (C) 2017 University of Waikato, Hamilton, New Zealand + * @author Jan N. van Rijn (j.n.van.rijn@liacs.leidenuniv.nl) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ +package moa.classifiers.meta.AutoML; + +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.ListOption; +import com.github.javacliparser.Option; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.InstancesHeader; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.Classifier; +import moa.classifiers.MultiClassClassifier; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * BLAST (Best Last) for Heterogeneous Ensembles Abstract Base Class + * + *

+ * Given a set of (heterogeneous) classifiers, BLAST builds an ensemble, and + * determines the weights of all ensemble members based on their performance on + * recent observed instances. Used as Abstact Base Class for + * HeterogeneousEnsembleBlast and HeterogeneousEnsembleBlastFadingFactors. + *

+ * + *

+ * J. N. van Rijn, G. Holmes, B. Pfahringer, J. Vanschoren. Having a Blast: + * Meta-Learning and Heterogeneous Ensembles for Data Streams. In 2015 IEEE + * International Conference on Data Mining, pages 1003-1008. IEEE, 2015. + *

+ * + *

+ * Parameters: + *

+ *
    + *
  • -b : Comma-separated string of classifiers
  • + *
  • -g : Grace period (1 = optimal)
  • + *
  • -k : Number of active classifiers
  • + *
+ * Maroua: Created for EnsembleClassifierAbstractAuto2 + * @author Jan N. van Rijn (j.n.van.rijn@liacs.leidenuniv.nl) + * @version $Revision: 1 $ + */ +public abstract class HeterogeneousEnsembleAbstract extends AbstractClassifier implements MultiClassClassifier { + + private static final long serialVersionUID = 1L; + + public ListOption baselearnersOption = new ListOption("baseClassifiers", 'b', + "The classifiers the ensemble consists of.", + new ClassOption("learner", ' ', "", Classifier.class, + "trees.HoeffdingTree"), + new Option[] { + new ClassOption("", ' ', "", Classifier.class, "bayes.NaiveBayes"), + new ClassOption("", ' ', "", Classifier.class, + "functions.Perceptron"), + new ClassOption("", ' ', "", Classifier.class, "functions.SGD"), + new ClassOption("", ' ', "", Classifier.class, "lazy.kNN"), + new ClassOption("", ' ', "", Classifier.class, + "trees.HoeffdingTree") }, + ','); + + public IntOption gracePerionOption = new IntOption("gracePeriod", 'g', + "How many instances before we reevalate the best classifier", 1, 1, + Integer.MAX_VALUE); + + public IntOption activeClassifiersOption = new IntOption("activeClassifiers", + 'k', "The number of active classifiers (used for voting)", 1, 1, + Integer.MAX_VALUE); + + public FlagOption weightClassifiersOption = new FlagOption( + "weightClassifiers", 'p', + "Uses online performance estimation to weight the classifiers"); + + protected Classifier[] ensemble; + + protected double[] historyTotal; + + protected Integer instancesSeen; + + List topK; + + @Override + public String getPurposeString() { + return "The model-free heterogeneous ensemble as presented in " + + "'Having a Blast: Meta-Learning and Heterogeneous Ensembles " + + "for Data Streams' (ICDM 2015)."; + } + + public int getEnsembleSize() { + // mainly for testing, @throws exception if called before initialization + return this.ensemble.length; + } + + public String getMemberCliString(int idx) { + // mainly for testing, @pre: idx < getEnsembleSize() + return this.ensemble[idx].getCLICreationString(Classifier.class); + } + + @Override + public double[] getVotesForInstance(Instance inst) { + double[] votes = new double[inst.classAttribute().numValues()]; + + for (int i = 0; i < topK.size(); ++i) { + double[] memberVotes = normalize( + ensemble[topK.get(i)].getVotesForInstance(inst)); + double weight = 1.0; + + if (weightClassifiersOption.isSet()) { + weight = historyTotal[topK.get(i)]; + } + + // make internal classifiers so-called "hard classifiers" + votes[maxIndex(memberVotes)] += 1.0 * weight; + } + + return votes; + } + + @Override + public void setModelContext(InstancesHeader ih) { + super.setModelContext(ih); + + for (int i = 0; i < this.ensemble.length; ++i) { + this.ensemble[i].setModelContext(ih); + } + } + + @Override + public boolean isRandomizable() { + return false; + } + + @Override + public void getModelDescription(StringBuilder arg0, int arg1) { + // TODO Auto-generated method stub + + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + // TODO Auto-generated method stub + return null; + } + + @Override + public void prepareForUseImpl(TaskMonitor monitor, + ObjectRepository repository) { + + Option[] learnerOptions = this.baselearnersOption.getList(); + this.ensemble = new Classifier[learnerOptions.length]; + for (int i = 0; i < learnerOptions.length; i++) { + monitor.setCurrentActivity("Materializing learner " + (i + 1) + "...", + -1.0); + this.ensemble[i] = (Classifier) ((ClassOption) learnerOptions[i]) + .materializeObject(monitor, repository); + if (monitor.taskShouldAbort()) { + return; + } + monitor.setCurrentActivity("Preparing learner " + (i + 1) + "...", -1.0); + this.ensemble[i].prepareForUse(monitor, repository); + if (monitor.taskShouldAbort()) { + return; + } + } + super.prepareForUseImpl(monitor, repository); + + topK = topK(historyTotal, activeClassifiersOption.getValue()); + } + + protected static List topK(double[] scores, int k) { + double[] scoresWorking = Arrays.copyOf(scores, scores.length); + + List topK = new ArrayList(); + + for (int i = 0; i < k; ++i) { + int bestIdx = maxIndex(scoresWorking); + topK.add(bestIdx); + scoresWorking[bestIdx] = -1; + } + + return topK; + } + + protected static int maxIndex(double[] scores) { + int bestIdx = 0; + for (int i = 1; i < scores.length; ++i) { + if (scores[i] > scores[bestIdx]) { + bestIdx = i; + } + } + return bestIdx; + } + + protected static double[] normalize(double[] input) { + double sum = 0.0; + for (int i = 0; i < input.length; ++i) { + sum += input[i]; + } + for (int i = 0; i < input.length; ++i) { + input[i] /= sum; + } + return input; + } +} diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java new file mode 100755 index 000000000..3c3d9af66 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java @@ -0,0 +1,16 @@ +package moa.classifiers.meta.AutoML; + +// interface allows us to maintain a single list of parameters +public interface IParameter { + public void sampleNewConfig(double lambda, double reset, int verbose); + + public IParameter copy(); + + public String getCLIString(); + + public String getCLIValueString(); + + public double getValue(); + + public String getParameter(); +} diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java new file mode 100755 index 000000000..517e8fe6b --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java @@ -0,0 +1,89 @@ +package moa.classifiers.meta.AutoML; + +import com.yahoo.labs.samoa.instances.Attribute; + +// the representation of an integer parameter +public class IntegerParameter implements IParameter { + private String parameter; + private int value; + private int[] range; + private double std; + private Attribute attribute; + private boolean optimise; + + public IntegerParameter(IntegerParameter x) { + this.parameter = x.parameter; + this.value = x.value; + this.attribute = x.attribute;// new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = x.range.clone(); + this.std = x.std; + } + } + + public IntegerParameter(ParameterConfiguration x) { + this.parameter = x.parameter; + this.value = (int)(double) x.value; // TODO fix casts + + this.attribute = new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = new int[x.range.length]; + for (int i = 0; i < x.range.length; i++) { + range[i] = (int) (double) x.range[i]; + } + this.std = (this.range[1] - this.range[0]) / 2; + } + } + + public IntegerParameter copy() { + return new IntegerParameter(this); + } + + public String getCLIString() { + return ("-" + this.parameter + " " + this.value); + } + + public String getCLIValueString() { + return ("" + this.value); + } + + public double getValue() { + return this.value; + } + + public void setValue(int value){ + this.value = value; + } + + public String getParameter() { + return this.parameter; + } + + public void sampleNewConfig(double lambda, double reset, int verbose) { + if(!this.optimise){ + return; + } + if (Math.random() < reset) { + this.std = (this.range[1] - this.range[0]) / 2; + } + + // update configuration + // for integer features use truncated normal distribution + TruncatedNormal trncnormal = new TruncatedNormal(this.value, this.std, this.range[0], this.range[1]); + int newValue = (int) Math.round(trncnormal.sample()); + if (verbose >= 3) { + System.out.println("Sample new configuration for integer parameter -" + this.parameter + " with mean: " + + this.value + ", std: " + this.std + ", lb: " + this.range[0] + ", ub: " + this.range[1] + + "\t=>\t -" + this.parameter + " " + newValue); + } + + this.value = newValue; + + // adapt distribution + this.std = this.std * Math.pow(2, -1 * lambda); + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java new file mode 100755 index 000000000..8d4fb6fbd --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java @@ -0,0 +1,93 @@ +package moa.classifiers.meta.AutoML; + +import com.yahoo.labs.samoa.instances.Attribute; + +// the representation of a numerical / real parameter +public class NumericalParameter implements IParameter { + private String parameter; + private double value; + private double[] range; + private double std; + private Attribute attribute; + private boolean optimise; + + public NumericalParameter(NumericalParameter x) { + this.parameter = x.parameter; + this.value = x.value; + this.attribute = new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = x.range.clone(); + this.std = x.std; + } + + } + + public NumericalParameter(ParameterConfiguration x) { + this.parameter = x.parameter; + this.value = (double) x.value; + this.attribute = new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = new double[x.range.length]; + for (int i = 0; i < x.range.length; i++) { + range[i] = (double) x.range[i]; + } + this.std = (this.range[1] - this.range[0]) / 2; + } + } + + public NumericalParameter copy() { + return new NumericalParameter(this); + } + + public String getCLIString() { + return ("-" + this.parameter + " " + this.value); + } + + public String getCLIValueString() { + return ("" + this.value); + } + + public double getValue() { + return this.value; + } + + public String getParameter() { + return this.parameter; + } + + public void sampleNewConfig(double lambda, double reset, int verbose) { + + if(!this.optimise){ + return; + } + + // trying to balanced exploitation vs exploration by resetting the std + if (Math.random() < reset) { + this.std = (this.range[1] - this.range[0]) / 2; + } + this.std = this.std * Math.pow(2, -1 * lambda); + + // update configuration + // for numeric features use truncated normal distribution + TruncatedNormal trncnormal = new TruncatedNormal(this.value, this.std, this.range[0], this.range[1]); + double newValue = trncnormal.sample(); + + if (verbose >= 3) { + System.out.println("Sample new configuration for numerical parameter -" + this.parameter + " with mean: " + + this.value + ", std: " + this.std + ", lb: " + this.range[0] + ", ub: " + this.range[1] + + "\t=>\t -" + this.parameter + " " + newValue); + } + + this.value = newValue; + + // adapt distribution + // this.std = this.std * (Math.pow((1.0 / nbNewConfigurations), (1.0 / + // nbVariable))); + + + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java new file mode 100755 index 000000000..9a445541e --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java @@ -0,0 +1,98 @@ +package moa.classifiers.meta.AutoML; + +import com.yahoo.labs.samoa.instances.Attribute; + +// the representation of an integer parameter +public class OrdinalParameter implements IParameter { + private String parameter; + private String value; + private int numericValue; + private String[] range; + private double std; + private Attribute attribute; + private boolean optimise; + + // copy constructor + public OrdinalParameter(OrdinalParameter x) { + this.parameter = x.parameter; + this.value = x.value; + this.numericValue = x.numericValue; + this.attribute = x.attribute; + this.optimise = x.optimise; + + if(this.optimise){ + this.range = x.range.clone(); + this.std = x.std; + } + } + + // init constructor + public OrdinalParameter(ParameterConfiguration x) { + this.parameter = x.parameter; + this.value = String.valueOf(x.value); + this.attribute = new Attribute(x.parameter); + this.optimise = x.optimise; + + if(this.optimise){ + this.range = new String[x.range.length]; + for (int i = 0; i < x.range.length; i++) { + range[i] = String.valueOf(x.range[i]); + if (this.range[i].equals(this.value)) { + this.numericValue = i; // get index of init value + } + } + this.std = (this.range.length - 0) / 2; + } + + } + + public OrdinalParameter copy() { + return new OrdinalParameter(this); + } + + public String getCLIString() { + return ("-" + this.parameter + " " + this.value); + } + + public String getCLIValueString() { + return ("" + this.value); + } + + public double getValue() { + return this.numericValue; + } + + public String getParameter() { + return this.parameter; + } + + public void sampleNewConfig(double lambda, double reset, int verbose) { + + if(!this.optimise){ + return; + } + + // update configuration + if (Math.random() < reset) { + double upper = (double) (this.range.length - 1); + this.std = upper / 2; + } + + // treat index of range as integer parameter + TruncatedNormal trncnormal = new TruncatedNormal(this.numericValue, this.std, 0.0, (double) (this.range.length - 1)); // limits are the indexes of the range + int newValue = (int) Math.round(trncnormal.sample()); + + if (verbose >= 3) { + System.out.println("Sample new configuration for ordinal parameter -" + this.parameter + " with mean: " + + this.numericValue + ", std: " + this.std + ", lb: " + 0 + ", ub: " + (this.range.length - 1) + + "\t=>\t -" + this.parameter + " " + this.range[newValue] + " (" + newValue + ")"); + } + + this.numericValue = newValue; + this.value = this.range[this.numericValue]; + + // adapt distribution + this.std = this.std * Math.pow(2, -1 * lambda); + } + +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java new file mode 100755 index 000000000..31d6224c6 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java @@ -0,0 +1,56 @@ +package moa.classifiers.meta.AutoML; + +import umontreal.iro.lecuyer.probdist.NormalDist; +import java.util.Random; + +public class TruncatedNormal { + + double mean; + double sd; + double lb; + double ub; + double cdf_a; + double Z; + + // https://en.wikipedia.org/wiki/Truncated_normal_distribution + TruncatedNormal(double mean, double sd, double lb, double ub){ + this.mean = mean; + this.sd = sd; + this.lb = lb; + this.ub = ub; + + this.cdf_a = NormalDist.cdf01((lb - mean)/sd); + double cdf_b = NormalDist.cdf01((ub - mean)/sd); + this.Z = cdf_b - cdf_a; + } + + public double sample() { + //TODO This is the simple sampling strategy. Faster approaches are available + Random random = new Random(); + double val = random.nextDouble() * Z + cdf_a; + return mean + sd * NormalDist.inverseF01(val); + } + + + + + public static void main(String[] args) { + TruncatedNormal trncnorm = new TruncatedNormal(0,10,-5,5); + double min=Double.MAX_VALUE; + double max=-Double.MAX_VALUE; + double sum=0.0; + for(int i=0; i<100000; i++) { + double val = trncnorm.sample(); + sum += val; + if(val > max) max=val; + if(val < min) min=val; + System.out.println(val); + } + + System.out.println("Min: " + min); + System.out.println("Max: " + max); + System.out.println("Mean: " + sum / 100000); + + } + +} diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json b/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json new file mode 100755 index 000000000..50b1eda78 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json @@ -0,0 +1,52 @@ +{ + + "windowSize" : 1000, + "ensembleSize" : 10, + "newConfigurations" : 10, + "keepCurrentModel" : true, + "lambda" : 0.05, + "preventAlgorithmDeath" : true, + "keepGlobalIncumbent" : true, + "keepAlgorithmIncumbents" : true, + "keepInitialConfigurations" : true, + "useTestEnsemble" : true, + "resetProbability" : 0.01, + "numberOfCores" : 1, + "performanceMeasureMaximisation": true, + + "algorithms": [ + { + "algorithm": "lazy.kNN", + "parameters": [ + {"parameter": "k", "type":"integer", "value":10, "range":[2,30]} + // {"parameter": "w", "type":"integer", "value":1000, "range":[500,2000]} + ] + } + , + { + "algorithm": "trees.HoeffdingTree", + "parameters": [ + {"parameter": "g", "type":"integer", "value":200, "range":[10, 200]}, + {"parameter": "c", "type":"float", "value":0.01, "range":[0, 1]} + ] + } + , + { + "algorithm": "lazy.kNNwithPAWandADWIN", + "parameters": [ + {"parameter": "k", "type":"integer", "value":10, "range":[2,30]} + // {"parameter": "w", "type":"numeric", "value":1000, "range":[1000,5000]}, + //{"parameter": "m", "type":"boolean", "value":"false"} + ] + } + , + { + "algorithm": "trees.HoeffdingAdaptiveTree", + "parameters": [ + {"parameter": "g", "type":"integer", "value":200, "range":[10, 200]}, + {"parameter": "c", "type":"float", "value":0.01, "range":[0, 1]} + ] + } + + ] +} From 6d1326091ad714634f6b527df75cc283b5c95c3e Mon Sep 17 00:00:00 2001 From: Heitor Murilo Gomes Date: Sat, 7 Sep 2024 06:19:31 +1200 Subject: [PATCH 25/31] update to autoclass --- .../java/moa/classifiers/lazy/RW_kNN.java | 4 ++-- .../classifiers/meta/AutoML/Algorithm.java | 2 +- .../classifiers/meta/AutoML/AutoClass.java | 5 ++++- .../meta/AutoML/TruncatedNormal.java | 20 +++++++++++++------ 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java index 4d9601502..09ad98294 100755 --- a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java +++ b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java @@ -99,7 +99,7 @@ public void trainOnInstanceImpl(Instance inst) { this.reservoir = new Instances(inst.dataset()); } if (this.limitOptionReservoir.getValue() <= this.reservoir.numInstances()) { - int replaceIndex = r.nextInt(this.limitOptionReservoir.getValue() - 1); + int replaceIndex = this.classifierRandom.nextInt(this.limitOptionReservoir.getValue() - 1); this.reservoir.set(replaceIndex, inst); } else this.reservoir.add(inst); @@ -155,6 +155,6 @@ public void getModelDescription(StringBuilder out, int indent) { } public boolean isRandomizable() { - return false; + return true; } } \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java b/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java index 4add3725f..441b68f59 100755 --- a/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java @@ -163,7 +163,7 @@ public void adjustAlgorithm(boolean keepCurrentModel, int verbose) { // these changes do not transfer over directly since all algorithms cache the // option values. Therefore we try to adjust the cached values if possible try { - ((AbstractClassifier) this.classifier).adjustParameters(); +// ((AbstractClassifier) this.classifier).adjustParameters(); if (verbose >= 2) { System.out.println("Changed: " + this.classifier.getCLICreationString(Classifier.class)); } diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java index c9b39aa68..4702989fa 100755 --- a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java @@ -256,7 +256,8 @@ protected void evaluatePerformance() { + "\t => \t performance: " + performance); } - String algorithm = this.ensemble.get(i).classifier.getPurposeString(); +// String algorithm = this.ensemble.get(i).classifier.getPurposeString(); + String algorithm = this.ensemble.get(i).classifier.getClass().getName(); if (!bestPerformanceIdxMap.containsKey(algorithm) || performance > bestPerformanceValMap.get(algorithm)) { bestPerformanceValMap.put(algorithm, performance); // best performance per algorithm bestPerformanceIdxMap.put(algorithm, i); // index of best performance per algorithm @@ -695,3 +696,5 @@ public Integer call() throws Exception { return 0; } } + } // Close the EnsembleRunnable class + diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java index 31d6224c6..192ae1577 100755 --- a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java @@ -1,6 +1,7 @@ package moa.classifiers.meta.AutoML; -import umontreal.iro.lecuyer.probdist.NormalDist; +import org.apache.commons.math3.distribution.NormalDistribution; +//import umontreal.iro.lecuyer.probdist.NormalDist; import java.util.Random; public class TruncatedNormal { @@ -11,16 +12,22 @@ public class TruncatedNormal { double ub; double cdf_a; double Z; - + + private NormalDistribution normalDist; + // https://en.wikipedia.org/wiki/Truncated_normal_distribution TruncatedNormal(double mean, double sd, double lb, double ub){ this.mean = mean; this.sd = sd; this.lb = lb; this.ub = ub; - - this.cdf_a = NormalDist.cdf01((lb - mean)/sd); - double cdf_b = NormalDist.cdf01((ub - mean)/sd); + + + normalDist = new NormalDistribution(0, 1); // Standard normal distribution + this.cdf_a = normalDist.cumulativeProbability((lb - mean) / sd); + double cdf_b = normalDist.cumulativeProbability((ub - mean) / sd); +// this.cdf_a = NormalDist.cdf01((lb - mean)/sd); +// double cdf_b = NormalDist.cdf01((ub - mean)/sd); this.Z = cdf_b - cdf_a; } @@ -28,7 +35,8 @@ public double sample() { //TODO This is the simple sampling strategy. Faster approaches are available Random random = new Random(); double val = random.nextDouble() * Z + cdf_a; - return mean + sd * NormalDist.inverseF01(val); + return mean + sd * normalDist.inverseCumulativeProbability(val); +// return mean + sd * NormalDistribution.inverseF01(val); } From c8a2118bfcee338f2f8864584c14219c62caf9db Mon Sep 17 00:00:00 2001 From: Heitor Murilo Gomes Date: Sat, 7 Sep 2024 22:23:30 +1200 Subject: [PATCH 26/31] fix: update to make it compatible with capymoa --- .../classifiers/meta/AutoML/AutoClass.java | 21 +++++++++++++++++++ .../AutoML/HeterogeneousEnsembleAbstract.java | 12 ++++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java index 4702989fa..94f151fda 100755 --- a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java @@ -5,6 +5,7 @@ import com.yahoo.labs.samoa.instances.DenseInstance; import com.yahoo.labs.samoa.instances.Instance; import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.instances.InstancesHeader; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.classifiers.meta.AdaptiveRandomForestRegressor; @@ -44,6 +45,13 @@ public class AutoClass extends HeterogeneousEnsembleAbstract implements MultiCla public FileOption fileOption = new FileOption("ConfigurationFile", 'f', "Configuration file in json format.", "/Users/mbahri/Desktop/Dell/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json", ".json", false); + @Override + public String getPurposeString() { + return "Autoclass: Automl for data stream classification. "+ + "Bahri, Maroua, and Nikolaos Georgantas. " + + "2023 IEEE International Conference on Big Data (BigData). IEEE, 2023."; + } + public void init() { this.fileOption.getFile(); } @@ -631,6 +639,19 @@ protected Measurement[] getModelMeasurementsImpl() { public void getModelDescription(StringBuilder out, int indent) { // TODO Auto-generated method stub } + + @Override + public void setModelContext(InstancesHeader ih) { + super.setModelContext(ih); + +// This will cause issues in case setModelContext is invoked before the ensemble has been created. +// It is likely safe to not perform this action due to how the context can be acquired later by the learners. +// However, it is worth reviewing this in the future. See also HeterogeneousEnsembleAbstract.setModelContext +// for (int i = 0; i < this.ensemble.size(); ++i) { +// this.ensemble.get(i).classifier.setModelContext(ih); +// } + } + @Override public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java index 7a17f8d64..3bcbdbafe 100755 --- a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java +++ b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java @@ -84,7 +84,7 @@ public abstract class HeterogeneousEnsembleAbstract extends AbstractClassifier i ','); public IntOption gracePerionOption = new IntOption("gracePeriod", 'g', - "How many instances before we reevalate the best classifier", 1, 1, + "How many instances before we re-evaluate the best classifier", 1, 1, Integer.MAX_VALUE); public IntOption activeClassifiersOption = new IntOption("activeClassifiers", @@ -143,10 +143,12 @@ public double[] getVotesForInstance(Instance inst) { @Override public void setModelContext(InstancesHeader ih) { super.setModelContext(ih); - - for (int i = 0; i < this.ensemble.length; ++i) { - this.ensemble[i].setModelContext(ih); - } +// This will cause issues in case setModelContext is invoked before the ensemble has been created. +// It is likely safe to not perform this action due to how the context can be acquired later by the learners. +// However, it is worth reviewing this in the future. See also AutoClass.setModelContext +// for (int i = 0; i < this.ensemble.length; ++i) { +// this.ensemble[i].setModelContext(ih); +// } } @Override From 382df4d69163512a19424653bbd9ac503471bd44 Mon Sep 17 00:00:00 2001 From: Spencer Sun Date: Wed, 19 Feb 2025 11:53:43 +1300 Subject: [PATCH 27/31] fix: debugging SRP_MB and TstThnTrn class --- .../classifiers/meta/minibatch/StreamingRandomPatchesMB.java | 1 + .../main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java index b5bb9a6e3..1c151c848 100644 --- a/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java +++ b/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java @@ -287,6 +287,7 @@ protected void initEnsemble(Instance instance) { break; case StreamingRandomPatchesMB.TRAIN_RANDOM_SUBSPACES: case StreamingRandomPatchesMB.TRAIN_RANDOM_PATCHES: + if (this.subspaces.isEmpty()) break; int selectedValue = this.classifierRandom.nextInt(subspaces.size()); ArrayList subsetOfFeatures = this.subspaces.get(selectedValue); subsetOfFeatures.add(instance.classIndex()); diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java index e0c69a5bd..e16d43791 100644 --- a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java @@ -25,6 +25,7 @@ import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; +import moa.classifiers.AbstractClassifierMiniBatch; import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.core.Example; @@ -217,6 +218,9 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + if (learner instanceof AbstractClassifierMiniBatch) { + ((AbstractClassifierMiniBatch) learner).trainingHasEnded(); + } return learningCurve; } From 12ad55ea9e2a755a52c862571f223fddd7d9471f Mon Sep 17 00:00:00 2001 From: heymarco Date: Wed, 12 Mar 2025 22:37:28 +0100 Subject: [PATCH 28/31] Added PLASTIC and the necessary adaptations to existing classes (#23) * Added PLASTIC and the necessary adaptations to existing classes * Added PLASTIC and PLASTIC-A. Now also supports NB and NBA as leaf prediction options --- ...GaussianNumericAttributeClassObserver.java | 18 +- .../NominalAttributeClassObserver.java | 33 +- .../NominalAttributeBinaryTest.java | 6 +- .../NumericAttributeBinaryTest.java | 10 +- .../main/java/moa/classifiers/trees/EFDT.java | 2716 ++++++++--------- .../java/moa/classifiers/trees/PLASTIC.java | 192 ++ .../java/moa/classifiers/trees/PLASTICA.java | 192 ++ .../CustomADWINChangeDetector.java | 27 + .../trees/plastic_util/CustomEFDTNode.java | 780 +++++ .../trees/plastic_util/CustomHTNode.java | 80 + .../trees/plastic_util/EFHATNode.java | 174 ++ .../trees/plastic_util/MappedTree.java | 409 +++ .../plastic_util/MeasuresNumberOfLeaves.java | 5 + .../plastic_util/PerformsTreeRevision.java | 5 + .../trees/plastic_util/PlasticBranch.java | 63 + .../trees/plastic_util/PlasticNode.java | 398 +++ .../plastic_util/PlasticTreeElement.java | 40 + .../trees/plastic_util/Restructurer.java | 315 ++ .../plastic_util/SuccessorIdentifier.java | 149 + .../trees/plastic_util/Successors.java | 221 ++ 20 files changed, 4457 insertions(+), 1376 deletions(-) create mode 100644 moa/src/main/java/moa/classifiers/trees/PLASTIC.java create mode 100644 moa/src/main/java/moa/classifiers/trees/PLASTICA.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java diff --git a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java index dbc0e0e17..a727c4066 100644 --- a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java +++ b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.classifiers.core.attributeclassobservers; @@ -81,7 +81,7 @@ public void observeAttributeClass(double attVal, int classVal, double weight) { @Override public double probabilityOfAttributeValueGivenClass(double attVal, - int classVal) { + int classVal) { GaussianEstimator obs = this.attValDistPerClass.get(classVal); return obs != null ? obs.probabilityDensity(attVal) : 0.0; } @@ -99,12 +99,24 @@ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) { bestSuggestion = new AttributeSplitSuggestion( new NumericAttributeBinaryTest(attIndex, splitValue, - true), postSplitDists, merit); + true), postSplitDists, merit); } } return bestSuggestion; } + /* Used by PLASTIC during restructuring when forcing a leaf split becomes necessary */ + public AttributeSplitSuggestion forceSplit( + SplitCriterion criterion, double[] preSplitDist, int attIndex, double threshold) { + AttributeSplitSuggestion bestSuggestion = null; + double[][] postSplitDists = getClassDistsResultingFromBinarySplit(threshold); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + bestSuggestion = new AttributeSplitSuggestion( + new NumericAttributeBinaryTest(attIndex, threshold, true), postSplitDists, merit); + return bestSuggestion; + } + public double[] getSplitPointSuggestions() { Set suggestedSplitValues = new TreeSet(); double minValue = Double.POSITIVE_INFINITY; diff --git a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java index ce8021017..c5de2406e 100644 --- a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java +++ b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.classifiers.core.attributeclassobservers; @@ -68,7 +68,7 @@ public void observeAttributeClass(double attVal, int classVal, double weight) { @Override public double probabilityOfAttributeValueGivenClass(double attVal, - int classVal) { + int classVal) { DoubleVector obs = this.attValDistPerClass.get(classVal); return obs != null ? (obs.getValue((int) attVal) + 1.0) / (obs.sumOfValues() + obs.numValues()) : 0.0; @@ -109,6 +109,33 @@ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( return bestSuggestion; } + /* Used by PLASTIC during restructuring when forcing a leaf split becomes necessary */ + public AttributeSplitSuggestion forceSplit( + SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binary, Double splitValue) { + AttributeSplitSuggestion bestSuggestion; + int maxAttValsObserved = getMaxAttValsObserved(); + if (!binary) { + double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + bestSuggestion = new AttributeSplitSuggestion( + new NominalAttributeMultiwayTest(attIndex), postSplitDists, + merit); + return bestSuggestion; + } + assert splitValue != null: "Split value is null"; + if (splitValue >= maxAttValsObserved) { + return null; + } + double[][] postSplitDists = getClassDistsResultingFromBinarySplit(splitValue.intValue()); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + bestSuggestion = new AttributeSplitSuggestion( + new NominalAttributeBinaryTest(attIndex, splitValue.intValue()), + postSplitDists, merit); + return bestSuggestion; + } + public int getMaxAttValsObserved() { int maxAttValsObserved = 0; for (DoubleVector attValDist : this.attValDistPerClass) { @@ -157,7 +184,7 @@ public double[][] getClassDistsResultingFromBinarySplit(int valIndex) { } } return new double[][]{equalsDist.getArrayRef(), - notEqualDist.getArrayRef()}; + notEqualDist.getArrayRef()}; } @Override diff --git a/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java b/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java index 73498ff6b..397d308b4 100644 --- a/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java +++ b/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.classifiers.core.conditionaltests; @@ -48,6 +48,10 @@ public int branchForInstance(Instance inst) { return inst.isMissing(instAttIndex) ? -1 : ((int) inst.value(instAttIndex) == this.attValue ? 0 : 1); } + public double getValue() { + return attValue; + } + @Override public String describeConditionForBranch(int branch, InstancesHeader context) { if ((branch == 0) || (branch == 1)) { diff --git a/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java b/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java index d8cb5e8c3..5acb11610 100644 --- a/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java +++ b/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java @@ -15,7 +15,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . - * + * */ package moa.classifiers.core.conditionaltests; @@ -39,16 +39,20 @@ public class NumericAttributeBinaryTest extends InstanceConditionalBinaryTest { protected boolean equalsPassesTest; public NumericAttributeBinaryTest(int attIndex, double attValue, - boolean equalsPassesTest) { + boolean equalsPassesTest) { this.attIndex = attIndex; this.attValue = attValue; this.equalsPassesTest = equalsPassesTest; } + public double getValue() { + return attValue; + } + @Override public int branchForInstance(Instance inst) { int instAttIndex = this.attIndex ; // < inst.classIndex() ? this.attIndex - // : this.attIndex + 1; + // : this.attIndex + 1; if (inst.isMissing(instAttIndex)) { return -1; } diff --git a/moa/src/main/java/moa/classifiers/trees/EFDT.java b/moa/src/main/java/moa/classifiers/trees/EFDT.java index d900ed250..314c6adfc 100644 --- a/moa/src/main/java/moa/classifiers/trees/EFDT.java +++ b/moa/src/main/java/moa/classifiers/trees/EFDT.java @@ -78,1538 +78,1522 @@ public class EFDT extends AbstractClassifier implements MultiClassClassifier { - private static final long serialVersionUID = 2L; - - public IntOption reEvalPeriodOption = new IntOption( - "reevaluationPeriod", - 'R', - "The number of instances an internal node should observe between re-evaluation attempts.", - 2000, 0, Integer.MAX_VALUE); - - public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', - "Maximum memory consumed by the tree.", 33554432, 0, - Integer.MAX_VALUE); - - /* - * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption( - * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{ - * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100", - * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating - * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints", - * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna - * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary - * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100 - * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0); - */ - public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", - 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, - "GaussianNumericAttributeClassObserver"); - - public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", - 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, - "NominalAttributeClassObserver"); - - public IntOption memoryEstimatePeriodOption = new IntOption( - "memoryEstimatePeriod", 'e', - "How many instances between memory consumption checks.", 1000000, - 0, Integer.MAX_VALUE); - - public IntOption gracePeriodOption = new IntOption( - "gracePeriod", - 'g', - "The number of instances a leaf should observe between split attempts.", - 200, 0, Integer.MAX_VALUE); - - public ClassOption splitCriterionOption = new ClassOption("splitCriterion", - 's', "Split criterion to use.", SplitCriterion.class, - "InfoGainSplitCriterion"); - - public FloatOption splitConfidenceOption = new FloatOption( - "splitConfidence", - 'c', - "The allowable error in split decision, values closer to 0 will take longer to decide.", - 0.0000001, 0.0, 1.0); - - public FloatOption tieThresholdOption = new FloatOption("tieThreshold", - 't', "Threshold below which a split will be forced to break ties.", - 0.05, 0.0, 1.0); - - public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', - "Only allow binary splits."); - - public FlagOption stopMemManagementOption = new FlagOption( - "stopMemManagement", 'z', - "Stop growing as soon as memory limit is hit."); - - public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts", - 'r', "Disable poor attributes."); - - public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p', - "Disable pre-pruning."); - - public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( - "leafprediction", 'l', "Leaf prediction to use.", new String[]{ - "MC", "NB", "NBAdaptive"}, new String[]{ - "Majority class", - "Naive Bayes", - "Naive Bayes Adaptive"}, 2); - - public IntOption nbThresholdOption = new IntOption( - "nbThreshold", - 'q', - "The number of instances a leaf should observe before permitting Naive Bayes.", - 0, 0, Integer.MAX_VALUE); - - protected Node treeRoot = null; - - protected int decisionNodeCount; - - protected int activeLeafNodeCount; - - protected int inactiveLeafNodeCount; - - protected double inactiveLeafByteSizeEstimate; - - protected double activeLeafByteSizeEstimate; - - protected double byteSizeEstimateOverheadFraction; - - protected boolean growthAllowed; - - protected int numInstances = 0; - - protected int splitCount = 0; - - @Override - public String getPurposeString() { - return "Hoeffding Tree or VFDT."; - } - - public long calcByteSize() { - long size = SizeOf.sizeOf(this); - if (this.treeRoot != null) { - size += this.treeRoot.calcByteSizeIncludingSubtree(); - } - return size; - } - - @Override - public long measureByteSize() { - return calcByteSize(); - } - - @Override - public void resetLearningImpl() { - this.treeRoot = null; - this.decisionNodeCount = 0; - this.activeLeafNodeCount = 0; - this.inactiveLeafNodeCount = 0; - this.inactiveLeafByteSizeEstimate = 0.0; - this.activeLeafByteSizeEstimate = 0.0; - this.byteSizeEstimateOverheadFraction = 1.0; - this.growthAllowed = true; - if (this.leafpredictionOption.getChosenIndex() > 0) { - this.removePoorAttsOption = null; - } - } - - @Override - public double[] getVotesForInstance(Instance inst) { - if (this.treeRoot != null) { - FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, - null, -1); - Node leafNode = foundNode.node; - if (leafNode == null) { - leafNode = foundNode.parent; - } - return leafNode.getClassVotes(inst, this); - } - else { - int numClasses = inst.dataset().numClasses(); - return new double[numClasses]; - } - } - - @Override - protected Measurement[] getModelMeasurementsImpl() { - FoundNode[] learningNodes = findLearningNodes(); - - return new Measurement[]{ - - new Measurement("tree size (nodes)", this.decisionNodeCount - + this.activeLeafNodeCount + this.inactiveLeafNodeCount), - new Measurement("tree size (leaves)", learningNodes.length), - new Measurement("active learning leaves", - this.activeLeafNodeCount), - new Measurement("tree depth", measureTreeDepth()), - new Measurement("active leaf byte size estimate", - this.activeLeafByteSizeEstimate), - new Measurement("inactive leaf byte size estimate", - this.inactiveLeafByteSizeEstimate), - new Measurement("byte size estimate overhead", - this.byteSizeEstimateOverheadFraction), - new Measurement("splits", - this.splitCount)}; - } - - public int measureTreeDepth() { - if (this.treeRoot != null) { - return this.treeRoot.subtreeDepth(); - } - return 0; - } - - @Override - public void getModelDescription(StringBuilder out, int indent) { - this.treeRoot.describeSubtree(this, out, indent); - } - - @Override - public boolean isRandomizable() { - return false; - } - - public static double computeHoeffdingBound(double range, double confidence, - double n) { - return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) - / (2.0 * n)); - } - - protected AttributeClassObserver newNominalClassObserver() { - AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption); - return (AttributeClassObserver) nominalClassObserver.copy(); - } - - protected AttributeClassObserver newNumericClassObserver() { - AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption); - return (AttributeClassObserver) numericClassObserver.copy(); - } - - public void enforceTrackerLimit() { - if ((this.inactiveLeafNodeCount > 0) - || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount - * this.inactiveLeafByteSizeEstimate) - * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) { - if (this.stopMemManagementOption.isSet()) { - this.growthAllowed = false; - return; - } - FoundNode[] learningNodes = findLearningNodes(); - Arrays.sort(learningNodes, new Comparator() { - - @Override - public int compare(FoundNode fn1, FoundNode fn2) { - return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise()); - } - }); - int maxActive = 0; - while (maxActive < learningNodes.length) { - maxActive++; - if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive) - * this.inactiveLeafByteSizeEstimate) - * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) { - maxActive--; - break; - } - } - int cutoff = learningNodes.length - maxActive; - for (int i = 0; i < cutoff; i++) { - if (learningNodes[i].node instanceof ActiveLearningNode) { - deactivateLearningNode( - (ActiveLearningNode) learningNodes[i].node, - learningNodes[i].parent, - learningNodes[i].parentBranch); - } - } - for (int i = cutoff; i < learningNodes.length; i++) { - if (learningNodes[i].node instanceof InactiveLearningNode) { - activateLearningNode( - (InactiveLearningNode) learningNodes[i].node, - learningNodes[i].parent, - learningNodes[i].parentBranch); - } - } - } - } - - public void estimateModelByteSizes() { - FoundNode[] learningNodes = findLearningNodes(); - long totalActiveSize = 0; - long totalInactiveSize = 0; - for (FoundNode foundNode : learningNodes) { - if (foundNode.node instanceof ActiveLearningNode) { - totalActiveSize += SizeOf.fullSizeOf(foundNode.node); - } - else { - totalInactiveSize += SizeOf.fullSizeOf(foundNode.node); - } - } - if (totalActiveSize > 0) { - this.activeLeafByteSizeEstimate = (double) totalActiveSize - / this.activeLeafNodeCount; - } - if (totalInactiveSize > 0) { - this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize - / this.inactiveLeafNodeCount; - } - long actualModelSize = this.measureByteSize(); - double estimatedModelSize = (this.activeLeafNodeCount - * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount - * this.inactiveLeafByteSizeEstimate); - this.byteSizeEstimateOverheadFraction = actualModelSize - / estimatedModelSize; - if (actualModelSize > this.maxByteSizeOption.getValue()) { - enforceTrackerLimit(); - } - } - - public void deactivateAllLeaves() { - FoundNode[] learningNodes = findLearningNodes(); - for (FoundNode learningNode : learningNodes) { - if (learningNode.node instanceof ActiveLearningNode) { - deactivateLearningNode( - (ActiveLearningNode) learningNode.node, - learningNode.parent, learningNode.parentBranch); - } - } - } - - protected void deactivateLearningNode(ActiveLearningNode toDeactivate, - SplitNode parent, int parentBranch) { - Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution()); - if (parent == null) { - this.treeRoot = newLeaf; - } - else { - parent.setChild(parentBranch, newLeaf); - } - this.activeLeafNodeCount--; - this.inactiveLeafNodeCount++; - } - - protected void activateLearningNode(InactiveLearningNode toActivate, - SplitNode parent, int parentBranch) { - Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution()); - if (parent == null) { - this.treeRoot = newLeaf; - } - else { - parent.setChild(parentBranch, newLeaf); - } - this.activeLeafNodeCount++; - this.inactiveLeafNodeCount--; - } - - protected FoundNode[] findLearningNodes() { - List foundList = new LinkedList<>(); - findLearningNodes(this.treeRoot, null, -1, foundList); - return foundList.toArray(new FoundNode[foundList.size()]); - } - - protected void findLearningNodes(Node node, SplitNode parent, - int parentBranch, List found) { - if (node != null) { - if (node instanceof LearningNode) { - found.add(new FoundNode(node, parent, parentBranch)); - } - if (node instanceof SplitNode) { - SplitNode splitNode = (SplitNode) node; - for (int i = 0; i < splitNode.numChildren(); i++) { - findLearningNodes(splitNode.getChild(i), splitNode, i, - found); - } - } - } - } - - protected void attemptToSplit(ActiveLearningNode node, SplitNode parent, - int parentIndex) { - - if (!node.observedClassDistributionIsPure()) { - node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains - - - SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption); - AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this); - Arrays.sort(bestSplitSuggestions); - boolean shouldSplit = false; - - for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { - - if (bestSplitSuggestion.splitTest != null) { - if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) { - node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0); - } - double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0])); - node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit); - } - - else { // handle the null attribute - double currentSum = node.getInfogainSum().get(-1); // null split - node.getInfogainSum().put(-1, Math.max(0.0, currentSum + bestSplitSuggestion.merit)); - assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here."; - } - - } - - if (bestSplitSuggestions.length < 2) { - shouldSplit = bestSplitSuggestions.length > 0; - } - - else { - double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), - this.splitConfidenceOption.getValue(), node.getWeightSeen()); - AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; - - double bestSuggestionAverageMerit; - double currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - - // because this is an unsplit leaf. current average merit should be always zero on the null split. - - if (bestSuggestion.splitTest == null) { // if you have a null split - bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - } - else { - bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); - } - - if (bestSuggestion.merit < 1e-10) { - shouldSplit = false; // we don't use average here - } - - else if ((bestSuggestionAverageMerit - currentAverageMerit) > - hoeffdingBound - || (hoeffdingBound < this.tieThresholdOption.getValue())) { - if (bestSuggestionAverageMerit - currentAverageMerit < hoeffdingBound) { - // Placeholder to list this possibility - } - shouldSplit = true; - } - - if (shouldSplit) { - for (Integer i : node.usedNominalAttributes) { - if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) { - shouldSplit = false; - break; - } - } - } - - // } - if ((this.removePoorAttsOption != null) - && this.removePoorAttsOption.isSet()) { - Set poorAtts = new HashSet<>(); - // scan 1 - add any poor to set - for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { - if (bestSplitSuggestion.splitTest != null) { - int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn(); - if (splitAtts.length == 1) { - if (bestSuggestion.merit - - bestSplitSuggestion.merit > hoeffdingBound) { - poorAtts.add(splitAtts[0]); - } - } - } - } - // scan 2 - remove good ones from set - for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { - if (bestSplitSuggestion.splitTest != null) { - int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn(); - if (splitAtts.length == 1) { - if (bestSuggestion.merit - - bestSplitSuggestion.merit < hoeffdingBound) { - poorAtts.remove(splitAtts[0]); - } - } - } - } - for (int poorAtt : poorAtts) { - node.disableAttribute(poorAtt); - } - } - } - if (shouldSplit) { - splitCount++; - - AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; - if (splitDecision.splitTest == null) { - // preprune - null wins - deactivateLearningNode(node, parent, parentIndex); - } - else { - Node newSplit = newSplitNode(splitDecision.splitTest, - node.getObservedClassDistribution(), splitDecision.numSplits()); - ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers - newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, leaf to split - - for (int i = 0; i < splitDecision.numSplits(); i++) { - - double[] j = splitDecision.resultingClassDistributionFromSplit(i); - - Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); - - if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class - || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) { - newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy - newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); - // no nominal attribute should be split on more than once in the path - } - ((EFDTSplitNode) newSplit).setChild(i, newChild); - } - this.activeLeafNodeCount--; - this.decisionNodeCount++; - this.activeLeafNodeCount += splitDecision.numSplits(); - if (parent == null) { - this.treeRoot = newSplit; - } - else { - parent.setChild(parentIndex, newSplit); - } - - } - // manage memory - enforceTrackerLimit(); - } - } - } - - @Override - public void trainOnInstanceImpl(Instance inst) { - - if (this.treeRoot == null) { - this.treeRoot = newLearningNode(); - ((EFDTNode) this.treeRoot).setRoot(true); - this.activeLeafNodeCount = 1; - } - - FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); - Node leafNode = foundNode.node; - - if (leafNode == null) { - leafNode = newLearningNode(); - foundNode.parent.setChild(foundNode.parentBranch, leafNode); - this.activeLeafNodeCount++; - } - - ((EFDTNode) this.treeRoot).learnFromInstance(inst, this, null, -1); - - numInstances++; - } - - - protected LearningNode newLearningNode() { - return new EFDTLearningNode(new double[0]); - } - - protected LearningNode newLearningNode(double[] initialClassObservations) { - return new EFDTLearningNode(initialClassObservations); - } - - protected SplitNode newSplitNode(InstanceConditionalTest splitTest, - double[] classObservations, int size) { - return new EFDTSplitNode(splitTest, classObservations, size); - } - - protected SplitNode newSplitNode(InstanceConditionalTest splitTest, - double[] classObservations) { - return new EFDTSplitNode(splitTest, classObservations); - } - - private int argmax(double[] array) { - - double max = array[0]; - int maxarg = 0; - - for (int i = 1; i < array.length; i++) { - - if (array[i] > max) { - max = array[i]; - maxarg = i; - } - } - return maxarg; - } - - public interface EFDTNode { - - boolean isRoot(); - - void setRoot(boolean isRoot); - - void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch); - - void setParent(EFDTSplitNode parent); - - EFDTSplitNode getParent(); - - } - - public static class FoundNode { - - public Node node; - - public SplitNode parent; - - public int parentBranch; - - public FoundNode(Node node, SplitNode parent, int parentBranch) { - this.node = node; - this.parent = parent; - this.parentBranch = parentBranch; - } - } - - public static class Node extends AbstractMOAObject { + private static final long serialVersionUID = 2L; + + public IntOption reEvalPeriodOption = new IntOption( + "reevaluationPeriod", + 'R', + "The number of instances an internal node should observe between re-evaluation attempts.", + 2000, 0, Integer.MAX_VALUE); + + public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', + "Maximum memory consumed by the tree.", 33554432, 0, + Integer.MAX_VALUE); + + /* + * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption( + * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{ + * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100", + * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating + * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints", + * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna + * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary + * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100 + * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0); + */ + public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", + 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, + "GaussianNumericAttributeClassObserver"); - private HashMap infogainSum; + public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", + 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, + "NominalAttributeClassObserver"); - private int numSplitAttempts = 0; + public IntOption memoryEstimatePeriodOption = new IntOption( + "memoryEstimatePeriod", 'e', + "How many instances between memory consumption checks.", 1000000, + 0, Integer.MAX_VALUE); - private static final long serialVersionUID = 1L; + public IntOption gracePeriodOption = new IntOption( + "gracePeriod", + 'g', + "The number of instances a leaf should observe between split attempts.", + 200, 0, Integer.MAX_VALUE); - protected DoubleVector observedClassDistribution; + public ClassOption splitCriterionOption = new ClassOption("splitCriterion", + 's', "Split criterion to use.", SplitCriterion.class, + "InfoGainSplitCriterion"); - protected DoubleVector classDistributionAtTimeOfCreation; + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); - protected int nodeTime; + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); - protected List usedNominalAttributes = new ArrayList<>(); + public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', + "Only allow binary splits."); - public Node(double[] classObservations) { - this.observedClassDistribution = new DoubleVector(classObservations); - this.classDistributionAtTimeOfCreation = new DoubleVector(classObservations); - this.infogainSum = new HashMap<>(); - this.infogainSum.put(-1, 0.0); // Initialize for null split + public FlagOption stopMemManagementOption = new FlagOption( + "stopMemManagement", 'z', + "Stop growing as soon as memory limit is hit."); - } + public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts", + 'r', "Disable poor attributes."); - public int getNumSplitAttempts() { - return numSplitAttempts; - } + public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p', + "Disable pre-pruning."); - public void addToSplitAttempts(int i) { - numSplitAttempts += i; - } + public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( + "leafprediction", 'l', "Leaf prediction to use.", new String[]{ + "MC", "NB", "NBAdaptive"}, new String[]{ + "Majority class", + "Naive Bayes", + "Naive Bayes Adaptive"}, 2); - public HashMap getInfogainSum() { - return infogainSum; - } + public IntOption nbThresholdOption = new IntOption( + "nbThreshold", + 'q', + "The number of instances a leaf should observe before permitting Naive Bayes.", + 0, 0, Integer.MAX_VALUE); - public void setInfogainSum(HashMap igs) { - infogainSum = igs; - } + protected Node treeRoot = null; - public long calcByteSize() { - return (SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution)); - } + protected int decisionNodeCount; - public long calcByteSizeIncludingSubtree() { - return calcByteSize(); - } + protected int activeLeafNodeCount; - public boolean isLeaf() { - return true; - } + protected int inactiveLeafNodeCount; - public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, - int parentBranch) { - return new FoundNode(this, parent, parentBranch); - } + protected double inactiveLeafByteSizeEstimate; - public double[] getObservedClassDistribution() { - return this.observedClassDistribution.getArrayCopy(); - } + protected double activeLeafByteSizeEstimate; - public double[] getClassVotes(Instance inst, EFDT ht) { - return this.observedClassDistribution.getArrayCopy(); - } + protected double byteSizeEstimateOverheadFraction; - public double[] getClassDistributionAtTimeOfCreation() { - return this.classDistributionAtTimeOfCreation.getArrayCopy(); - } + protected boolean growthAllowed; - public boolean observedClassDistributionIsPure() { - return this.observedClassDistribution.numNonZeroEntries() < 2; - } + protected int numInstances = 0; - public void describeSubtree(EFDT ht, StringBuilder out, - int indent) { - StringUtils.appendIndented(out, indent, "Leaf "); - out.append(ht.getClassNameString()); - out.append(" = "); - out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex())); - out.append(" weights: "); - this.observedClassDistribution.getSingleLineDescription(out, - ht.treeRoot.observedClassDistribution.numValues()); - StringUtils.appendNewline(out); - } + protected int splitCount = 0; - public int subtreeDepth() { - return 0; + @Override + public String getPurposeString() { + return "Hoeffding Tree or VFDT."; } - public double calculatePromise() { - double totalSeen = this.observedClassDistribution.sumOfValues(); - return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex())) - : 0.0; + public int calcByteSize() { + int size = (int) SizeOf.sizeOf(this); + if (this.treeRoot != null) { + size += this.treeRoot.calcByteSizeIncludingSubtree(); + } + return size; } @Override - public void getDescription(StringBuilder sb, int indent) { - describeSubtree(null, sb, indent); + public int measureByteSize() { + return calcByteSize(); } - } - - public static class SplitNode extends Node { - - private static final long serialVersionUID = 1L; - - protected InstanceConditionalTest splitTest; - - protected AutoExpandVector children; // = new AutoExpandVector(); @Override - public long calcByteSize() { - return super.calcByteSize() - + SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest); + public void resetLearningImpl() { + this.treeRoot = null; + this.decisionNodeCount = 0; + this.activeLeafNodeCount = 0; + this.inactiveLeafNodeCount = 0; + this.inactiveLeafByteSizeEstimate = 0.0; + this.activeLeafByteSizeEstimate = 0.0; + this.byteSizeEstimateOverheadFraction = 1.0; + this.growthAllowed = true; + if (this.leafpredictionOption.getChosenIndex() > 0) { + this.removePoorAttsOption = null; + } } @Override - public long calcByteSizeIncludingSubtree() { - long byteSize = calcByteSize(); - for (Node child : this.children) { - if (child != null) { - byteSize += child.calcByteSizeIncludingSubtree(); - } - } - return byteSize; - } - - public SplitNode(InstanceConditionalTest splitTest, - double[] classObservations, int size) { - super(classObservations); - this.splitTest = splitTest; - this.children = new AutoExpandVector<>(size); - } - - public SplitNode(InstanceConditionalTest splitTest, - double[] classObservations) { - super(classObservations); - this.splitTest = splitTest; - this.children = new AutoExpandVector<>(); - } - - - public int numChildren() { - return this.children.size(); - } - - public void setChild(int index, Node child) { - if ((this.splitTest.maxBranches() >= 0) - && (index >= this.splitTest.maxBranches())) { - throw new IndexOutOfBoundsException(); - } - this.children.set(index, child); - } - - public Node getChild(int index) { - return this.children.get(index); - } - - public int instanceChildIndex(Instance inst) { - return this.splitTest.branchForInstance(inst); + public double[] getVotesForInstance(Instance inst) { + if (this.treeRoot != null) { + FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, + null, -1); + Node leafNode = foundNode.node; + if (leafNode == null) { + leafNode = foundNode.parent; + } + return leafNode.getClassVotes(inst, this); + } else { + int numClasses = inst.dataset().numClasses(); + return new double[numClasses]; + } } @Override - public boolean isLeaf() { - return false; + protected Measurement[] getModelMeasurementsImpl() { + FoundNode[] learningNodes = findLearningNodes(); + + return new Measurement[]{ + + new Measurement("tree size (nodes)", this.decisionNodeCount + + this.activeLeafNodeCount + this.inactiveLeafNodeCount), + new Measurement("tree size (leaves)", learningNodes.length), + new Measurement("active learning leaves", + this.activeLeafNodeCount), + new Measurement("tree depth", measureTreeDepth()), + new Measurement("active leaf byte size estimate", + this.activeLeafByteSizeEstimate), + new Measurement("inactive leaf byte size estimate", + this.inactiveLeafByteSizeEstimate), + new Measurement("byte size estimate overhead", + this.byteSizeEstimateOverheadFraction), + new Measurement("splits", + this.splitCount)}; + } + + public int measureTreeDepth() { + if (this.treeRoot != null) { + return this.treeRoot.subtreeDepth(); + } + return 0; } @Override - public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, - int parentBranch) { - - //System.err.println("OVERRIDING "); - - int childIndex = instanceChildIndex(inst); - if (childIndex >= 0) { - Node child = getChild(childIndex); - if (child != null) { - return child.filterInstanceToLeaf(inst, this, childIndex); - } - return new FoundNode(null, this, childIndex); - } - return new FoundNode(this, parent, parentBranch); + public void getModelDescription(StringBuilder out, int indent) { + this.treeRoot.describeSubtree(this, out, indent); } @Override - public void describeSubtree(EFDT ht, StringBuilder out, - int indent) { - for (int branch = 0; branch < numChildren(); branch++) { - Node child = getChild(branch); - if (child != null) { - StringUtils.appendIndented(out, indent, "if "); - out.append(this.splitTest.describeConditionForBranch(branch, - ht.getModelContext())); - out.append(": "); - StringUtils.appendNewline(out); - child.describeSubtree(ht, out, indent + 2); - } - } + public boolean isRandomizable() { + return false; + } + + public static double computeHoeffdingBound(double range, double confidence, + double n) { + return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) + / (2.0 * n)); + } + + protected AttributeClassObserver newNominalClassObserver() { + AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption); + return (AttributeClassObserver) nominalClassObserver.copy(); + } + + protected AttributeClassObserver newNumericClassObserver() { + AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption); + return (AttributeClassObserver) numericClassObserver.copy(); + } + + public void enforceTrackerLimit() { + if ((this.inactiveLeafNodeCount > 0) + || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount + * this.inactiveLeafByteSizeEstimate) + * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) { + if (this.stopMemManagementOption.isSet()) { + this.growthAllowed = false; + return; + } + FoundNode[] learningNodes = findLearningNodes(); + Arrays.sort(learningNodes, new Comparator() { + + @Override + public int compare(FoundNode fn1, FoundNode fn2) { + return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise()); + } + }); + int maxActive = 0; + while (maxActive < learningNodes.length) { + maxActive++; + if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive) + * this.inactiveLeafByteSizeEstimate) + * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) { + maxActive--; + break; + } + } + int cutoff = learningNodes.length - maxActive; + for (int i = 0; i < cutoff; i++) { + if (learningNodes[i].node instanceof ActiveLearningNode) { + deactivateLearningNode( + (ActiveLearningNode) learningNodes[i].node, + learningNodes[i].parent, + learningNodes[i].parentBranch); + } + } + for (int i = cutoff; i < learningNodes.length; i++) { + if (learningNodes[i].node instanceof InactiveLearningNode) { + activateLearningNode( + (InactiveLearningNode) learningNodes[i].node, + learningNodes[i].parent, + learningNodes[i].parentBranch); + } + } + } + } + + public void estimateModelByteSizes() { + FoundNode[] learningNodes = findLearningNodes(); + long totalActiveSize = 0; + long totalInactiveSize = 0; + for (FoundNode foundNode : learningNodes) { + if (foundNode.node instanceof ActiveLearningNode) { + totalActiveSize += SizeOf.fullSizeOf(foundNode.node); + } else { + totalInactiveSize += SizeOf.fullSizeOf(foundNode.node); + } + } + if (totalActiveSize > 0) { + this.activeLeafByteSizeEstimate = (double) totalActiveSize + / this.activeLeafNodeCount; + } + if (totalInactiveSize > 0) { + this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize + / this.inactiveLeafNodeCount; + } + int actualModelSize = this.measureByteSize(); + double estimatedModelSize = (this.activeLeafNodeCount + * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount + * this.inactiveLeafByteSizeEstimate); + this.byteSizeEstimateOverheadFraction = actualModelSize + / estimatedModelSize; + if (actualModelSize > this.maxByteSizeOption.getValue()) { + enforceTrackerLimit(); + } + } + + public void deactivateAllLeaves() { + FoundNode[] learningNodes = findLearningNodes(); + for (FoundNode learningNode : learningNodes) { + if (learningNode.node instanceof ActiveLearningNode) { + deactivateLearningNode( + (ActiveLearningNode) learningNode.node, + learningNode.parent, learningNode.parentBranch); + } + } + } + + protected void deactivateLearningNode(ActiveLearningNode toDeactivate, + SplitNode parent, int parentBranch) { + Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution()); + if (parent == null) { + this.treeRoot = newLeaf; + } else { + parent.setChild(parentBranch, newLeaf); + } + this.activeLeafNodeCount--; + this.inactiveLeafNodeCount++; + } + + protected void activateLearningNode(InactiveLearningNode toActivate, + SplitNode parent, int parentBranch) { + Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution()); + if (parent == null) { + this.treeRoot = newLeaf; + } else { + parent.setChild(parentBranch, newLeaf); + } + this.activeLeafNodeCount++; + this.inactiveLeafNodeCount--; + } + + protected FoundNode[] findLearningNodes() { + List foundList = new LinkedList<>(); + findLearningNodes(this.treeRoot, null, -1, foundList); + return foundList.toArray(new FoundNode[foundList.size()]); + } + + protected void findLearningNodes(Node node, SplitNode parent, + int parentBranch, List found) { + if (node != null) { + if (node instanceof LearningNode) { + found.add(new FoundNode(node, parent, parentBranch)); + } + if (node instanceof SplitNode) { + SplitNode splitNode = (SplitNode) node; + for (int i = 0; i < splitNode.numChildren(); i++) { + findLearningNodes(splitNode.getChild(i), splitNode, i, + found); + } + } + } } - @Override - public int subtreeDepth() { - int maxChildDepth = 0; - for (Node child : this.children) { - if (child != null) { - int depth = child.subtreeDepth(); - if (depth > maxChildDepth) { - maxChildDepth = depth; - } - } - } - return maxChildDepth + 1; + protected void attemptToSplit(ActiveLearningNode node, SplitNode parent, + int parentIndex) { + + if (!node.observedClassDistributionIsPure()) { + node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains + + + SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption); + AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this); + Arrays.sort(bestSplitSuggestions); + boolean shouldSplit = false; + + for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { + + if (bestSplitSuggestion.splitTest != null) { + if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) { + node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0); + } + double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0])); + node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit); + } else { // handle the null attribute + double currentSum = node.getInfogainSum().get(-1); // null split + node.getInfogainSum().put(-1, Math.max(0.0, currentSum + bestSplitSuggestion.merit)); + assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here."; + } + + } + + if (bestSplitSuggestions.length < 2) { + shouldSplit = bestSplitSuggestions.length > 0; + } else { + double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), + this.splitConfidenceOption.getValue(), node.getWeightSeen()); + AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + + double bestSuggestionAverageMerit; + double currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + + // because this is an unsplit leaf. current average merit should be always zero on the null split. + + if (bestSuggestion.splitTest == null) { // if you have a null split + bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + } else { + bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); + } + + if (bestSuggestion.merit < 1e-10) { + shouldSplit = false; // we don't use average here + } else if ((bestSuggestionAverageMerit - currentAverageMerit) > + hoeffdingBound + || (hoeffdingBound < this.tieThresholdOption.getValue())) { + if (bestSuggestionAverageMerit - currentAverageMerit < hoeffdingBound) { + // Placeholder to list this possibility + } + shouldSplit = true; + } + + if (shouldSplit) { + for (Integer i : node.usedNominalAttributes) { + if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) { + shouldSplit = false; + break; + } + } + } + + // } + if ((this.removePoorAttsOption != null) + && this.removePoorAttsOption.isSet()) { + Set poorAtts = new HashSet<>(); + // scan 1 - add any poor to set + for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { + if (bestSplitSuggestion.splitTest != null) { + int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn(); + if (splitAtts.length == 1) { + if (bestSuggestion.merit + - bestSplitSuggestion.merit > hoeffdingBound) { + poorAtts.add(splitAtts[0]); + } + } + } + } + // scan 2 - remove good ones from set + for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { + if (bestSplitSuggestion.splitTest != null) { + int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn(); + if (splitAtts.length == 1) { + if (bestSuggestion.merit + - bestSplitSuggestion.merit < hoeffdingBound) { + poorAtts.remove(splitAtts[0]); + } + } + } + } + for (int poorAtt : poorAtts) { + node.disableAttribute(poorAtt); + } + } + } + if (shouldSplit) { + splitCount++; + + AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + if (splitDecision.splitTest == null) { + // preprune - null wins + deactivateLearningNode(node, parent, parentIndex); + } else { + Node newSplit = newSplitNode(splitDecision.splitTest, + node.getObservedClassDistribution(), splitDecision.numSplits()); + ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers + newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, leaf to split + + for (int i = 0; i < splitDecision.numSplits(); i++) { + + double[] j = splitDecision.resultingClassDistributionFromSplit(i); + + Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); + + if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class + || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) { + newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy + newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); + // no nominal attribute should be split on more than once in the path + } + ((EFDTSplitNode) newSplit).setChild(i, newChild); + } + this.activeLeafNodeCount--; + this.decisionNodeCount++; + this.activeLeafNodeCount += splitDecision.numSplits(); + if (parent == null) { + this.treeRoot = newSplit; + } else { + parent.setChild(parentIndex, newSplit); + } + + } + // manage memory + enforceTrackerLimit(); + } + } } - } + @Override + public void trainOnInstanceImpl(Instance inst) { - public class EFDTSplitNode extends SplitNode implements EFDTNode { - - /** - * - */ - - private boolean isRoot; + if (this.treeRoot == null) { + this.treeRoot = newLearningNode(); + ((EFDTNode) this.treeRoot).setRoot(true); + this.activeLeafNodeCount = 1; + } - private EFDTSplitNode parent = null; + FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + Node leafNode = foundNode.node; - private static final long serialVersionUID = 1L; + if (leafNode == null) { + leafNode = newLearningNode(); + foundNode.parent.setChild(foundNode.parentBranch, leafNode); + this.activeLeafNodeCount++; + } - protected AutoExpandVector attributeObservers; + ((EFDTNode) this.treeRoot).learnFromInstance(inst, this, null, -1); - public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations, int size) { - super(splitTest, classObservations, size); + numInstances++; } - public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { - super(splitTest, classObservations); - } - @Override - public boolean isRoot() { - return isRoot; + protected LearningNode newLearningNode() { + return new EFDTLearningNode(new double[0], this.leafpredictionOption.getChosenLabel()); } - @Override - public void setRoot(boolean isRoot) { - this.isRoot = isRoot; + protected LearningNode newLearningNode(double[] initialClassObservations) { + return new EFDTLearningNode(initialClassObservations, this.leafpredictionOption.getChosenLabel()); } - public void killSubtree(EFDT ht) { - for (Node child : this.children) { - if (child != null) { - - //Recursive delete of SplitNodes - if (child instanceof SplitNode) { - ((EFDTSplitNode) child).killSubtree(ht); - } - else if (child instanceof ActiveLearningNode) { - ht.activeLeafNodeCount--; - } - else if (child instanceof InactiveLearningNode) { - ht.inactiveLeafNodeCount--; - } - } - } + protected SplitNode newSplitNode(InstanceConditionalTest splitTest, + double[] classObservations, int size) { + return new EFDTSplitNode(splitTest, classObservations, size); } - - // DRY Don't Repeat Yourself... code duplicated from ActiveLearningNode in VFDT.java. However, this is the most practical way to share stand-alone. - public AttributeSplitSuggestion[] getBestSplitSuggestions( - SplitCriterion criterion, EFDT ht) { - List bestSuggestions = new LinkedList<>(); - double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); - if (!ht.noPrePruneOption.isSet()) { - // add null split as an option - bestSuggestions.add(new AttributeSplitSuggestion(null, - new double[0][], criterion.getMeritOfSplit( - preSplitDist, new double[][]{preSplitDist}))); - } - for (int i = 0; i < this.attributeObservers.size(); i++) { - AttributeClassObserver obs = this.attributeObservers.get(i); - if (obs != null) { - AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, - preSplitDist, i, ht.binarySplitsOption.isSet()); - if (bestSuggestion != null) { - bestSuggestions.add(bestSuggestion); - } - } - } - return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + protected SplitNode newSplitNode(InstanceConditionalTest splitTest, + double[] classObservations) { + return new EFDTSplitNode(splitTest, classObservations); } + private int argmax(double[] array) { - @Override - public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) { - - nodeTime++; - //// Update node statistics and class distribution - - this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); // update prior (predictor) - - for (int i = 0; i < inst.numAttributes() - 1; i++) { //update likelihood - int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); - AttributeClassObserver obs = this.attributeObservers.get(i); - if (obs == null) { - obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver(); - this.attributeObservers.set(i, obs); - } - obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); - } - - // check if a better split is available. if so, chop the tree at this point, copying likelihood. predictors for children are from parent likelihood. - if (ht.numInstances % ht.reEvalPeriodOption.getValue() == 0) { - this.reEvaluateBestSplit(this, parent, parentBranch); - } + double max = array[0]; + int maxarg = 0; - int childBranch = this.instanceChildIndex(inst); - Node child = this.getChild(childBranch); - - if (child != null) { - ((EFDTNode) child).learnFromInstance(inst, ht, this, childBranch); - } + for (int i = 1; i < array.length; i++) { + if (array[i] > max) { + max = array[i]; + maxarg = i; + } + } + return maxarg; } - protected void reEvaluateBestSplit(EFDTSplitNode node, EFDTSplitNode parent, - int parentIndex) { - + public interface EFDTNode { - node.addToSplitAttempts(1); + boolean isRoot(); - // EFDT must transfer over gain averages when replacing a node: leaf to split, split to leaf, or split to split - // It must replace split nodes with leaves if null wins + void setRoot(boolean isRoot); + void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch); - // node is a reference to this anyway... why have it at all? + void setParent(EFDTSplitNode parent); - int currentSplit = -1; - // and if we always choose to maintain tree structure + EFDTSplitNode getParent(); - //lets first find out X_a, the current split + } - if (this.splitTest != null) { - currentSplit = this.splitTest.getAttsTestDependsOn()[0]; - // given the current implementations in MOA, we're only ever expecting one int to be returned - } + public static class FoundNode { - //compute Hoeffding bound - SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(EFDT.this.splitCriterionOption); - double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getClassDistributionAtTimeOfCreation()), - EFDT.this.splitConfidenceOption.getValue(), node.observedClassDistribution.sumOfValues()); + public Node node; - // get best split suggestions - AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, EFDT.this); - Arrays.sort(bestSplitSuggestions); + public SplitNode parent; - // get the best suggestion - AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + public int parentBranch; + public FoundNode(Node node, SplitNode parent, int parentBranch) { + this.node = node; + this.parent = parent; + this.parentBranch = parentBranch; + } + } - for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { + public static class Node extends AbstractMOAObject { - if (bestSplitSuggestion.splitTest != null) { - if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) { - node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0); - } - double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0])); - node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit); - } + private HashMap infogainSum; - else { // handle the null attribute. this is fine to do- it'll always average zero, and we will use this later to potentially burn bad splits. - double currentSum = node.getInfogainSum().get(-1); // null split - node.getInfogainSum().put(-1, currentSum + bestSplitSuggestion.merit); - } + private int numSplitAttempts = 0; - } + private static final long serialVersionUID = 1L; - // get the average merit for best and current splits + protected DoubleVector observedClassDistribution; - double bestSuggestionAverageMerit; - double currentAverageMerit; + protected DoubleVector classDistributionAtTimeOfCreation; - if (bestSuggestion.splitTest == null) { // best is null - bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - } - else { + protected int nodeTime; - bestSuggestionAverageMerit = node.getInfogainSum().get(bestSuggestion.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts(); - } + protected List usedNominalAttributes = new ArrayList<>(); - if (node.splitTest == null) { // current is null- shouldn't happen, check for robustness - currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - } - else { - currentAverageMerit = node.getInfogainSum().get(node.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts(); - } + public Node(double[] classObservations) { + this.observedClassDistribution = new DoubleVector(classObservations); + this.classDistributionAtTimeOfCreation = new DoubleVector(classObservations); + this.infogainSum = new HashMap<>(); + this.infogainSum.put(-1, 0.0); // Initialize for null split - double tieThreshold = EFDT.this.tieThresholdOption.getValue(); + } - // compute the average deltaG - double deltaG = bestSuggestionAverageMerit - currentAverageMerit; + public int getNumSplitAttempts() { + return numSplitAttempts; + } - if (deltaG > hoeffdingBound - || (hoeffdingBound < tieThreshold && deltaG > tieThreshold / 2)) { + public void addToSplitAttempts(int i) { + numSplitAttempts += i; + } - System.err.println(numInstances); + public HashMap getInfogainSum() { + return infogainSum; + } - AttributeSplitSuggestion splitDecision = bestSuggestion; + public void setInfogainSum(HashMap igs) { + infogainSum = igs; + } - // if null split wins - if (splitDecision.splitTest == null) { + public int calcByteSize() { + return (int) (SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution)); + } - node.killSubtree(EFDT.this); - EFDTLearningNode replacement = (EFDTLearningNode) newLearningNode(); - replacement.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement leaf - if (node.getParent() != null) { - node.getParent().setChild(parentIndex, replacement); - } - else { - assert (node.isRoot()); - node.setRoot(true); - } - } + public int calcByteSizeIncludingSubtree() { + return calcByteSize(); + } - else { + public boolean isLeaf() { + return true; + } - Node newSplit = newSplitNode(splitDecision.splitTest, - node.getObservedClassDistribution(), splitDecision.numSplits()); + public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, + int parentBranch) { + return new FoundNode(this, parent, parentBranch); + } - ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers - newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement split + public double[] getObservedClassDistribution() { + return this.observedClassDistribution.getArrayCopy(); + } - if (node.splitTest == splitDecision.splitTest - && node.splitTest.getClass() == NumericAttributeBinaryTest.class && - (argmax(splitDecision.resultingClassDistributions[0]) == argmax(node.getChild(0).getObservedClassDistribution()) - || argmax(splitDecision.resultingClassDistributions[1]) == argmax(node.getChild(1).getObservedClassDistribution())) - ) { - // change split but don't destroy the subtrees - for (int i = 0; i < splitDecision.numSplits(); i++) { - ((EFDTSplitNode) newSplit).setChild(i, this.getChild(i)); - } + public double[] getClassVotes(Instance inst, EFDT ht) { + return this.observedClassDistribution.getArrayCopy(); + } - } - else { + public double[] getClassDistributionAtTimeOfCreation() { + return this.classDistributionAtTimeOfCreation.getArrayCopy(); + } - // otherwise, torch the subtree and split on the new best attribute. + public boolean observedClassDistributionIsPure() { + return this.observedClassDistribution.numNonZeroEntries() < 2; + } - this.killSubtree(EFDT.this); + public void describeSubtree(EFDT ht, StringBuilder out, + int indent) { + StringUtils.appendIndented(out, indent, "Leaf "); + out.append(ht.getClassNameString()); + out.append(" = "); + out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex())); + out.append(" weights: "); + this.observedClassDistribution.getSingleLineDescription(out, + ht.treeRoot.observedClassDistribution.numValues()); + StringUtils.appendNewline(out); + } - for (int i = 0; i < splitDecision.numSplits(); i++) { + public int subtreeDepth() { + return 0; + } + + public double calculatePromise() { + double totalSeen = this.observedClassDistribution.sumOfValues(); + return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex())) + : 0.0; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + describeSubtree(null, sb, indent); + } + } + + public static class SplitNode extends Node { + + private static final long serialVersionUID = 1L; + + protected InstanceConditionalTest splitTest; + + protected AutoExpandVector children; // = new AutoExpandVector(); + + @Override + public int calcByteSize() { + return super.calcByteSize() + + (int) (SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest)); + } + + @Override + public int calcByteSizeIncludingSubtree() { + int byteSize = calcByteSize(); + for (Node child : this.children) { + if (child != null) { + byteSize += child.calcByteSizeIncludingSubtree(); + } + } + return byteSize; + } + + public SplitNode(InstanceConditionalTest splitTest, + double[] classObservations, int size) { + super(classObservations); + this.splitTest = splitTest; + this.children = new AutoExpandVector<>(size); + } + + public SplitNode(InstanceConditionalTest splitTest, + double[] classObservations) { + super(classObservations); + this.splitTest = splitTest; + this.children = new AutoExpandVector<>(); + } + + + public int numChildren() { + return this.children.size(); + } + + public void setChild(int index, Node child) { + if ((this.splitTest.maxBranches() >= 0) + && (index >= this.splitTest.maxBranches())) { + throw new IndexOutOfBoundsException(); + } + this.children.set(index, child); + } + + public Node getChild(int index) { + return this.children.get(index); + } + + public int instanceChildIndex(Instance inst) { + return this.splitTest.branchForInstance(inst); + } + + @Override + public boolean isLeaf() { + return false; + } + + @Override + public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, + int parentBranch) { + + //System.err.println("OVERRIDING "); + + int childIndex = instanceChildIndex(inst); + if (childIndex >= 0) { + Node child = getChild(childIndex); + if (child != null) { + return child.filterInstanceToLeaf(inst, this, childIndex); + } + return new FoundNode(null, this, childIndex); + } + return new FoundNode(this, parent, parentBranch); + } + + @Override + public void describeSubtree(EFDT ht, StringBuilder out, + int indent) { + for (int branch = 0; branch < numChildren(); branch++) { + Node child = getChild(branch); + if (child != null) { + StringUtils.appendIndented(out, indent, "if "); + out.append(this.splitTest.describeConditionForBranch(branch, + ht.getModelContext())); + out.append(": "); + StringUtils.appendNewline(out); + child.describeSubtree(ht, out, indent + 2); + } + } + } + + @Override + public int subtreeDepth() { + int maxChildDepth = 0; + for (Node child : this.children) { + if (child != null) { + int depth = child.subtreeDepth(); + if (depth > maxChildDepth) { + maxChildDepth = depth; + } + } + } + return maxChildDepth + 1; + } + } + + + public class EFDTSplitNode extends SplitNode implements EFDTNode { + + /** + * + */ + + private boolean isRoot; + + private EFDTSplitNode parent = null; + + private static final long serialVersionUID = 1L; + + protected AutoExpandVector attributeObservers; + + public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations, int size) { + super(splitTest, classObservations, size); + } + + public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { + super(splitTest, classObservations); + } + + @Override + public boolean isRoot() { + return isRoot; + } + + @Override + public void setRoot(boolean isRoot) { + this.isRoot = isRoot; + } + + public void killSubtree(EFDT ht) { + for (Node child : this.children) { + if (child != null) { + + //Recursive delete of SplitNodes + if (child instanceof SplitNode) { + ((EFDTSplitNode) child).killSubtree(ht); + } else if (child instanceof ActiveLearningNode) { + ht.activeLeafNodeCount--; + } else if (child instanceof InactiveLearningNode) { + ht.inactiveLeafNodeCount--; + } + } + } + } + + + // DRY Don't Repeat Yourself... code duplicated from ActiveLearningNode in VFDT.java. However, this is the most practical way to share stand-alone. + public AttributeSplitSuggestion[] getBestSplitSuggestions( + SplitCriterion criterion, EFDT ht) { + List bestSuggestions = new LinkedList<>(); + double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); + if (!ht.noPrePruneOption.isSet()) { + // add null split as an option + bestSuggestions.add(new AttributeSplitSuggestion(null, + new double[0][], criterion.getMeritOfSplit( + preSplitDist, new double[][]{preSplitDist}))); + } + for (int i = 0; i < this.attributeObservers.size(); i++) { + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs != null) { + AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, + preSplitDist, i, ht.binarySplitsOption.isSet()); + if (bestSuggestion != null) { + bestSuggestions.add(bestSuggestion); + } + } + } + return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + } + + + @Override + public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) { + + nodeTime++; + //// Update node statistics and class distribution - double[] j = splitDecision.resultingClassDistributionFromSplit(i); + this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); // update prior (predictor) - Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); + for (int i = 0; i < inst.numAttributes() - 1; i++) { //update likelihood + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs == null) { + obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver(); + this.attributeObservers.set(i, obs); + } + obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); + } - if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class - || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) { - newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy - newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); - // no nominal attribute should be split on more than once in the path - } - ((EFDTSplitNode) newSplit).setChild(i, newChild); - } - - EFDT.this.activeLeafNodeCount--; - EFDT.this.decisionNodeCount++; - EFDT.this.activeLeafNodeCount += splitDecision.numSplits(); + // check if a better split is available. if so, chop the tree at this point, copying likelihood. predictors for children are from parent likelihood. + if (ht.numInstances % ht.reEvalPeriodOption.getValue() == 0) { + this.reEvaluateBestSplit(this, parent, parentBranch); + } - } + int childBranch = this.instanceChildIndex(inst); + Node child = this.getChild(childBranch); + if (child != null) { + ((EFDTNode) child).learnFromInstance(inst, ht, this, childBranch); + } - if (parent == null) { - ((EFDTNode) newSplit).setRoot(true); - ((EFDTNode) newSplit).setParent(null); - EFDT.this.treeRoot = newSplit; - } - else { - ((EFDTNode) newSplit).setRoot(false); - ((EFDTNode) newSplit).setParent(parent); - parent.setChild(parentIndex, newSplit); - } - } - } - } + } - @Override - public void setParent(EFDTSplitNode parent) { - this.parent = parent; - } + protected void reEvaluateBestSplit(EFDTSplitNode node, EFDTSplitNode parent, + int parentIndex) { - @Override - public EFDTSplitNode getParent() { - return this.parent; - } - } - public static abstract class LearningNode extends Node { + node.addToSplitAttempts(1); + + // EFDT must transfer over gain averages when replacing a node: leaf to split, split to leaf, or split to split + // It must replace split nodes with leaves if null wins + + + // node is a reference to this anyway... why have it at all? + + int currentSplit = -1; + // and if we always choose to maintain tree structure + + //lets first find out X_a, the current split + + if (this.splitTest != null) { + currentSplit = this.splitTest.getAttsTestDependsOn()[0]; + // given the current implementations in MOA, we're only ever expecting one int to be returned + } - private static final long serialVersionUID = 1L; + //compute Hoeffding bound + SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(EFDT.this.splitCriterionOption); + double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getClassDistributionAtTimeOfCreation()), + EFDT.this.splitConfidenceOption.getValue(), node.observedClassDistribution.sumOfValues()); - public LearningNode(double[] initialClassObservations) { - super(initialClassObservations); - } + // get best split suggestions + AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, EFDT.this); + Arrays.sort(bestSplitSuggestions); - public abstract void learnFromInstance(Instance inst, EFDT ht); - } + // get the best suggestion + AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; - public class EFDTLearningNode extends LearningNodeNBAdaptive implements EFDTNode { + for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) { - private boolean isRoot; + if (bestSplitSuggestion.splitTest != null) { + if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) { + node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0); + } + double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0])); + node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit); + } else { // handle the null attribute. this is fine to do- it'll always average zero, and we will use this later to potentially burn bad splits. + double currentSum = node.getInfogainSum().get(-1); // null split + node.getInfogainSum().put(-1, currentSum + bestSplitSuggestion.merit); + } - private EFDTSplitNode parent = null; + } - public EFDTLearningNode(double[] initialClassObservations) { - super(initialClassObservations); - } + // get the average merit for best and current splits + double bestSuggestionAverageMerit; + double currentAverageMerit; - /** - * - */ - private static final long serialVersionUID = -2525042202040084035L; + if (bestSuggestion.splitTest == null) { // best is null + bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + } else { - @Override - public boolean isRoot() { - return isRoot; - } + bestSuggestionAverageMerit = node.getInfogainSum().get(bestSuggestion.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts(); + } - @Override - public void setRoot(boolean isRoot) { - this.isRoot = isRoot; - } + if (node.splitTest == null) { // current is null- shouldn't happen, check for robustness + currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + } else { + currentAverageMerit = node.getInfogainSum().get(node.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts(); + } - @Override - public void learnFromInstance(Instance inst, EFDT ht) { - super.learnFromInstance(inst, ht); + double tieThreshold = EFDT.this.tieThresholdOption.getValue(); - } + // compute the average deltaG + double deltaG = bestSuggestionAverageMerit - currentAverageMerit; - @Override - public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) { - learnFromInstance(inst, ht); - - if (ht.growthAllowed) { - ActiveLearningNode activeLearningNode = this; - double weightSeen = activeLearningNode.getWeightSeen(); - if (activeLearningNode.nodeTime % ht.gracePeriodOption.getValue() == 0) { - attemptToSplit(activeLearningNode, parent, - parentBranch); - activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen); - } - } - } + if (deltaG > hoeffdingBound + || (hoeffdingBound < tieThreshold && deltaG > tieThreshold / 2)) { - @Override - public void setParent(EFDTSplitNode parent) { - this.parent = parent; - } + System.err.println(numInstances); - @Override - public EFDTSplitNode getParent() { - return this.parent; - } + AttributeSplitSuggestion splitDecision = bestSuggestion; - } + // if null split wins + if (splitDecision.splitTest == null) { - public static class InactiveLearningNode extends LearningNode { + node.killSubtree(EFDT.this); + EFDTLearningNode replacement = (EFDTLearningNode) newLearningNode(); + replacement.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement leaf + if (node.getParent() != null) { + node.getParent().setChild(parentIndex, replacement); + } else { + assert (node.isRoot()); + node.setRoot(true); + } + } else { - private static final long serialVersionUID = 1L; + Node newSplit = newSplitNode(splitDecision.splitTest, + node.getObservedClassDistribution(), splitDecision.numSplits()); - public InactiveLearningNode(double[] initialClassObservations) { - super(initialClassObservations); - } + ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers + newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement split - @Override - public void learnFromInstance(Instance inst, EFDT ht) { - this.observedClassDistribution.addToValue((int) inst.classValue(), - inst.weight()); - } - } + if (node.splitTest == splitDecision.splitTest + && node.splitTest.getClass() == NumericAttributeBinaryTest.class && + (argmax(splitDecision.resultingClassDistributions[0]) == argmax(node.getChild(0).getObservedClassDistribution()) + || argmax(splitDecision.resultingClassDistributions[1]) == argmax(node.getChild(1).getObservedClassDistribution())) + ) { + // change split but don't destroy the subtrees + for (int i = 0; i < splitDecision.numSplits(); i++) { + ((EFDTSplitNode) newSplit).setChild(i, this.getChild(i)); + } - public static class ActiveLearningNode extends LearningNode { + } else { - private static final long serialVersionUID = 1L; + // otherwise, torch the subtree and split on the new best attribute. - protected double weightSeenAtLastSplitEvaluation; + this.killSubtree(EFDT.this); - protected AutoExpandVector attributeObservers = new AutoExpandVector<>(); + for (int i = 0; i < splitDecision.numSplits(); i++) { - protected boolean isInitialized; + double[] j = splitDecision.resultingClassDistributionFromSplit(i); - public ActiveLearningNode(double[] initialClassObservations) { - super(initialClassObservations); - this.weightSeenAtLastSplitEvaluation = getWeightSeen(); - this.isInitialized = false; - } + Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); - @Override - public long calcByteSize() { - return super.calcByteSize() - + (SizeOf.fullSizeOf(this.attributeObservers)); - } + if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class + || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) { + newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy + newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); + // no nominal attribute should be split on more than once in the path + } + ((EFDTSplitNode) newSplit).setChild(i, newChild); + } - @Override - public void learnFromInstance(Instance inst, EFDT ht) { - nodeTime++; - - if (this.isInitialized) { - this.attributeObservers = new AutoExpandVector<>(inst.numAttributes()); - this.isInitialized = true; - } - this.observedClassDistribution.addToValue((int) inst.classValue(), - inst.weight()); - for (int i = 0; i < inst.numAttributes() - 1; i++) { - int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); - AttributeClassObserver obs = this.attributeObservers.get(i); - if (obs == null) { - obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver(); - this.attributeObservers.set(i, obs); - } - obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); - } - } + EFDT.this.activeLeafNodeCount--; + EFDT.this.decisionNodeCount++; + EFDT.this.activeLeafNodeCount += splitDecision.numSplits(); - public double getWeightSeen() { - return this.observedClassDistribution.sumOfValues(); - } + } - public double getWeightSeenAtLastSplitEvaluation() { - return this.weightSeenAtLastSplitEvaluation; - } - public void setWeightSeenAtLastSplitEvaluation(double weight) { - this.weightSeenAtLastSplitEvaluation = weight; - } + if (parent == null) { + ((EFDTNode) newSplit).setRoot(true); + ((EFDTNode) newSplit).setParent(null); + EFDT.this.treeRoot = newSplit; + } else { + ((EFDTNode) newSplit).setRoot(false); + ((EFDTNode) newSplit).setParent(parent); + parent.setChild(parentIndex, newSplit); + } + } + } + } - public AttributeSplitSuggestion[] getBestSplitSuggestions( - SplitCriterion criterion, EFDT ht) { - List bestSuggestions = new LinkedList<>(); - double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); - if (!ht.noPrePruneOption.isSet()) { - // add null split as an option - bestSuggestions.add(new AttributeSplitSuggestion(null, - new double[0][], criterion.getMeritOfSplit( - preSplitDist, new double[][]{preSplitDist}))); - } - for (int i = 0; i < this.attributeObservers.size(); i++) { - AttributeClassObserver obs = this.attributeObservers.get(i); - if (obs != null) { - AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, - preSplitDist, i, ht.binarySplitsOption.isSet()); - if (bestSuggestion != null) { - bestSuggestions.add(bestSuggestion); - } - } - } - return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); - } + @Override + public void setParent(EFDTSplitNode parent) { + this.parent = parent; + } - public void disableAttribute(int attIndex) { - this.attributeObservers.set(attIndex, - new NullAttributeClassObserver()); + @Override + public EFDTSplitNode getParent() { + return this.parent; + } } - } - public static class LearningNodeNB extends ActiveLearningNode { + public static abstract class LearningNode extends Node { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - public LearningNodeNB(double[] initialClassObservations) { - super(initialClassObservations); - } - - @Override - public double[] getClassVotes(Instance inst, EFDT ht) { - if (getWeightSeen() >= ht.nbThresholdOption.getValue()) { - return NaiveBayes.doNaiveBayesPrediction(inst, - this.observedClassDistribution, - this.attributeObservers); - } - return super.getClassVotes(inst, ht); - } + public LearningNode(double[] initialClassObservations) { + super(initialClassObservations); + } - @Override - public void disableAttribute(int attIndex) { - // should not disable poor atts - they are used in NB calc + public abstract void learnFromInstance(Instance inst, EFDT ht); } - } - public static class LearningNodeNBAdaptive extends LearningNodeNB { - private static final long serialVersionUID = 1L; + public class EFDTLearningNode extends LearningNodeNBAdaptive implements EFDTNode { - protected double mcCorrectWeight = 0.0; + private boolean isRoot; - protected double nbCorrectWeight = 0.0; + private EFDTSplitNode parent = null; - public LearningNodeNBAdaptive(double[] initialClassObservations) { - super(initialClassObservations); - } + private String predictionType; - @Override - public void learnFromInstance(Instance inst, EFDT ht) { - int trueClass = (int) inst.classValue(); - if (this.observedClassDistribution.maxIndex() == trueClass) { - this.mcCorrectWeight += inst.weight(); - } - if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst, - this.observedClassDistribution, this.attributeObservers)) == trueClass) { - this.nbCorrectWeight += inst.weight(); - } - super.learnFromInstance(inst, ht); - } + public EFDTLearningNode(double[] initialClassObservations, String predictionType) { + super(initialClassObservations); + this.predictionType = predictionType; + } - @Override - public double[] getClassVotes(Instance inst, EFDT ht) { - if (this.mcCorrectWeight > this.nbCorrectWeight) { - return this.observedClassDistribution.getArrayCopy(); - } - return NaiveBayes.doNaiveBayesPrediction(inst, - this.observedClassDistribution, this.attributeObservers); - } - } - static class VFDT extends EFDT { + /** + * + */ + private static final long serialVersionUID = -2525042202040084035L; - @Override - protected void attemptToSplit(ActiveLearningNode node, SplitNode parent, - int parentIndex) { - if (!node.observedClassDistributionIsPure()) { - - - SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption); - AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this); - - Arrays.sort(bestSplitSuggestions); - boolean shouldSplit = false; - - for (int i = 0; i < bestSplitSuggestions.length; i++){ - - node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains - - - if (bestSplitSuggestions[i].splitTest != null){ - if (!node.getInfogainSum().containsKey((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]))) - { - node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), 0.0); - } - double currentSum = node.getInfogainSum().get((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0])); - node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestions[i].merit); - } - - else { // handle the null attribute - double currentSum = node.getInfogainSum().get(-1); // null split - node.getInfogainSum().put(-1, currentSum + Math.max(0.0, bestSplitSuggestions[i].merit)); - assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here."; - } - - } - - if (bestSplitSuggestions.length < 2) { - shouldSplit = bestSplitSuggestions.length > 0; - } - - else { - - - double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), - this.splitConfidenceOption.getValue(), node.getWeightSeen()); - - AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; - AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2]; - - - double bestSuggestionAverageMerit = 0.0; - double secondBestSuggestionAverageMerit = 0.0; - - if(bestSuggestion.splitTest == null){ // if you have a null split - bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - } else{ - bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); - } - - if(secondBestSuggestion.splitTest == null){ // if you have a null split - secondBestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); - } else{ - secondBestSuggestionAverageMerit = node.getInfogainSum().get((secondBestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); - } - - //comment this if statement to get VFDT bug - if(bestSuggestion.merit < 1e-10){ // we don't use average here - shouldSplit = false; - } - - else - if ((bestSuggestionAverageMerit - secondBestSuggestionAverageMerit > hoeffdingBound) - || (hoeffdingBound < this.tieThresholdOption.getValue())) - { - shouldSplit = true; - } - - if(shouldSplit){ - for(Integer i : node.usedNominalAttributes){ - if(bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i){ - shouldSplit = false; - break; - } - } - } - - // } - if ((this.removePoorAttsOption != null) - && this.removePoorAttsOption.isSet()) { - Set poorAtts = new HashSet(); - // scan 1 - add any poor to set - for (int i = 0; i < bestSplitSuggestions.length; i++) { - if (bestSplitSuggestions[i].splitTest != null) { - int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); - if (splitAtts.length == 1) { - if (bestSuggestion.merit - - bestSplitSuggestions[i].merit > hoeffdingBound) { - poorAtts.add(new Integer(splitAtts[0])); - } - } - } - } - // scan 2 - remove good ones from set - for (int i = 0; i < bestSplitSuggestions.length; i++) { - if (bestSplitSuggestions[i].splitTest != null) { - int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); - if (splitAtts.length == 1) { - if (bestSuggestion.merit - - bestSplitSuggestions[i].merit < hoeffdingBound) { - poorAtts.remove(new Integer(splitAtts[0])); - } - } - } - } - for (int poorAtt : poorAtts) { - node.disableAttribute(poorAtt); - } - } - } - if (shouldSplit) { - splitCount++; - - AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; - if (splitDecision.splitTest == null) { - // preprune - null wins - deactivateLearningNode(node, parent, parentIndex); - } else { - SplitNode newSplit = newSplitNode(splitDecision.splitTest, - node.getObservedClassDistribution(), splitDecision.numSplits()); - for (int i = 0; i < splitDecision.numSplits(); i++) { - - double[] j = splitDecision.resultingClassDistributionFromSplit(i); - - Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); - - if(splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class - ||splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class){ - newChild.usedNominalAttributes = new ArrayList(node.usedNominalAttributes); //deep copy - newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); - // no nominal attribute should be split on more than once in the path - } - newSplit.setChild(i, newChild); - } - this.activeLeafNodeCount--; - this.decisionNodeCount++; - this.activeLeafNodeCount += splitDecision.numSplits(); - if (parent == null) { - this.treeRoot = newSplit; - } else { - parent.setChild(parentIndex, newSplit); - } - - } - - // manage memory - enforceTrackerLimit(); - } - } - } + @Override + public boolean isRoot() { + return isRoot; + } - @Override - protected LearningNode newLearningNode() { - return newLearningNode(new double[0]); - } + @Override + public void setRoot(boolean isRoot) { + this.isRoot = isRoot; + } - @Override - protected LearningNode newLearningNode(double[] initialClassObservations) { - LearningNode ret; - int predictionOption = this.leafpredictionOption.getChosenIndex(); - if (predictionOption == 0) { //MC - ret = new ActiveLearningNode(initialClassObservations); - } else if (predictionOption == 1) { //NB - ret = new LearningNodeNB(initialClassObservations); - } else { //NBAdaptive - ret = new LearningNodeNBAdaptive(initialClassObservations); - } - return ret; - } + @Override + public void learnFromInstance(Instance inst, EFDT ht) { + super.learnFromInstance(inst, ht); - @Override - protected SplitNode newSplitNode(InstanceConditionalTest splitTest, - double[] classObservations, int size) { - return new SplitNode(splitTest, classObservations, size); - } + } - @Override - protected SplitNode newSplitNode(InstanceConditionalTest splitTest, - double[] classObservations) { - return new SplitNode(splitTest, classObservations); - } + @Override + public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) { + learnFromInstance(inst, ht); - @Override - public void trainOnInstanceImpl(Instance inst) { - //System.err.println(i++); - if (this.treeRoot == null) { - this.treeRoot = newLearningNode(); - this.activeLeafNodeCount = 1; - } - FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); - Node leafNode = foundNode.node; - - if (leafNode == null) { - leafNode = newLearningNode(); - foundNode.parent.setChild(foundNode.parentBranch, leafNode); - this.activeLeafNodeCount++; - } - - if (leafNode instanceof LearningNode) { - LearningNode learningNode = (LearningNode) leafNode; - learningNode.learnFromInstance(inst, this); - if (this.growthAllowed - && (learningNode instanceof ActiveLearningNode)) { - ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode; - double weightSeen = activeLearningNode.getWeightSeen(); - if (activeLearningNode.nodeTime % this.gracePeriodOption.getValue() == 0) { - attemptToSplit(activeLearningNode, foundNode.parent, - foundNode.parentBranch); - activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen); - } - } - } - - if (this.trainingWeightSeenByModel - % this.memoryEstimatePeriodOption.getValue() == 0) { - estimateModelByteSizes(); - } - - numInstances++; + if (ht.growthAllowed) { + ActiveLearningNode activeLearningNode = this; + double weightSeen = activeLearningNode.getWeightSeen(); + if (activeLearningNode.nodeTime % ht.gracePeriodOption.getValue() == 0) { + attemptToSplit(activeLearningNode, parent, + parentBranch); + activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen); + } + } + } + + @Override + public void setParent(EFDTSplitNode parent) { + this.parent = parent; + } + + @Override + public EFDTSplitNode getParent() { + return this.parent; + } + + @Override + public double[] getClassVotes(Instance inst, EFDT ht) { + if (this.predictionType.equals("MC")) + return this.observedClassDistribution.getArrayCopy(); + else if (this.predictionType.equals("NB")) + return NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, this.attributeObservers); + // Naive Bayes Adaptive + if (this.mcCorrectWeight > this.nbCorrectWeight) { + return this.observedClassDistribution.getArrayCopy(); + } + return NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, this.attributeObservers); + } + } + + public static class InactiveLearningNode extends LearningNode { + + private static final long serialVersionUID = 1L; + + public InactiveLearningNode(double[] initialClassObservations) { + super(initialClassObservations); + } + + @Override + public void learnFromInstance(Instance inst, EFDT ht) { + this.observedClassDistribution.addToValue((int) inst.classValue(), + inst.weight()); + } + } + + public static class ActiveLearningNode extends LearningNode { + + private static final long serialVersionUID = 1L; + + protected double weightSeenAtLastSplitEvaluation; + + protected AutoExpandVector attributeObservers = new AutoExpandVector<>(); + + protected boolean isInitialized; + + public ActiveLearningNode(double[] initialClassObservations) { + super(initialClassObservations); + this.weightSeenAtLastSplitEvaluation = getWeightSeen(); + this.isInitialized = false; + } + + @Override + public int calcByteSize() { + return super.calcByteSize() + + (int) (SizeOf.fullSizeOf(this.attributeObservers)); + } + + @Override + public void learnFromInstance(Instance inst, EFDT ht) { + nodeTime++; + + if (this.isInitialized) { + this.attributeObservers = new AutoExpandVector<>(inst.numAttributes()); + this.isInitialized = true; + } + this.observedClassDistribution.addToValue((int) inst.classValue(), + inst.weight()); + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs == null) { + obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver(); + this.attributeObservers.set(i, obs); + } + obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); + } + } + + public double getWeightSeen() { + return this.observedClassDistribution.sumOfValues(); + } + + public double getWeightSeenAtLastSplitEvaluation() { + return this.weightSeenAtLastSplitEvaluation; + } + + public void setWeightSeenAtLastSplitEvaluation(double weight) { + this.weightSeenAtLastSplitEvaluation = weight; + } + + public AttributeSplitSuggestion[] getBestSplitSuggestions( + SplitCriterion criterion, EFDT ht) { + List bestSuggestions = new LinkedList<>(); + double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); + if (!ht.noPrePruneOption.isSet()) { + // add null split as an option + bestSuggestions.add(new AttributeSplitSuggestion(null, + new double[0][], criterion.getMeritOfSplit( + preSplitDist, new double[][]{preSplitDist}))); + } + for (int i = 0; i < this.attributeObservers.size(); i++) { + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs != null) { + AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, + preSplitDist, i, ht.binarySplitsOption.isSet()); + if (bestSuggestion != null) { + bestSuggestions.add(bestSuggestion); + } + } + } + return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + } + + public void disableAttribute(int attIndex) { + this.attributeObservers.set(attIndex, + new NullAttributeClassObserver()); + } + } + + public static class LearningNodeNB extends ActiveLearningNode { + + private static final long serialVersionUID = 1L; + + public LearningNodeNB(double[] initialClassObservations) { + super(initialClassObservations); + } + + @Override + public double[] getClassVotes(Instance inst, EFDT ht) { + if (getWeightSeen() >= ht.nbThresholdOption.getValue()) { + return NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, + this.attributeObservers); + } + return super.getClassVotes(inst, ht); + } + + @Override + public void disableAttribute(int attIndex) { + // should not disable poor atts - they are used in NB calc + } + } + + public static class LearningNodeNBAdaptive extends LearningNodeNB { + + private static final long serialVersionUID = 1L; + + protected double mcCorrectWeight = 0.0; + + protected double nbCorrectWeight = 0.0; + + public LearningNodeNBAdaptive(double[] initialClassObservations) { + super(initialClassObservations); + } + + @Override + public void learnFromInstance(Instance inst, EFDT ht) { + int trueClass = (int) inst.classValue(); + if (this.observedClassDistribution.maxIndex() == trueClass) { + this.mcCorrectWeight += inst.weight(); + } + if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, this.attributeObservers)) == trueClass) { + this.nbCorrectWeight += inst.weight(); + } + super.learnFromInstance(inst, ht); + } + + @Override + public double[] getClassVotes(Instance inst, EFDT ht) { + if (this.mcCorrectWeight > this.nbCorrectWeight) { + return this.observedClassDistribution.getArrayCopy(); + } + return NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, this.attributeObservers); + } + } + + static class VFDT extends EFDT { + + @Override + protected void attemptToSplit(ActiveLearningNode node, SplitNode parent, + int parentIndex) { + if (!node.observedClassDistributionIsPure()) { + + + SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption); + AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this); + + Arrays.sort(bestSplitSuggestions); + boolean shouldSplit = false; + + for (int i = 0; i < bestSplitSuggestions.length; i++) { + + node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains + + + if (bestSplitSuggestions[i].splitTest != null) { + if (!node.getInfogainSum().containsKey((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]))) { + node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), 0.0); + } + double currentSum = node.getInfogainSum().get((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0])); + node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestions[i].merit); + } else { // handle the null attribute + double currentSum = node.getInfogainSum().get(-1); // null split + node.getInfogainSum().put(-1, currentSum + Math.max(0.0, bestSplitSuggestions[i].merit)); + assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here."; + } + + } + + if (bestSplitSuggestions.length < 2) { + shouldSplit = bestSplitSuggestions.length > 0; + } else { + + + double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()), + this.splitConfidenceOption.getValue(), node.getWeightSeen()); + + AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2]; + + + double bestSuggestionAverageMerit = 0.0; + double secondBestSuggestionAverageMerit = 0.0; + + if (bestSuggestion.splitTest == null) { // if you have a null split + bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + } else { + bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); + } + + if (secondBestSuggestion.splitTest == null) { // if you have a null split + secondBestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts(); + } else { + secondBestSuggestionAverageMerit = node.getInfogainSum().get((secondBestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts(); + } + + //comment this if statement to get VFDT bug + if (bestSuggestion.merit < 1e-10) { // we don't use average here + shouldSplit = false; + } else if ((bestSuggestionAverageMerit - secondBestSuggestionAverageMerit > hoeffdingBound) + || (hoeffdingBound < this.tieThresholdOption.getValue())) { + shouldSplit = true; + } + + if (shouldSplit) { + for (Integer i : node.usedNominalAttributes) { + if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) { + shouldSplit = false; + break; + } + } + } + + // } + if ((this.removePoorAttsOption != null) + && this.removePoorAttsOption.isSet()) { + Set poorAtts = new HashSet(); + // scan 1 - add any poor to set + for (int i = 0; i < bestSplitSuggestions.length; i++) { + if (bestSplitSuggestions[i].splitTest != null) { + int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); + if (splitAtts.length == 1) { + if (bestSuggestion.merit + - bestSplitSuggestions[i].merit > hoeffdingBound) { + poorAtts.add(new Integer(splitAtts[0])); + } + } + } + } + // scan 2 - remove good ones from set + for (int i = 0; i < bestSplitSuggestions.length; i++) { + if (bestSplitSuggestions[i].splitTest != null) { + int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn(); + if (splitAtts.length == 1) { + if (bestSuggestion.merit + - bestSplitSuggestions[i].merit < hoeffdingBound) { + poorAtts.remove(new Integer(splitAtts[0])); + } + } + } + } + for (int poorAtt : poorAtts) { + node.disableAttribute(poorAtt); + } + } + } + if (shouldSplit) { + splitCount++; + + AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + if (splitDecision.splitTest == null) { + // preprune - null wins + deactivateLearningNode(node, parent, parentIndex); + } else { + SplitNode newSplit = newSplitNode(splitDecision.splitTest, + node.getObservedClassDistribution(), splitDecision.numSplits()); + for (int i = 0; i < splitDecision.numSplits(); i++) { + + double[] j = splitDecision.resultingClassDistributionFromSplit(i); + + Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i)); + + if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class + || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) { + newChild.usedNominalAttributes = new ArrayList(node.usedNominalAttributes); //deep copy + newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]); + // no nominal attribute should be split on more than once in the path + } + newSplit.setChild(i, newChild); + } + this.activeLeafNodeCount--; + this.decisionNodeCount++; + this.activeLeafNodeCount += splitDecision.numSplits(); + if (parent == null) { + this.treeRoot = newSplit; + } else { + parent.setChild(parentIndex, newSplit); + } + + } + + // manage memory + enforceTrackerLimit(); + } + } + } + + @Override + protected LearningNode newLearningNode() { + return newLearningNode(new double[0]); + } + + @Override + protected LearningNode newLearningNode(double[] initialClassObservations) { + LearningNode ret; + int predictionOption = this.leafpredictionOption.getChosenIndex(); + if (predictionOption == 0) { //MC + ret = new ActiveLearningNode(initialClassObservations); + } else if (predictionOption == 1) { //NB + ret = new LearningNodeNB(initialClassObservations); + } else { //NBAdaptive + ret = new LearningNodeNBAdaptive(initialClassObservations); + } + return ret; + } + + @Override + protected SplitNode newSplitNode(InstanceConditionalTest splitTest, + double[] classObservations, int size) { + return new SplitNode(splitTest, classObservations, size); + } + + @Override + protected SplitNode newSplitNode(InstanceConditionalTest splitTest, + double[] classObservations) { + return new SplitNode(splitTest, classObservations); + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + //System.err.println(i++); + if (this.treeRoot == null) { + this.treeRoot = newLearningNode(); + this.activeLeafNodeCount = 1; + } + FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + Node leafNode = foundNode.node; + + if (leafNode == null) { + leafNode = newLearningNode(); + foundNode.parent.setChild(foundNode.parentBranch, leafNode); + this.activeLeafNodeCount++; + } + + if (leafNode instanceof LearningNode) { + LearningNode learningNode = (LearningNode) leafNode; + learningNode.learnFromInstance(inst, this); + if (this.growthAllowed + && (learningNode instanceof ActiveLearningNode)) { + ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode; + double weightSeen = activeLearningNode.getWeightSeen(); + if (activeLearningNode.nodeTime % this.gracePeriodOption.getValue() == 0) { + attemptToSplit(activeLearningNode, foundNode.parent, + foundNode.parentBranch); + activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen); + } + } + } + + if (this.trainingWeightSeenByModel + % this.memoryEstimatePeriodOption.getValue() == 0) { + estimateModelByteSizes(); + } + + numInstances++; + } } - } } diff --git a/moa/src/main/java/moa/classifiers/trees/PLASTIC.java b/moa/src/main/java/moa/classifiers/trees/PLASTIC.java new file mode 100644 index 000000000..d11327224 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/PLASTIC.java @@ -0,0 +1,192 @@ +package moa.classifiers.trees; + +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.classifiers.MultiClassClassifier; +import moa.classifiers.trees.plastic_util.MeasuresNumberOfLeaves; +import moa.classifiers.trees.plastic_util.PerformsTreeRevision; +import moa.classifiers.trees.plastic_util.PlasticNode; +import moa.core.DoubleVector; +import moa.core.Measurement; +import moa.options.ClassOption; + +import java.util.ArrayList; + +/** + * PLASTIC + * + *

Restructures the subtrees that would have been pruned by EFDT.

+ * + *

See details in:
Marco Heyden, Heitor Murilo Gomes, Edouard Fouché, Bernhard Pfahringer, Klemens Böhm: + * Leveraging Plasticity in Incremental Decision Trees. ECML/PKDD (5) 2024: 38-54

+ * + * @author Marco Heyden (marco dot heyden at kit dot edu) + * @version $Revision: 1 $ + */ +public class PLASTIC extends AbstractClassifier implements MultiClassClassifier, PerformsTreeRevision, MeasuresNumberOfLeaves { + + PlasticNode root; + int seenItems = 0; + + public IntOption gracePeriodOption = new IntOption( + "gracePeriod", + 'g', + "The number of instances a leaf should observe between split attempts.", + 200, 0, Integer.MAX_VALUE); + + public IntOption reEvalPeriodOption = new IntOption( + "reevaluationPeriod", + 'R', + "The number of instances an internal node should observe between re-evaluation attempts.", + 200, 0, Integer.MAX_VALUE); + + public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", + 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, + "NominalAttributeClassObserver"); + + public ClassOption splitCriterionOption = new ClassOption("splitCriterion", + 's', "Split criterion to use.", SplitCriterion.class, + "InfoGainSplitCriterion"); + + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "The allowable error in split decision when using fixed confidence. Values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public FloatOption tieThresholdReevalOption = new FloatOption("tieThresholdReevaluation", + 'T', "Threshold below which a split will be forced to break ties during reevaluation.", + 0.05, 0.0, 1.0); + + public FloatOption relMinDeltaG = new FloatOption("relMinDeltaG", + 'G', "Relative minimum information gain to split a tie during reevaluation.", + 0.5, 0.0, 1.0); + + public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', + "Only allow binary splits."); + + public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( + "leafprediction", 'l', "Leaf prediction to use.", new String[]{ + "MC", "NB", "NBA"}, new String[]{ + "Majority class", "Naive Bayes", "Naive Bayes Adaptive"}, 2); + + public IntOption maxDepthOption = new IntOption( + "maxDepth", + 'D', + "Maximum allowed depth of tree.", + 20, 0, Integer.MAX_VALUE); + + public IntOption maxBranchLengthOption = new IntOption( + "maxBranchLength", + 'B', + "Maximum allowed length of branches during restructuring.", + 5, 1, Integer.MAX_VALUE); + + /** + * Creates and configures the root node of the tree + *

+ * The root is the only access point for the main PLASTIC class. All subtrees etc will simply connect to the root. + *

+ * @return the created root node + **/ + PlasticNode createRoot() { + return new PlasticNode( + (SplitCriterion) getPreparedClassOption(splitCriterionOption), + gracePeriodOption.getValue(), + splitConfidenceOption.getValue(), + 0.2, + false, + leafpredictionOption.getChosenLabel(), + reEvalPeriodOption.getValue(), + 0, + maxDepthOption.getValue(), + tieThresholdOption.getValue(), + tieThresholdReevalOption.getValue(), + 0.5, + binarySplitsOption.isSet(), + true, + (NominalAttributeClassObserver) getPreparedClassOption(nominalEstimatorOption), + new DoubleVector(), + new ArrayList<>(), + maxBranchLengthOption.getValue(), + 0.05, + -1 + ); + } + + /** + * Trains the tree on the provided instance + * @param inst The instance to train on + **/ + @Override + public void trainOnInstanceImpl(Instance inst) { + if (root == null) + root = createRoot(); + root.learnInstance(inst, seenItems); + seenItems++; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + if (root == null) { + root = createRoot(); + return new double[inst.numClasses()]; + } + return root.predict(inst); + } + + @Override + public void resetLearningImpl() { + root = null; + seenItems = 0; + } + + /** + * Checks if a tree revision (e.g., pruning or restructuring) was performed at the current time step + * + * @return if a tree revision was performed + **/ + @Override + public boolean didPerformTreeRevision() { + return root.didPerformTreeRevision(); + } + + /** + * Returns the number of leaves of the tree. + * + *

+ * Note that calling this function repeatedly causes some overhead + * as the function traverses the tree recursively. + *

+ * + * @return the created root node + **/ + @Override + public int getLeafNumber() { + return root.getLeafNumber(); + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return new Measurement[0]; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) {} + + @Override + public boolean isRandomizable() { + return false; + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/PLASTICA.java b/moa/src/main/java/moa/classifiers/trees/PLASTICA.java new file mode 100644 index 000000000..b3cc82056 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/PLASTICA.java @@ -0,0 +1,192 @@ +/* + * DriftDetectionMethodClassifier.java + * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand + * @author Manuel Baena (mbaena@lcc.uma.es) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package moa.classifiers.trees; + +import com.github.javacliparser.FloatOption; +import com.yahoo.labs.samoa.instances.Instance; +import moa.capabilities.CapabilitiesHandler; +import moa.capabilities.Capability; +import moa.capabilities.ImmutableCapabilities; +import moa.classifiers.AbstractClassifier; +import moa.classifiers.Classifier; +import moa.classifiers.MultiClassClassifier; +import moa.classifiers.trees.plastic_util.CustomADWINChangeDetector; +import moa.core.Measurement; +import moa.core.Utils; +import moa.options.ClassOption; + +import java.util.LinkedList; +import java.util.List; + +/** + * PLASTIC-A + * + *

Combination of PLASTIC and a drift detector at the root. Will grow a background tree after a change and + * replace the current tree once the BG tree is more accurate.

+ * + *

See details in:
Marco Heyden, Heitor Murilo Gomes, Edouard Fouché, Bernhard Pfahringer, Klemens Böhm: + * Leveraging Plasticity in Incremental Decision Trees. ECML/PKDD (5) 2024: 38-54

+ * + * @author Marco Heyden (marco dot heyden at kit dot edu) + * @version $Revision: 1 $ + */ +public class PLASTICA extends AbstractClassifier implements MultiClassClassifier, + CapabilitiesHandler { + + @Override + public String getPurposeString() { + return "Classifier that grows a background tree when a change is detected in accuracy."; + } + + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Classifier.class, "trees.PLASTIC"); + + public FloatOption confidenceOption = new FloatOption( + "Confidence", + 'c', + "Confidence at which the current learner will be replaced.", + 0.05, 0.0, 1.0); + + protected Classifier classifier; + + protected Classifier newclassifier; + + protected CustomADWINChangeDetector mainLearnerChangeDetector; + protected CustomADWINChangeDetector bgLearnerChangeDetector; + + private double confidence; + protected boolean newClassifierReset; + + @Override + public void resetLearningImpl() { + this.classifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy(); + this.classifier.resetLearning(); + mainLearnerChangeDetector = new CustomADWINChangeDetector(); + bgLearnerChangeDetector = new CustomADWINChangeDetector(); + this.newClassifierReset = false; + confidence = confidenceOption.getValue(); + } + + protected int changeDetected = 0; + + protected int warningDetected = 0; + + @Override + public void trainOnInstanceImpl(Instance inst) { + //this.numberInstances++; + int trueClass = (int) inst.classValue(); + boolean prediction; + if (Utils.maxIndex(this.classifier.getVotesForInstance(inst)) == trueClass) { + prediction = true; + } else { + prediction = false; + } + + mainLearnerChangeDetector.input(prediction ? 0.0 : 1.0); + + if (this.newclassifier != null) { + if (Utils.maxIndex(this.newclassifier.getVotesForInstance(inst)) == trueClass) { + prediction = true; + } else { + prediction = false; + } + bgLearnerChangeDetector.input(prediction ? 0.0 : 1.0); + } + + if (this.mainLearnerChangeDetector.getChange() && newclassifier == null) { + makeNewClassifier(); + } + else if (this.bgLearnerChangeDetector.getChange()) { + makeNewClassifier(); + } + + if (mainLearnerChangeDetector.getWidth() > 200 && bgLearnerChangeDetector.getWidth() > 200) { + double oldErrorRate = mainLearnerChangeDetector.getEstimation(); + double oldWS = mainLearnerChangeDetector.getWidth(); + double altErrorRate = bgLearnerChangeDetector.getEstimation(); + double altWS = bgLearnerChangeDetector.getWidth(); + + double Bound = computeBound(oldErrorRate, oldWS, altWS); + + if (Bound < oldErrorRate - altErrorRate) { + classifier = newclassifier; + newclassifier = null; + bgLearnerChangeDetector.resetLearning(); + } + else if (Bound < altErrorRate - oldErrorRate) { + // Erase alternate tree + newclassifier = null; + bgLearnerChangeDetector.resetLearning(); + } + } + + this.classifier.trainOnInstance(inst); + if (newclassifier != null) + newclassifier.trainOnInstance(inst); + } + + private double computeBound(double oldErrorRate, double oldWS, double altWS) { + double fDelta = confidenceOption.getValue(); + double fN = 1.0 / altWS + 1.0 / oldWS; + return Math.sqrt(2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN); + } + + public double[] getVotesForInstance(Instance inst) { + return this.classifier.getVotesForInstance(inst); + } + + @Override + public boolean isRandomizable() { + return true; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + ((AbstractClassifier) this.classifier).getModelDescription(out, indent); + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + List measurementList = new LinkedList(); + measurementList.add(new Measurement("Change detected", this.changeDetected)); + measurementList.add(new Measurement("Warning detected", this.warningDetected)); + Measurement[] modelMeasurements = ((AbstractClassifier) this.classifier).getModelMeasurements(); + if (modelMeasurements != null) { + for (Measurement measurement : modelMeasurements) { + measurementList.add(measurement); + } + } + this.changeDetected = 0; + this.warningDetected = 0; + return measurementList.toArray(new Measurement[measurementList.size()]); + } + + @Override + public ImmutableCapabilities defineImmutableCapabilities() { + if (this.getClass() == PLASTICA.class) + return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE); + else + return new ImmutableCapabilities(Capability.VIEW_STANDARD); + } + + private void makeNewClassifier() { + newclassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy(); + bgLearnerChangeDetector.resetLearning(); + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java new file mode 100644 index 000000000..8b428e841 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java @@ -0,0 +1,27 @@ +package moa.classifiers.trees.plastic_util; + +import moa.classifiers.core.driftdetection.ADWINChangeDetector; + +public class CustomADWINChangeDetector extends ADWINChangeDetector { + private boolean hadChange = false; + + @Override + public boolean getChange() { + boolean hadChangeCpy = hadChange; + hadChange = false; + return hadChangeCpy; + } + + @Override + public void input(double inputValue) { + super.input(inputValue); + if (!hadChange) + hadChange = super.getChange(); + } + + public int getWidth() { + if (adwin == null) + return 0; + return adwin.getWidth(); + } +} \ No newline at end of file diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java new file mode 100644 index 000000000..2ae55406f --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java @@ -0,0 +1,780 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; +import moa.AbstractMOAObject; +import moa.classifiers.bayes.NaiveBayes; +import moa.classifiers.core.AttributeSplitSuggestion; +import moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver; +import moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest; +import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest; +import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.classifiers.trees.EFDT; +import moa.core.AutoExpandVector; +import moa.core.DoubleVector; +import moa.core.Utils; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.*; + +class CustomEFDTNode extends AbstractMOAObject implements PerformsTreeRevision, MeasuresNumberOfLeaves { + protected final int gracePeriod; + protected final SplitCriterion splitCriterion; + protected final Double confidence; + protected final boolean useAdaptiveConfidence; + protected final Double adaptiveConfidence; + protected final String leafPrediction; + protected final Integer minSamplesReevaluate; + protected Integer depth; + protected final Integer maxDepth; + protected final Double tau; + protected final Double tauReevaluate; + protected final Double relMinDeltaG; + protected final boolean binaryOnly; + protected List usedNominalAttributes; + protected HashMap infogainSum = new HashMap<>(); + private InstanceConditionalTest splitTest; + protected int numSplitAttempts = 0; + protected final NominalAttributeClassObserver nominalObserverBlueprint; + protected final GaussianNumericAttributeClassObserver numericObserverBlueprint = new GaussianNumericAttributeClassObserver(); + + protected DoubleVector classDistributionAtTimeOfCreation; + protected DoubleVector observedClassDistribution; + protected AutoExpandVector attributeObservers = new AutoExpandVector<>(); + protected Double seenWeight = 0.0; + protected int nodeTime = 0; + protected Successors successors; + protected Attribute splitAttribute; + protected final boolean noPrePrune; + protected int blockedAttributeIndex; + protected boolean performedTreeRevision = false; + protected double mcCorrectWeight = 0.0; + protected double nbCorrectWeight = 0.0; + + public CustomEFDTNode(SplitCriterion splitCriterion, + int gracePeriod, + Double confidence, + Double adaptiveConfidence, + boolean useAdaptiveConfidence, + String leafPrediction, + Integer minSamplesReevaluate, + Integer depth, + Integer maxDepth, + Double tau, + Double tauReevaluate, + Double relMinDeltaG, + boolean binaryOnly, + boolean noPrePrune, + NominalAttributeClassObserver nominalObserverBlueprint, + DoubleVector observedClassDistribution, + List usedNominalAttributes, + int blockedAttributeIndex) { + this.gracePeriod = gracePeriod; + this.splitCriterion = splitCriterion; + this.confidence = confidence; + this.adaptiveConfidence = adaptiveConfidence; + this.useAdaptiveConfidence = useAdaptiveConfidence; + this.leafPrediction = leafPrediction; + this.minSamplesReevaluate = minSamplesReevaluate; + this.depth = depth; + this.maxDepth = maxDepth; + this.tau = tau; + this.tauReevaluate = tauReevaluate; + this.relMinDeltaG = relMinDeltaG; + this.binaryOnly = binaryOnly; + this.noPrePrune = noPrePrune; + this.usedNominalAttributes = usedNominalAttributes != null ? usedNominalAttributes : new LinkedList<>(); + this.nominalObserverBlueprint = nominalObserverBlueprint; + this.observedClassDistribution = observedClassDistribution != null ? observedClassDistribution : new DoubleVector(); + this.infogainSum.put(-1, 0.0); // Initialize for null split + classDistributionAtTimeOfCreation = new DoubleVector(this.observedClassDistribution); + this.blockedAttributeIndex = blockedAttributeIndex; + } + + protected double computeHoeffdingBound() { + double range = splitCriterion.getRangeOfMerit(observedClassDistribution.getArrayCopy()); + double n = seenWeight; + return Math.sqrt(((range * range) * Math.log(1.0 / currentConfidence())) + / (2.0 * n)); + } + + /** + * Gets the index of the attribute in the instance, + * given the index of the attribute in the learner. + * + * @param index the index of the attribute in the learner + * @param inst the instance + * @return the index in the instance + */ + protected static int modelAttIndexToInstanceAttIndex(int index, + Instance inst) { + return inst.classIndex() > index ? index : index + 1; + } + + public double[] getObservedClassDistribution() { + return observedClassDistribution.getArrayCopy(); + } + + public Successors getSuccessors() { + return successors; + } + + public Attribute getSplitAttribute() { + return splitAttribute; + } + + public Integer getSplitAttributeIndex() { + if (splitTest == null) + return -1; + return splitTest.getAttsTestDependsOn()[0]; + } + + public InstanceConditionalTest getSplitTest() { + return splitTest; + } + + public boolean setSplitTest(InstanceConditionalTest newTest) { + if (newTest == null) { + splitTest = null; + return true; + } + if (successors != null) { + if (successors.isNominal() && successors.isBinary()) { + if (!(newTest instanceof NominalAttributeBinaryTest)) + return false; + } + if (successors.isNominal() && !successors.isBinary()) { + if (!(newTest instanceof NominalAttributeMultiwayTest)) { + return false; + } + } + if (!successors.isNominal()) { + if (!(newTest instanceof NumericAttributeBinaryTest) || !successors.isBinary()) + return false; + } + } + splitTest = newTest; + return true; + } + + public Integer getDepth() { + return depth; + } + + public List getUsedNominalAttributes() { + return usedNominalAttributes; + } + + Double currentConfidence() { + if (!useAdaptiveConfidence) + return confidence; + return adaptiveConfidence * Math.exp(-numSplitAttempts); + } + + public AttributeSplitSuggestion[] getBestSplitSuggestions(SplitCriterion criterion) { + List bestSuggestions = new LinkedList<>(); + double[] preSplitDist = observedClassDistribution.getArrayCopy(); + if (!noPrePrune) { + // add null split as an option + bestSuggestions.add(new AttributeSplitSuggestion(null, + new double[0][], criterion.getMeritOfSplit( + preSplitDist, new double[][]{preSplitDist}))); + } + for (int i = 0; i < attributeObservers.size(); i++) { + AttributeClassObserver obs = attributeObservers.get(i); + if (obs != null) { + AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion( + criterion, preSplitDist, i, binaryOnly + ); + if (bestSuggestion != null) { + bestSuggestions.add(bestSuggestion); + } + } + } + return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + } + + /** + * Train on the provided instance. Re-evaluate the tree + * + * @param instance the instance to train on + * @param totalNumInstances the total number of instances observed so far. (not used in EFDT but might be overwritten by subclasses) + **/ + public void learnInstance(Instance instance, int totalNumInstances) { + seenWeight += instance.weight(); + nodeTime++; + updateStatistics(instance); + updateObservers(instance); + + if (isLeaf() && nodeTime % gracePeriod == 0) { + attemptInitialSplit(instance); + } + if (!isLeaf() && nodeTime % minSamplesReevaluate == 0) { + reevaluateSplit(instance); + } + if (!isLeaf()) { + propagateToSuccessors(instance, totalNumInstances); + } + + if (!"MC".equals(leafPrediction)) + updateNaiveBayesAdaptive(instance); + } + + public void updateNaiveBayesAdaptive(Instance inst) { + int trueClass = (int) inst.classValue(); + if (this.observedClassDistribution.maxIndex() == trueClass) { + this.mcCorrectWeight += inst.weight(); + } + if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst, + this.observedClassDistribution, this.attributeObservers)) == trueClass) { + this.nbCorrectWeight += inst.weight(); + } + } + + /** + * Predict the provided instance + *

+ * Traverses the tree until reaching a leaf and then returns the class votes of that leaf. + *

+ * + * @param instance the instance to predict + * @return the class votes + **/ + public double[] predict(Instance instance) { + if (!isLeaf()) { + CustomEFDTNode successor = getSuccessor(instance); + if (successor == null) + return getClassVotes(); + return successor.predict(instance); + } + if ("MC".equals(leafPrediction)) + return getClassVotes(); + else if ("NB".equals(leafPrediction)) + return NaiveBayes.doNaiveBayesPrediction(instance, this.observedClassDistribution, this.attributeObservers); + else + return doNBAdaptive(instance); + } + + private double[] doNBAdaptive(Instance instance) { + if (this.mcCorrectWeight > this.nbCorrectWeight) { + return this.observedClassDistribution.getArrayCopy(); + } + return NaiveBayes.doNaiveBayesPrediction(instance, + this.observedClassDistribution, this.attributeObservers); + } + + /** + * Finds the successor of the current node based on the provided instance. + * + * @param instance the instance + * @return the successor node if it exists. Else, returns null. + **/ + CustomEFDTNode getSuccessor(Instance instance) { + if (isLeaf()) + return null; + Double attVal = instance.value(splitAttribute); + return successors.getSuccessorNode(attVal); + } + + /** + * Attempts to split this leaf + * + * @param instance the current instance + **/ + protected void attemptInitialSplit(Instance instance) { + if (depth >= maxDepth) { + return; + } + if (isPure()) + return; + + numSplitAttempts++; + + AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSuggestions); + AttributeSplitSuggestion xBest = bestSuggestions[bestSuggestions.length - 1]; + xBest = replaceBestSuggestionIfAttributeIsBlocked(xBest, bestSuggestions, blockedAttributeIndex); + + if (!shouldSplitLeaf(bestSuggestions, currentConfidence(), observedClassDistribution)) + return; + if (xBest.splitTest == null) { + // preprune - null wins + System.out.println("preprune - null wins"); + killSubtree(); + resetSplitAttribute(); + return; + } + int instanceIndex = modelAttIndexToInstanceAttIndex(xBest.splitTest.getAttsTestDependsOn()[0], instance); + Attribute newSplitAttribute = instance.attribute(instanceIndex); + makeSplit(newSplitAttribute, xBest); + classDistributionAtTimeOfCreation = new DoubleVector(observedClassDistribution.getArrayCopy()); + } + + /** + * Reevaluates a split decision that was already made + * + * @param instance the current instance + **/ + protected void reevaluateSplit(Instance instance) { + numSplitAttempts++; + + AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSuggestions); + if (bestSuggestions.length == 0) + return; + + // get best split suggestions + AttributeSplitSuggestion[] bestSplitSuggestions = getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSplitSuggestions); + AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + + double bestSuggestionAverageMerit = bestSuggestion.splitTest == null ? 0.0 : bestSuggestion.merit; + double currentAverageMerit = getCurrentSuggestionAverageMerit(bestSuggestions); + double deltaG = bestSuggestionAverageMerit - currentAverageMerit; + double eps = computeHoeffdingBound(); + + if (deltaG > eps || (eps < tauReevaluate && deltaG > tauReevaluate * relMinDeltaG)) { + + if (bestSuggestion.splitTest == null) { + System.out.println("preprune - null wins"); + killSubtree(); + resetSplitAttribute(); + } else { + boolean doResplit = true; + if ( + getSplitTest() instanceof NumericAttributeBinaryTest + && getSplitTest().getAttsTestDependsOn()[0] == bestSuggestion.splitTest.getAttsTestDependsOn()[0] + ) { + Set keys = successors.getKeyset(); + for (SuccessorIdentifier key: keys) { + if (key.isLower()) { + if (argmax(bestSuggestion.resultingClassDistributions[0]) == argmax(successors.getSuccessorNode(key).observedClassDistribution.getArrayRef())) { + doResplit = false; + break; + } + } + else { + if (argmax(bestSuggestion.resultingClassDistributions[1]) == argmax(successors.getSuccessorNode(key).observedClassDistribution.getArrayRef())) { + doResplit = false; + break; + } + } + } + } + performedTreeRevision = true; + if (!doResplit) { + NumericAttributeBinaryTest test = (NumericAttributeBinaryTest) bestSuggestion.splitTest; + successors.adjustThreshold(test.getSplitValue()); + setSplitTest(bestSuggestion.splitTest); + nodeTime = 0; + seenWeight = 0.0; + } + else { + int instanceIndex = modelAttIndexToInstanceAttIndex(bestSuggestion.splitTest.getAttsTestDependsOn()[0], instance); + Attribute newSplitAttribute = instance.attribute(instanceIndex); + makeSplit(newSplitAttribute, bestSuggestion); + nodeTime = 0; + seenWeight = 0.0; + } + } + } + } + + /** + * Initializes the successor nodes when performing a split + * + * @param xBest the split suggestion for the best split + * @param splitAttribute the attribute to split on + * @return if the initialization was successful. + **/ + protected boolean initializeSuccessors(AttributeSplitSuggestion xBest, Attribute splitAttribute) { + boolean isNominal = splitAttribute.isNominal(); + boolean isBinary = !(xBest.splitTest instanceof NominalAttributeMultiwayTest); + Double splitValue = null; + if (isNominal && isBinary) + splitValue = ((NominalAttributeBinaryTest) xBest.splitTest).getValue(); + else if (!isNominal) + splitValue = ((NumericAttributeBinaryTest) xBest.splitTest).getSplitValue(); + + Integer splitAttributeIndex = xBest.splitTest.getAttsTestDependsOn()[0]; + if (splitAttribute.isNominal()) { + if (!isBinary) { + for (int i = 0; i < xBest.numSplits(); i++) { + double[] stats = xBest.resultingClassDistributionFromSplit(i); + if (stats.length == 0) + continue; + CustomEFDTNode s = newNode( + depth + 1, + new DoubleVector(stats), + getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)); + boolean success = successors.addSuccessorNominalMultiway((double) i, s); + if (!success) { + successors = null; + return false; + } + } + return !isLeaf(); + } else { + double[] stats1 = xBest.resultingClassDistributionFromSplit(0); + double[] stats2 = xBest.resultingClassDistributionFromSplit(1); + CustomEFDTNode s1 = newNode(depth + 1, new DoubleVector(stats1), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)); + CustomEFDTNode s2 = newNode(depth + 1, new DoubleVector(stats2), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)); + boolean success = successors.addSuccessorNominalBinary(splitValue, s1); + success = success && successors.addDefaultSuccessorNominalBinary(s2); + if (!success) { + successors = null; + return false; + } + return !isLeaf(); + } + } else { + boolean success = successors.addSuccessorNumeric( + splitValue, + newNode(depth + 1, new DoubleVector(xBest.resultingClassDistributionFromSplit(0)), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)), + true + ); + success = success && successors.addSuccessorNumeric( + splitValue, + newNode(depth + 1, new DoubleVector(xBest.resultingClassDistributionFromSplit(1)), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)), + false + ); + if (!success) { + successors = null; + return false; + } + return !isLeaf(); + } + } + + /** + * Sets the split attribute for this node + * + * @param xBest the split suggestion for the best split + * @param splitAttribute the attribute to split on + **/ + protected void setSplitAttribute(AttributeSplitSuggestion xBest, Attribute splitAttribute) { + this.splitAttribute = splitAttribute; + setSplitTest(xBest.splitTest); + } + + protected void resetSplitAttribute() { + splitAttribute = null; + splitTest = null; + } + + /** + * Kills the subtree by removing the link to the successors + */ + protected void killSubtree() { + successors = null; + } + + double[] getClassVotes() { + return observedClassDistribution.getArrayCopy(); + } + + /** + * Updates the class statistics (which class occurred how often) + * @param instance the current instance + */ + protected void updateStatistics(Instance instance) { + observedClassDistribution.addToValue((int) instance.classValue(), instance.weight()); + } + + /** + * Propagates the instance down the tree + * @param instance the current instance + * @param totalNumInstances the number of instances seen so far + */ + protected void propagateToSuccessors(Instance instance, int totalNumInstances) { + Double attValue = instance.value(splitAttribute); + CustomEFDTNode successor = successors.getSuccessorNode(attValue); + if (successor == null) + successor = addSuccessor(instance); + if (successor != null) + successor.learnInstance(instance, totalNumInstances); + } + + /** + * Add a successor to the children + * @param instance the current instance + * @return the creates successor. Null, if the successor could not be created + */ + protected CustomEFDTNode addSuccessor(Instance instance) { + List usedNomAttributes = new ArrayList<>(usedNominalAttributes); //deep copy + CustomEFDTNode successor = newNode(depth + 1, null, usedNomAttributes); + double value = instance.value(splitAttribute); + if (splitAttribute.isNominal()) { + if (!successors.isBinary()) { + boolean success = successors.addSuccessorNominalMultiway(value, successor); + return success ? successor : null; + } else { + boolean success = successors.addSuccessorNominalBinary(value, successor); + if (!success) // this is the case if the split is binary nominal but the "left" successor exists. + success = successors.addDefaultSuccessorNominalBinary(successor); + return success ? successor : null; + } + } else { + if (successors.lowerIsMissing()) { + boolean success = successors.addSuccessorNumeric(value, successor, true); + return success ? successor : null; + } else if (successors.upperIsMissing()) { + boolean success = successors.addSuccessorNumeric(value, successor, false); + return success ? successor : null; + } + } + return null; + } + + protected CustomEFDTNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) { + return new CustomEFDTNode( + splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, + leafPrediction, minSamplesReevaluate, depth, maxDepth, + tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune, nominalObserverBlueprint, + classDistribution, usedNominalAttributes, -1 // we don't block attributes in EFDT + ); + } + + protected void updateObservers(Instance instance) { + for (int i = 0; i < instance.numAttributes() - 1; i++) { //update likelihood + int instAttIndex = modelAttIndexToInstanceAttIndex(i, instance); + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs == null) { + obs = instance.attribute(instAttIndex).isNominal() ? newNominalClassObserver() : newNumericClassObserver(); + this.attributeObservers.set(i, obs); + } + obs.observeAttributeClass(instance.value(instAttIndex), (int) instance.classValue(), instance.weight()); + } + } + + /** + * If this node is a leaf node + * @return true if the node is a leaf + */ + boolean isLeaf() { + if (successors == null) + return true; + return successors.size() == 0; + } + + /** + * If this node only has seen instances of the same class + * @return true if the number of observed classes is less than 2 + */ + boolean isPure() { + return observedClassDistribution.numNonZeroEntries() < 2; + } + + protected NominalAttributeClassObserver newNominalClassObserver() { + return (NominalAttributeClassObserver) nominalObserverBlueprint.copy(); + } + + protected NumericAttributeClassObserver newNumericClassObserver() { + return (NumericAttributeClassObserver) numericObserverBlueprint.copy(); + } + + protected List getUsedNominalAttributesForSuccessor(Attribute splitAttribute, Integer splitAttributeIndex) { + List usedNomAttributesCpy = new ArrayList<>(usedNominalAttributes); //deep copy + if (splitAttribute.isNominal()) + usedNomAttributesCpy.add(splitAttributeIndex); + return usedNomAttributesCpy; + } + + protected void updateInfogainSum(AttributeSplitSuggestion[] suggestions) { + for (AttributeSplitSuggestion sugg : suggestions) { + if (sugg.splitTest != null) { + if (!infogainSum.containsKey((sugg.splitTest.getAttsTestDependsOn()[0]))) { + infogainSum.put((sugg.splitTest.getAttsTestDependsOn()[0]), 0.0); + } + double currentSum = infogainSum.get((sugg.splitTest.getAttsTestDependsOn()[0])); + infogainSum.put((sugg.splitTest.getAttsTestDependsOn()[0]), currentSum + sugg.merit); + } else { // handle the null attribute + double currentSum = infogainSum.get(-1); // null split + infogainSum.put(-1, Math.max(0.0, currentSum + sugg.merit)); + assert infogainSum.get(-1) >= 0.0 : "Negative infogain shouldn't be possible here."; + } + } + } + + protected boolean shouldSplitLeaf(AttributeSplitSuggestion[] suggestions, + double confidence, + DoubleVector observedClassDistribution + ) { + boolean shouldSplit = false; + if (suggestions.length < 2) { + shouldSplit = suggestions.length > 0; + } else { + AttributeSplitSuggestion bestSuggestion = suggestions[suggestions.length - 1]; + + double bestSuggestionAverageMerit = bestSuggestion.merit; + double currentAverageMerit = 0.0; + double eps = computeHoeffdingBound(); + + shouldSplit = bestSuggestionAverageMerit - currentAverageMerit > eps || eps < tau; + if (bestSuggestion.merit < 1e-10) + shouldSplit = false; // we don't use average here + + if (shouldSplit) { + for (Integer i : usedNominalAttributes) { + if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) { + shouldSplit = false; + break; + } + } + } + } + return shouldSplit; + } + + /** + * Get the merit if the current split + * @param suggestions the suggestions for the possible splits + * @return the merit of the current split + */ + double getCurrentSuggestionAverageMerit(AttributeSplitSuggestion[] suggestions) { + double merit = 0.0; + if (splitTest != null) { + if (splitTest instanceof NominalAttributeMultiwayTest) { + for (AttributeSplitSuggestion s: suggestions) { + if (s.splitTest == null) + continue; + if (s.splitTest.getAttsTestDependsOn()[0] == getSplitAttributeIndex()) { + merit = s.merit; + break; + } + } + } + else if (splitTest instanceof NominalAttributeBinaryTest) { + double currentValue = successors.getReferenceValue(); + NominalAttributeClassObserver obs = (NominalAttributeClassObserver) attributeObservers.get(getSplitAttributeIndex()); + AttributeSplitSuggestion xCurrent = obs.forceSplit(splitCriterion, observedClassDistribution.getArrayCopy(), getSplitAttributeIndex(), true, currentValue); + merit = xCurrent == null ? 0.0 : xCurrent.merit; + if (xCurrent != null) + merit = xCurrent.splitTest == null ? 0.0 : xCurrent.merit; + } + else if (splitTest instanceof NumericAttributeBinaryTest) { + double currentThreshold = successors.getReferenceValue(); + GaussianNumericAttributeClassObserver obs = (GaussianNumericAttributeClassObserver) attributeObservers.get(getSplitAttributeIndex()); + AttributeSplitSuggestion xCurrent = obs.forceSplit(splitCriterion, observedClassDistribution.getArrayCopy(), getSplitAttributeIndex(), currentThreshold); + merit = xCurrent == null ? 0.0 : xCurrent.merit; + if (xCurrent != null) + merit = xCurrent.splitTest == null ? 0.0 : xCurrent.merit; + } + } + return merit == Double.NEGATIVE_INFINITY ? 0.0 : merit; + } + + double getSuggestionAverageMerit(InstanceConditionalTest splitTest) { + double averageMerit; + + if (splitTest == null) { + averageMerit = infogainSum.get(-1) / Math.max(numSplitAttempts, 1.0); + } else { + Integer key = splitTest.getAttsTestDependsOn()[0]; + if (!infogainSum.containsKey(key)) { + infogainSum.put(key, 0.0); + } + averageMerit = infogainSum.get(key) / Math.max(numSplitAttempts, 1.0); + } + return averageMerit; + } + + int argmax(double[] array) { + double max = array[0]; + int maxarg = 0; + + for (int i = 1; i < array.length; i++) { + + if (array[i] > max) { + max = array[i]; + maxarg = i; + } + } + return maxarg; + } + + public int getSubtreeDepth() { + if (isLeaf()) + return depth; + Set succDepths = new HashSet<>(); + for (CustomEFDTNode successor: successors.getAllSuccessors()) { + succDepths.add(successor.getSubtreeDepth()); + } + return Collections.max(succDepths); + } + + /** + * Replaces the best split suggestion if the node is not allowed to split on that attribute. + * This is the case when the parent node splits on that attribute. It prevents splitting the space into thinner and thinner slices. + * @param bestSuggestion the best split suggestion + * @param suggestions all suggestions + * @param blockedAttributeIndex the attribute index of the blocked attribute + * @return + */ + AttributeSplitSuggestion replaceBestSuggestionIfAttributeIsBlocked(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion[] suggestions, int blockedAttributeIndex) { + if (suggestions.length == 0) + return null; + if (bestSuggestion.splitTest == null) + return bestSuggestion; + if (suggestions.length == 1) + return bestSuggestion; + if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == blockedAttributeIndex) { + ArrayUtils.remove(suggestions, suggestions.length - 1); + return suggestions[suggestions.length - 1]; + } + return bestSuggestion; + } + + protected boolean makeSplit(Attribute splitAttribute, AttributeSplitSuggestion suggestion) { + boolean isNominal = splitAttribute.isNominal(); + boolean isBinary = !(suggestion.splitTest instanceof NominalAttributeMultiwayTest); + Double splitValue = null; + if (isNominal && isBinary) + splitValue = ((NominalAttributeBinaryTest) suggestion.splitTest).getValue(); + else if (!isNominal) + splitValue = ((NumericAttributeBinaryTest) suggestion.splitTest).getSplitValue(); + + successors = new Successors(isBinary, !isNominal, splitValue); + + setSplitAttribute(suggestion, splitAttribute); + return initializeSuccessors(suggestion, splitAttribute); + } + + /** + * Checks if the subtree of this node performed split revision + * @return true if any node in the subtree performed a split revision + */ + @Override + public boolean didPerformTreeRevision() { + boolean didRevise = performedTreeRevision; + performedTreeRevision = false; + if (isLeaf()) { + return didRevise; + } + for (CustomEFDTNode child: successors.getAllSuccessors()) { + didRevise |= child.didPerformTreeRevision(); + } + return didRevise; + } + + @Override + public int getLeafNumber() { + if (isLeaf()) + return 1; + int sum = 0; + for (CustomEFDTNode s: successors.getAllSuccessors()) { + sum += s.getLeafNumber(); + } + return sum; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java new file mode 100644 index 000000000..9ab193dec --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java @@ -0,0 +1,80 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.core.AttributeSplitSuggestion; +import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.core.DoubleVector; + +import java.util.LinkedList; +import java.util.List; + +public class CustomHTNode extends CustomEFDTNode { + + public CustomHTNode(SplitCriterion splitCriterion, + int gracePeriod, + Double confidence, + Double adaptiveConfidence, + boolean useAdaptiveConfidence, + String leafPrediction, + Integer depth, + Integer maxDepth, + Double tau, + boolean binaryOnly, + boolean noPrePrune, + NominalAttributeClassObserver nominalObserverBlueprint, + DoubleVector observedClassDistribution, + List usedNominalAttributes, + int blockedAttributeIndex) { + super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction, + Integer.MAX_VALUE, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune, + nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex); + + } + + @Override + protected void reevaluateSplit(Instance instance) { + } + + @Override + protected CustomHTNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) { + return new CustomHTNode( + splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, + leafPrediction, depth, maxDepth, + tau, binaryOnly, noPrePrune, nominalObserverBlueprint, + classDistribution, new LinkedList<>(), -1 // we don't block attributes in HT + ); + } + + @Override + protected boolean shouldSplitLeaf(AttributeSplitSuggestion[] suggestions, + double confidence, + DoubleVector observedClassDistribution + ) { + boolean shouldSplit; + if (suggestions.length < 2) { + shouldSplit = suggestions.length > 0; + } else { + AttributeSplitSuggestion bestSuggestion = suggestions[suggestions.length - 1]; + AttributeSplitSuggestion secondBestSuggestion = suggestions[suggestions.length - 2]; + + double bestSuggestionAverageMerit = bestSuggestion.merit; + double currentAverageMerit = secondBestSuggestion.merit; + double eps = computeHoeffdingBound(); + + shouldSplit = bestSuggestionAverageMerit - currentAverageMerit > eps || eps < tau; + if (bestSuggestion.merit < 1e-10) + shouldSplit = false; // we don't use average here + + if (shouldSplit) { + for (Integer i : usedNominalAttributes) { + if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) { + shouldSplit = false; + break; + } + } + } + } + return shouldSplit; + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java new file mode 100644 index 000000000..c76596c0c --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java @@ -0,0 +1,174 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.core.DoubleVector; + +import java.util.LinkedList; +import java.util.List; + +public class EFHATNode extends CustomEFDTNode { + + private CustomADWINChangeDetector changeDetector; // we need to access the width of adwin to compute switch significance. This is not possible with the default adwin change detector class. + private EFHATNode backgroundLearner; + private LinkedList predictions = new LinkedList<>(); + private boolean isBackgroundLearner = false; + + public EFHATNode(SplitCriterion splitCriterion, + int gracePeriod, + Double confidence, + Double adaptiveConfidence, + boolean useAdaptiveConfidence, + String leafPrediction, + Integer minSamplesReevaluate, + Integer depth, + Integer maxDepth, + Double tau, + boolean binaryOnly, + boolean noPrePrune, + NominalAttributeClassObserver nominalObserverBlueprint, + DoubleVector observedClassDistribution, + List usedNominalAttributes, + int blockedAttributeIndex, + CustomADWINChangeDetector changeDetector) { + super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction, + minSamplesReevaluate, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune, + nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex); + this.changeDetector = changeDetector == null ? new CustomADWINChangeDetector() : changeDetector; + } + + public EFHATNode(SplitCriterion splitCriterion, + int gracePeriod, + Double confidence, + Double adaptiveConfidence, + boolean useAdaptiveConfidence, + String leafPrediction, + Integer minSamplesReevaluate, + Integer depth, + Integer maxDepth, + Double tau, + boolean binaryOnly, + boolean noPrePrune, + NominalAttributeClassObserver nominalObserverBlueprint, + DoubleVector observedClassDistribution, + List usedNominalAttributes, + int blockedAttributeIndex, + CustomADWINChangeDetector changeDetector, + boolean isBackgroundLearner) { + super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction, + minSamplesReevaluate, depth, maxDepth, tau, 0.0, 0.0, binaryOnly, noPrePrune, + nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex); + this.changeDetector = changeDetector == null ? new CustomADWINChangeDetector() : changeDetector; + this.isBackgroundLearner = isBackgroundLearner; + } + + @Override + public double[] predict(Instance instance) { + double[] pred = super.predict(instance); + if (pred.length > 0) + predictions.add((double) argmax(pred)); + if (backgroundLearner != null) + backgroundLearner.predict(instance); + return pred; + } + + @Override + public void learnInstance(Instance instance, int totalNumInstances) { + seenWeight += instance.weight(); + nodeTime++; + updateStatistics(instance); + updateObservers(instance); + updateChangeDetector(instance.classValue()); + + if (backgroundLearner != null) + backgroundLearner.learnInstance(instance, totalNumInstances); + + if (isLeaf() && nodeTime % gracePeriod == 0) + attemptInitialSplit(instance); + + if (!isLeaf() +// && nodeTime % minSamplesReevaluate == 0 + ) + hatGrow(); + + if (!isLeaf()) //Do NOT! put this in the upper (!isleaf()) block. This is not the same since we might kill the subtree during reevaluation! + propagateToSuccessors(instance, totalNumInstances); + } + + private void hatGrow() { + if (changeDetector.getChange()) { + backgroundLearner = newNode(depth, new DoubleVector(), new LinkedList<>(usedNominalAttributes)); + backgroundLearner.isBackgroundLearner = true; + } + else if (backgroundLearner != null) { + // adopted from HAT implementation + if (changeDetector.getWidth() > minSamplesReevaluate && backgroundLearner.changeDetector.getWidth() > minSamplesReevaluate) { + double oldErrorRate = changeDetector.getEstimation(); + double oldWS = changeDetector.getWidth(); + double altErrorRate = backgroundLearner.changeDetector.getEstimation(); + double altWS = backgroundLearner.changeDetector.getWidth(); + + double fDelta = .05; + //if (gNumAlts>0) fDelta=fDelta/gNumAlts; +// double fN = 1.0 / w2 + 1.0 / w1; + double fN = 1.0 / altWS + 1.0 / oldWS; + double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN); + +// double significance = switchSignificance(e1, e2, w1, w2); + if (Bound < oldErrorRate - altErrorRate) { + performedTreeRevision = true; + makeBackgroundLearnerMainLearner(); + } + else if (Bound < altErrorRate - oldErrorRate) { + // Erase alternate tree + backgroundLearner = null; + } + } + } + } + + @Override + protected EFHATNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) { + return new EFHATNode( + splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, + leafPrediction, minSamplesReevaluate, depth, maxDepth, + tau, binaryOnly, noPrePrune, nominalObserverBlueprint, + classDistribution, new LinkedList<>(), + -1, // we don't block attributes in HT + (CustomADWINChangeDetector) changeDetector.copy(), + isBackgroundLearner + ); + } + + private void makeBackgroundLearnerMainLearner() { + splitAttribute = backgroundLearner.splitAttribute; + successors = backgroundLearner.successors; + observedClassDistribution = backgroundLearner.observedClassDistribution; + classDistributionAtTimeOfCreation = backgroundLearner.classDistributionAtTimeOfCreation; + attributeObservers = backgroundLearner.attributeObservers; + seenWeight = backgroundLearner.seenWeight; + nodeTime = backgroundLearner.nodeTime; + numSplitAttempts = backgroundLearner.numSplitAttempts; + setSplitTest(backgroundLearner.getSplitTest()); + backgroundLearner = null; + } + + private void updateChangeDetector(double label) { + if (predictions.size() == 0) + return; + Double pred = predictions.removeFirst(); + if (pred == null) + return; + changeDetector.input(pred == label ? 0.0 : 1.0); // monitoring error rate, not accuracy. + } + + private void resetIsBackgroundLearnerInSubtree() { + isBackgroundLearner = false; + if (isLeaf()) + return; + for (CustomEFDTNode n: successors.getAllSuccessors()) { + ((EFHATNode) n).resetIsBackgroundLearnerInSubtree(); + } + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java new file mode 100644 index 000000000..7a678f7b0 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java @@ -0,0 +1,409 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Attribute; +import moa.classifiers.core.conditionaltests.InstanceConditionalTest; + +import java.util.*; + +class MappedTree implements Iterator { + + private LinkedList branchQueue; + private final LinkedList finishedBranches = new LinkedList<>(); + + private final Attribute splitAttribute; + private final int splitAttributeIndex; + private final Double splitValue; + private final int maxBranchLength; + + public MappedTree(PlasticNode root, Attribute splitAttribute, int splitAttributeIndex, Double splitValue, int maxBranchLength) { + branchQueue = disconnectRoot(root); + + this.splitAttribute = splitAttribute; + this.splitAttributeIndex = splitAttributeIndex; + this.splitValue = splitValue; + this.maxBranchLength = maxBranchLength; + } + + + @Override + public boolean hasNext() { + return !(branchQueue.size() == 0 && finishedBranches.size() == 0); + } + + @Override + public PlasticBranch next() { + if (finishedBranches.size() > 0) { + return finishedBranches.removeFirst(); + } + else { + while (finishedBranches.size() == 0) { + mapBranches(branchQueue, splitAttribute, splitAttributeIndex, splitValue, maxBranchLength); + } + return finishedBranches.removeFirst(); + } + } + + private void mapBranches( + LinkedList branches, + Attribute swapAttribute, + int swapAttributeIndex, + Double splitValue, + int maxBranchLength + ) { + int numBranches = branches.size(); + for (int i = 0; i < numBranches; i++) { + PlasticBranch branch = branches.removeFirst(); + boolean branchIsFinished = getEndConditionForBranch(branch, swapAttribute, swapAttributeIndex, splitValue, maxBranchLength); + if (branchIsFinished) { + expandBranch(branch, swapAttribute, swapAttributeIndex, splitValue); + List decoupledBranches = decoupleLastNode(branch); + decoupledBranches.forEach(b -> modifyBranch(b, swapAttribute)); + finishedBranches.addAll(decoupledBranches); + branchQueue = branches; + return; + } + PlasticTreeElement lastElement = branch.getLast(); + PlasticBranch branchExtensions = disconnectSuccessors(lastElement, swapAttribute); + for (PlasticTreeElement extension: branchExtensions.getBranchRef()) { + PlasticBranch extendedBranch = new PlasticBranch((LinkedList) branch.getBranchRef().clone()); + extendedBranch.getBranchRef().add(extension); + branches.add(extendedBranch); + } + } + branchQueue = branches; + } + + private boolean getEndConditionForBranch(PlasticBranch branch, Attribute swapAttribute, int swapAttributeIndex, Double splitValue, int maxBranchLength) { + CustomEFDTNode lastNodeOfBranch = branch.getLast().getNode(); + boolean isSwapAttribute = lastNodeOfBranch.getSplitAttribute() == swapAttribute; + boolean isLeaf = lastNodeOfBranch.isLeaf(); + if (isSwapAttribute || isLeaf) + return true; + return branch.getBranchRef().size() >= maxBranchLength; + } + + private PlasticBranch disconnectSuccessors(PlasticTreeElement ancestor, Attribute swapAttribute) { + PlasticBranch branchExtensions = new PlasticBranch(); + PlasticNode ancestorNode = ancestor.getNode(); + SuccessorIdentifier key = ancestor.getKey(); + PlasticNode successor = (PlasticNode) ancestorNode.getSuccessors().getSuccessorNode(key); + + boolean endCondition = successor.isLeaf() || successor.getSplitAttribute() == swapAttribute; + + if (endCondition) { + PlasticTreeElement element = new PlasticTreeElement(successor, null); + branchExtensions.getBranchRef().add(element); + } + else { + for (SuccessorIdentifier successorSuccessorKey: successor.getSuccessors().getKeyset()) { + PlasticTreeElement element = new PlasticTreeElement(successor, successorSuccessorKey); + branchExtensions.getBranchRef().add(element); + } + } + return branchExtensions; + } + + private LinkedList disconnectRoot(PlasticNode root) { + Successors successors = root.getSuccessors(); + if (successors == null) { + return null; + } + if (successors.size() == 0) { + return null; + } + Set keys = successors.getKeyset(); + LinkedList branches = new LinkedList<>(); + for (SuccessorIdentifier key: keys) { + PlasticTreeElement newElement = new PlasticTreeElement(root, key); + PlasticBranch newBranch = new PlasticBranch(); + newBranch.getBranchRef().add(newElement); + branches.add(newBranch); + } + return branches; + } + + private void expandBranch(PlasticBranch branch, Attribute splitAttribute, int splitAttributeIndex, Double splitValue) { + if (splitAttribute.isNominal()) + splitLastNodeIfRequiredCategorical(branch, splitAttribute, splitAttributeIndex, splitValue); + else + splitLastNodeIfRequiredNumeric(branch, splitAttribute, splitAttributeIndex, splitValue); + } + + private LinkedList decoupleLastNode(PlasticBranch branch) { + PlasticNode lastNode = branch.getLast().getNode(); + LinkedList decoupledBranches = new LinkedList<>(); + Successors lastNodeSuccessors = lastNode.getSuccessors(); + + for (SuccessorIdentifier key: lastNodeSuccessors.getKeyset()) { + PlasticBranch branchCopy = branch.copy(); //TODO: Doublecheck if this is fine. + PlasticTreeElement replacedEndElement = new PlasticTreeElement(branchCopy.getLast().getNode(), key); + branchCopy.getBranchRef().removeLast(); + branchCopy.getBranchRef().add(replacedEndElement); + PlasticTreeElement finalElement = new PlasticTreeElement((PlasticNode) lastNodeSuccessors.getSuccessorNode(key), null); + branchCopy.getBranchRef().add(finalElement); + decoupledBranches.add(branchCopy); + } + + return decoupledBranches; + } + + private void splitLastNodeIfRequiredCategorical(PlasticBranch branch, + Attribute swapAttribute, + int swapAttributeIndex, + Double splitValue) { + PlasticNode lastNode = branch.getLast().getNode(); + + boolean shouldBeBinary = splitValue != null; + boolean splitAttributesMatch = lastNode.splitAttribute == swapAttribute; + + if (lastNode.isLeaf() || !splitAttributesMatch) { // Option 1: the split attributes don't match + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + shouldBeBinary + ); + if (lastNode.isLeaf()) { + System.out.println("Error"); + } + lastNode.successors.getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial()); + return; + } + + boolean isBinary = lastNode.successors.isBinary(); + + if (!isBinary && !shouldBeBinary) // Option 2: the split attributes match and the splits are multiway (this is really the best case possible) + // do nothing + return; + + if (isBinary && shouldBeBinary) { // Option 3: the split attributes match and also both splits should be binary + if (splitValue.equals(lastNode.successors.getReferenceValue())) + // do nothing + return; + Successors lastNodeSuccessors = new Successors(lastNode.getSuccessors(), true); + + // Corresponds to transformation 3 in the paper + Attribute lastNodeSplitAttribute = lastNode.getSplitAttribute(); + int lastNodeSplitAttributeIndex = lastNode.getSplitAttributeIndex(); + InstanceConditionalTest splitTest = lastNode.getSplitTest(); + lastNode.successors = null; + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + shouldBeBinary + ); + PlasticNode defaultSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(SuccessorIdentifier.DEFAULT_NOMINAL_VALUE); + defaultSuccessor.transferSplit( + lastNodeSuccessors, lastNodeSplitAttribute, lastNodeSplitAttributeIndex, splitTest + ); + return; + } + + if (!isBinary && shouldBeBinary) { // Option 4: The split attributes match and the current split is multiway while the old one was binary. In this case, we do something similar to the numeric splits + // Corresponds to transformation 2 in the paper + Successors lastNodeSuccessors = new Successors(lastNode.getSuccessors(), true); + Attribute lastNodeSplitAttribute = lastNode.getSplitAttribute(); + int lastNodeSplitAttributeIndex = lastNode.getSplitAttributeIndex(); + InstanceConditionalTest splitTest = lastNode.getSplitTest(); + lastNode.successors = null; + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + shouldBeBinary + ); + PlasticNode defaultSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(SuccessorIdentifier.DEFAULT_NOMINAL_VALUE); +// if (lastNodeSuccessors.contains(splitValue)) { +// SuccessorIdentifier foundKey = null; +// for (SuccessorIdentifier k: lastNodeSuccessors.getKeyset()) { +// if (k.getSelectorValue() == splitValue) { +// foundKey = k; +// break; +// } +// } +// if (foundKey != null) +// lastNodeSuccessors.removeSuccessor(foundKey); +// } + defaultSuccessor.transferSplit( + lastNodeSuccessors, lastNodeSplitAttribute, lastNodeSplitAttributeIndex, splitTest + ); + return; + } + + if (isBinary && !shouldBeBinary) { // Option 5: The split is binary but should be multiway. In this case, we force the split and then use the old subtree of the left branch of the old subtree. + // Corresponds to transformation 1 in the paper + SuccessorIdentifier keyToPreviousSuccessor = new SuccessorIdentifier(false, splitValue, splitValue, false); + PlasticNode previousSuccessor = (PlasticNode) lastNode.getSuccessors().getSuccessorNode(keyToPreviousSuccessor); + + lastNode.successors = null; + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + shouldBeBinary + ); + lastNode.successors.removeSuccessor(keyToPreviousSuccessor); + lastNode.successors.addSuccessor(previousSuccessor, keyToPreviousSuccessor); + return; + } + + else if (lastNode.isLeaf()) { + System.out.println("Do nothing"); + } + } + + private void splitLastNodeIfRequiredNumeric(PlasticBranch branch, + Attribute swapAttribute, + int swapAttributeIndex, + Double splitValue) { + assert splitValue != null; + PlasticNode lastNode = branch.getLast().getNode(); + boolean forceSplit = lastNode.isLeaf() || lastNode.splitAttribute != swapAttribute; + + if (forceSplit) { + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + true + ); + lastNode.getSuccessors().getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial()); + return; + } + + Double oldThreshold = lastNode.getSuccessors().getReferenceValue(); + if (splitValue.equals(oldThreshold)) + return; // do nothing + + if (lastNode.getSuccessors().size() > 1) { + updateThreshold(lastNode, swapAttributeIndex, splitValue); + return; + } + + lastNode.forceSplit( + swapAttribute, + swapAttributeIndex, + splitValue, + true + ); + lastNode.getSuccessors().getAllSuccessors().forEach(s -> ((PlasticNode) s).setIsArtificial()); + } + + private void updateThreshold(PlasticNode node, int splitAttributeIndex, double splitValue) { + Double oldThreshold = node.getSuccessors().getReferenceValue(); + + SuccessorIdentifier leftKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, true); + SuccessorIdentifier rightKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, false); + PlasticNode succ1 = (PlasticNode) node.getSuccessors().getSuccessorNode(leftKey); + PlasticNode succ2 = (PlasticNode) node.getSuccessors().getSuccessorNode(rightKey); + Successors newSuccessors = new Successors(true, true, splitValue); + if (succ1 != null) + newSuccessors.addSuccessorNumeric(splitValue, succ1, true); + if (succ2 != null) + newSuccessors.addSuccessorNumeric(splitValue, succ2, false); + node.successors = newSuccessors; + + if (node.isLeaf()) + return; + + for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) { + PlasticNode s = (PlasticNode) node.getSuccessors().getSuccessorNode(key); + removeUnreachableSubtree(s, splitAttributeIndex, splitValue, key.isLower()); + } + } + + private void removeUnreachableSubtree(PlasticNode node, int splitAttributeIndex, double threshold, boolean isLower) { + if (node.isLeaf()) + return; + + if (node.getSplitAttributeIndex() != splitAttributeIndex) { + for (CustomEFDTNode successor: node.getSuccessors().getAllSuccessors()) { + removeUnreachableSubtree((PlasticNode) successor, splitAttributeIndex, threshold, isLower); + } + return; + } + + Set keysToRemove = new HashSet<>(); + for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) { + assert key.isNumeric(); + if (isLower) { + if (!key.isLower() && key.getSelectorValue() >= threshold) { + keysToRemove.add(key); + } + } + else { + if (key.isLower() && key.getSelectorValue() <= threshold) { + keysToRemove.add(key); + } + } + } + for (SuccessorIdentifier key: keysToRemove) { + node.getSuccessors().removeSuccessor(key); + } + + if (!node.isLeaf()) { + node.getSuccessors().getAllSuccessors().forEach(s -> removeUnreachableSubtree((PlasticNode) s, splitAttributeIndex, threshold, isLower)); + } + } + + private void modifyBranch(PlasticBranch branch, Attribute splitAttribute) { + putLastElementToFront(branch, splitAttribute); + resetSuccessorsInBranch(branch); + setRestructuredFlagInBranch(branch); + setDepth(branch); + } + + private void putLastElementToFront(PlasticBranch branch, Attribute splitAttribute) { + PlasticTreeElement oldFirstBranchElement = branch.getBranchRef().getFirst(); + PlasticTreeElement newFirstBranchElement = branch.getBranchRef().remove(branch.getBranchRef().size() - 2); + + branch.getBranchRef().addFirst(newFirstBranchElement); + if (splitAttribute != newFirstBranchElement.getNode().splitAttribute) { + System.out.println(branch.getDescription()); + } + + PlasticNode oldFirstNode = oldFirstBranchElement.getNode(); + PlasticNode newFirstNode = newFirstBranchElement.getNode(); + + //TODO not sure this is actually required! I think it could be sufficient to just change the successors when building the tree in a later step. + newFirstNode.observedClassDistribution = oldFirstNode.observedClassDistribution; + newFirstNode.depth = oldFirstNode.depth; + newFirstNode.attributeObservers = oldFirstNode.attributeObservers; + } + + private void resetSuccessorsInBranch(PlasticBranch branch) { + if (branch.getBranchRef().size() == 1) + return; + int i = 0; + for (PlasticTreeElement item: branch.getBranchRef()) { + if (i == branch.getBranchRef().size() - 1) + break; + item.getNode().successors = new Successors(item.getNode().getSuccessors(), false); + i++; + } + } + + private void setRestructuredFlagInBranch(PlasticBranch branch) { + if (branch.getBranchRef().size() <= 2) + return; + int i = 0; + for (PlasticTreeElement item: branch.getBranchRef()) { + if (i == 0 || i == branch.getBranchRef().size() - 1) { + i++; + continue; + } + item.getNode().setRestructuredFlag(); + i++; + } + } + + private void setDepth(PlasticBranch branch) { + PlasticNode firstNode = branch.getBranchRef().getFirst().getNode(); + int i = 0; + for (PlasticTreeElement item: branch.getBranchRef()) { + item.getNode().depth = firstNode.getDepth() + i; + i++; + } + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java new file mode 100644 index 000000000..3873c5812 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java @@ -0,0 +1,5 @@ +package moa.classifiers.trees.plastic_util; + +public interface MeasuresNumberOfLeaves { + public int getLeafNumber(); +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java new file mode 100644 index 000000000..6e87a7986 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java @@ -0,0 +1,5 @@ +package moa.classifiers.trees.plastic_util; + +public interface PerformsTreeRevision { + boolean didPerformTreeRevision(); +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java new file mode 100644 index 000000000..00dab0ef4 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java @@ -0,0 +1,63 @@ +package moa.classifiers.trees.plastic_util; + +import java.util.ArrayList; +import java.util.LinkedList; + +class PlasticBranch implements Comparable { + private LinkedList branch = new LinkedList<>(); + + public PlasticBranch(){} + + public PlasticBranch(LinkedList branch) { + this.branch = branch; + } + + public LinkedList getBranchRef() { + return branch; + } + + public String getDescription() { + if (branch == null) { + return "Branch is null"; + } + StringBuilder s = new StringBuilder(); + int i = 0; + for (PlasticTreeElement e: branch) { + i++; + s.append(e.getDescription()).append(i == branch.size() ? "" : " --> "); + } + return s.toString(); + } + + public PlasticTreeElement getLast() { + if (branch == null) + return null; + if (branch.size() == 0) { + return null; + } + return branch.getLast(); + } + + public PlasticBranch copy() { + PlasticBranch cpy = new PlasticBranch(); + for (PlasticTreeElement item: branch) { + cpy.getBranchRef().add(item.copy()); + } + return cpy; + } + + public ArrayList branchArrayCpy() { + return new ArrayList<>(branch); + } + + public int compareTo(PlasticBranch other) + { + int a = branch.getLast().getNode().observedClassDistribution.numValues(); + int b = other.branch.getLast().getNode().observedClassDistribution.numValues(); + if (a < b) + return -1; + if (a == b) + return 0; + return 1; + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java new file mode 100644 index 000000000..91b056246 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java @@ -0,0 +1,398 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Attribute; +import com.yahoo.labs.samoa.instances.Instance; +import moa.classifiers.core.AttributeSplitSuggestion; +import moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest; +import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.core.AutoExpandVector; +import moa.core.DoubleVector; + +import java.util.*; + +public class PlasticNode extends CustomEFDTNode { + + private Set childrenSplitAttributes; + private boolean nodeGotRestructured = false; + private boolean isArtificial = false; + private final Restructurer restructurer; + protected final int maxBranchLength; + protected final double acceptedNumericThresholdDeviation; + private boolean isDummy = false; + + protected void setRestructuredFlag() { + nodeGotRestructured = true; + } + + protected void resetRestructuredFlag() { + nodeGotRestructured = false; + } + + protected boolean getRestructuredFlag() { + return nodeGotRestructured; + } + + public PlasticNode( + SplitCriterion splitCriterion, int gracePeriod, Double confidence, Double adaptiveConfidence, + boolean useAdaptiveConfidence, String leafPrediction, Integer minSamplesReevaluate, Integer depth, + Integer maxDepth, Double tau, Double tauReevaluate, Double relMinDeltaG, boolean binaryOnly, + boolean noPrePrune, NominalAttributeClassObserver nominalObserverBlueprint, + DoubleVector observedClassDistribution, List usedNominalAttributes, + int maxBranchLength, double acceptedNumericThresholdDeviation, int blockedAttributeIndex + ) { + super(splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, leafPrediction, + minSamplesReevaluate, depth, maxDepth, tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune, + nominalObserverBlueprint, observedClassDistribution, usedNominalAttributes, blockedAttributeIndex); + this.maxBranchLength = maxBranchLength; + this.acceptedNumericThresholdDeviation = acceptedNumericThresholdDeviation; + restructurer = new Restructurer(maxBranchLength, acceptedNumericThresholdDeviation); + } + + public PlasticNode(PlasticNode other) { + super((SplitCriterion) other.splitCriterion.copy(), other.gracePeriod, other.confidence, + other.adaptiveConfidence, other.useAdaptiveConfidence, other.leafPrediction, + other.minSamplesReevaluate, other.depth, other.maxDepth, other.tau, other.tauReevaluate, + other.relMinDeltaG, other.binaryOnly, other.noPrePrune, other.nominalObserverBlueprint, + (DoubleVector) other.observedClassDistribution.copy(), other.usedNominalAttributes, + other.blockedAttributeIndex); + this.acceptedNumericThresholdDeviation = other.acceptedNumericThresholdDeviation; + this.maxBranchLength = other.maxBranchLength; + if (other.successors != null) + this.successors = new Successors(other.successors, true); + if (other.getSplitTest() != null) + setSplitTest((InstanceConditionalTest) other.getSplitTest().copy()); + this.infogainSum = new HashMap<>(infogainSum); + this.numSplitAttempts = other.numSplitAttempts; + this.classDistributionAtTimeOfCreation = other.classDistributionAtTimeOfCreation; + this.nodeTime = other.nodeTime; + this.splitAttribute = other.splitAttribute; + this.seenWeight = other.seenWeight; + this.isArtificial = other.isArtificial; + if (other.attributeObservers != null) + this.attributeObservers = (AutoExpandVector) other.attributeObservers.copy(); + restructurer = other.restructurer; + blockedAttributeIndex = other.blockedAttributeIndex; + } + + /** + * In some cases during restructuring, we create dummy nodes that we prune restructuring has finished + * @return true if the node is a dummy node + */ + public boolean isDummy() { + return isDummy; + } + + @Override + protected PlasticNode addSuccessor(Instance instance) { + List usedNomAttributes = new ArrayList<>(usedNominalAttributes); //deep copy + PlasticNode successor = newNode(depth + 1, new DoubleVector(), usedNomAttributes); + double value = instance.value(splitAttribute); + if (splitAttribute.isNominal()) { + if (!successors.isBinary()) { + boolean success = successors.addSuccessorNominalMultiway(value, successor); + return success ? successor : null; + } else { + boolean success = successors.addSuccessorNominalBinary(value, successor); + if (!success) // this is the case if the split is binary nominal but the "left" successor exists. + success = successors.addDefaultSuccessorNominalBinary(successor); + return success ? successor : null; + } + } else { + NumericAttributeBinaryTest test = (NumericAttributeBinaryTest) getSplitTest(); + if (successors.lowerIsMissing()) { + boolean success = successors.addSuccessorNumeric(test.getValue(), successor, true); + return success ? successor : null; + } else if (successors.upperIsMissing()) { + boolean success = successors.addSuccessorNumeric(test.getValue(), successor, false); + return success ? successor : null; + } + } + return null; + } + + @Override + protected PlasticNode newNode(int depth, DoubleVector classDistribution, List usedNominalAttributes) { + return new PlasticNode( + splitCriterion, gracePeriod, confidence, adaptiveConfidence, useAdaptiveConfidence, + leafPrediction, minSamplesReevaluate, depth, maxDepth, + tau, tauReevaluate, relMinDeltaG, binaryOnly, noPrePrune, nominalObserverBlueprint, + classDistribution, usedNominalAttributes, maxBranchLength, acceptedNumericThresholdDeviation, getSplitAttributeIndex() + ); + } + + /** + * Collect the split attributes of the children and this node. + * For performance reasons, this updates the `childrenSplitAttributes` property at every node. + * @return the set of attributes the children and this node split on. + */ + public Set collectChildrenSplitAttributes() { + childrenSplitAttributes = new HashSet<>(); + if (isLeaf()) { + // we have no split attribute + return childrenSplitAttributes; + } + // add the split attribute of this node to the set + childrenSplitAttributes.add(splitAttribute); + for (CustomEFDTNode successor : successors.getAllSuccessors()) { + // add the split attributes of the subtree + childrenSplitAttributes.addAll(((PlasticNode) successor).collectChildrenSplitAttributes()); + } + return childrenSplitAttributes; + } + + /** + * Simply returns `childrenSplitAttributes`. In doubt, call `collectChildrenSplitAttributes` before accessing the property. + * @return the set of attributes the children and this node split on. + */ + public Set getChildrenSplitAttributes() { + return childrenSplitAttributes; + } + + /** + * Transfers the split from another node to this node + * @param successors the sucessors of the other node + * @param splitAttribute the split attribute of the other node + * @param splitAttributeIndex the split attribute index of the other node + * @param splitTest the split test of the other node + */ + public void transferSplit(Successors successors, + Attribute splitAttribute, + int splitAttributeIndex, + InstanceConditionalTest splitTest) { + this.successors = successors; + this.splitAttribute = splitAttribute; + setSplitTest(splitTest); + } + + /** + * Increases the depth property in all subtree nodes by 1 + */ + protected void incrementDepthInSubtree() { + depth++; + if (isLeaf()) + return; + for (SuccessorIdentifier key : successors.getKeyset()) { + PlasticNode successor = (PlasticNode) successors.getSuccessorNode(key); + successor.incrementDepthInSubtree(); + } + } + + protected void resetObservers() { + AutoExpandVector newObservers = new AutoExpandVector<>(); + for (AttributeClassObserver observer : attributeObservers) { + if (observer.getClass() == nominalObserverBlueprint.getClass()) + newObservers.add(newNominalClassObserver()); + else + newObservers.add(newNumericClassObserver()); + } + attributeObservers = newObservers; + } + + /** + * If the node was artificially created during restructuring or if it originated from a 'normal' split + * @return true if the node was created artificially + */ + public boolean isArtificial() { + return isArtificial; + } + + public void setIsArtificial() { + isArtificial = true; + } + + public void setIsArtificial(boolean val) { + isArtificial = val; + } + + /** + * Forces a split of this node. This is required during restructuring to make sure the branch contains the desired split attribute. + * See step 3 of the algorithm + * @param splitAttribute the attribute to split on + * @param splitAttributeIndex the index of the attribute to split on + * @param splitValue the value of the split (e.g., for numerical splits or binary-nominal) + * @param isBinary flag if the split is binary or multiway + * @return true, if the split was successful + */ + protected boolean forceSplit(Attribute splitAttribute, int splitAttributeIndex, Double splitValue, boolean isBinary) { + AttributeClassObserver observer = attributeObservers.get(splitAttributeIndex); + if (observer == null) { + observer = splitAttribute.isNominal() ? newNominalClassObserver() : newNumericClassObserver(); + this.attributeObservers.set(splitAttributeIndex, observer); + } + + boolean success; + if (splitAttribute.isNominal()) { + NominalAttributeClassObserver nominalObserver = (NominalAttributeClassObserver) observer; + AttributeSplitSuggestion suggestion = nominalObserver.forceSplit( + splitCriterion, observedClassDistribution.getArrayCopy(), splitAttributeIndex, isBinary, splitValue + ); + if (suggestion != null) { + success = makeSplit(splitAttribute, suggestion); + } else + success = false; + + if (!success) { + successors = new Successors(isBinary, splitAttribute.isNumeric(), splitValue); + this.splitAttribute = splitAttribute; + setSplitTest(suggestion == null ? null : suggestion.splitTest); + + if (!isBinary) { + PlasticNode dummyNode = newNode(depth, new DoubleVector(), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)); + SuccessorIdentifier dummyKey = new SuccessorIdentifier(splitAttribute.isNumeric(), 0.0, 0.0, false); + success = successors.addSuccessor(dummyNode, dummyKey); // will be pruned later on. + dummyNode.isDummy = true; + return success; + } else { + PlasticNode a = newNode(depth + 1, new DoubleVector(), new LinkedList<>(usedNominalAttributes)); + PlasticNode b = newNode(depth + 1, new DoubleVector(), new LinkedList<>(usedNominalAttributes)); + SuccessorIdentifier keyA = new SuccessorIdentifier(splitAttribute.isNumeric(), splitValue, splitValue, false); + successors.addSuccessor(a, keyA); + successors.addDefaultSuccessorNominalBinary(b); + return true; + } + } + } else { + GaussianNumericAttributeClassObserver numericObserver = (GaussianNumericAttributeClassObserver) observer; + AttributeSplitSuggestion suggestion = numericObserver.forceSplit( + splitCriterion, observedClassDistribution.getArrayCopy(), splitAttributeIndex, splitValue + ); + if (suggestion != null) { + success = makeSplit(splitAttribute, suggestion); + } else + success = false; + + if (!success) { + successors = new Successors(isBinary, splitAttribute.isNumeric(), splitValue); + this.splitAttribute = splitAttribute; + setSplitTest(suggestion == null ? null : suggestion.splitTest); + + for (int i = 0; i < 1; i++) { + PlasticNode dummyNode = newNode(depth, new DoubleVector(), getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex)); + SuccessorIdentifier dummyKey = new SuccessorIdentifier(splitAttribute.isNumeric(), splitValue, splitValue, i == 0); + success = successors.addSuccessor(dummyNode, dummyKey); // will be pruned later on. + dummyNode.isDummy = true; + if (!success) + break; + } + } + } + + return success; + } + + /** + * Reevaluate the split and restructure if needed + * @param instance the current instance + */ + @Override + protected void reevaluateSplit(Instance instance) { + if (isPure()) + return; + + numSplitAttempts++; + + AttributeSplitSuggestion[] bestSuggestions = getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSuggestions); + if (bestSuggestions.length == 0) + return; + updateInfogainSum(bestSuggestions); + + AttributeSplitSuggestion[] bestSplitSuggestions = getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSplitSuggestions); + AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + + double bestSuggestionAverageMerit = bestSuggestion.splitTest == null ? 0.0 : bestSuggestion.merit; + double currentAverageMerit = getCurrentSuggestionAverageMerit(bestSuggestions); + double deltaG = bestSuggestionAverageMerit - currentAverageMerit; + double eps = computeHoeffdingBound(); + + if (deltaG > eps || (eps < tauReevaluate && deltaG > tauReevaluate * relMinDeltaG)) { + + if (bestSuggestion.splitTest == null) { + System.out.println("preprune - null wins"); + killSubtree(); + resetSplitAttribute(); + return; + } + + Attribute newSplitAttribute = instance.attribute(bestSuggestion.splitTest.getAttsTestDependsOn()[0]); + boolean success = false; + performedTreeRevision = true; + if (maxBranchLength > 1) { + success = performReordering(bestSuggestion, newSplitAttribute); + if (success) + setSplitAttribute(bestSuggestion, newSplitAttribute); + } + if (!success) { + makeSplit(newSplitAttribute, bestSuggestion); + } + nodeTime = 0; + seenWeight = 0.0; + } + } + + /** + * Perform the restructuring and replace the subtree with the restructured subtree + * @param xBest the suggestion for the best split + * @param splitAttribute the attribute of the best split + * @return true if restructuring was successful + */ + private boolean performReordering(AttributeSplitSuggestion xBest, Attribute splitAttribute) { + Double splitValue = null; + InstanceConditionalTest test = xBest.splitTest; + if (test instanceof NominalAttributeBinaryTest) + splitValue = ((NominalAttributeBinaryTest) test).getValue(); + else if (test instanceof NumericAttributeBinaryTest) + splitValue = ((NumericAttributeBinaryTest) test).getValue(); + + PlasticNode restructuredNode = restructurer.restructure(this, xBest, splitAttribute, splitValue); + + if (restructuredNode != null) + successors = restructuredNode.getSuccessors(); + + return restructuredNode != null; + } + + protected void updateUsedNominalAttributesInSubtree(Attribute splitAttribute, Integer splitAttributeIndex) { + if (isLeaf()) + return; + for (CustomEFDTNode successor : successors.getAllSuccessors()) { + PlasticNode s = (PlasticNode) successor; + s.usedNominalAttributes = getUsedNominalAttributesForSuccessor(splitAttribute, splitAttributeIndex); + s.updateUsedNominalAttributesInSubtree(splitAttribute, splitAttributeIndex); + } + } + + protected Set getMajorityVotesOfLeaves() { + Set majorityVotes = new HashSet<>(); + if (isLeaf()) { + if (observedClassDistribution.numValues() == 0) + return majorityVotes; + majorityVotes.add((double) argmax(observedClassDistribution.getArrayRef())); + return majorityVotes; + } + for (CustomEFDTNode s : getSuccessors().getAllSuccessors()) { + majorityVotes.addAll(((PlasticNode) s).getMajorityVotesOfLeaves()); + } + return majorityVotes; + } + + protected void setObservedClassDistribution(double[] newDistribution) { + observedClassDistribution = new DoubleVector(newDistribution); + classDistributionAtTimeOfCreation = new DoubleVector(newDistribution); + mcCorrectWeight = 0.0; + nbCorrectWeight = 0.0; + } + + protected void resetObservedClassDistribution() { + observedClassDistribution = new DoubleVector(); + classDistributionAtTimeOfCreation = new DoubleVector(); + mcCorrectWeight = 0.0; + nbCorrectWeight = 0.0; + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java new file mode 100644 index 000000000..16eccecef --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java @@ -0,0 +1,40 @@ +package moa.classifiers.trees.plastic_util; + +class PlasticTreeElement { + private PlasticNode node; + private SuccessorIdentifier key; + + public PlasticTreeElement(PlasticNode node, SuccessorIdentifier key) { + this.node = node; + this.key = key; + } + + public PlasticTreeElement(PlasticTreeElement other) { + this.node = other.node; + this.key = other.key; + } + + public PlasticNode getNode() { + return node; + } + + public SuccessorIdentifier getKey() { + return key; + } + + public String getDescription() { + String blueprint = "%s%s -- %s"; + return String.format(blueprint, + node.splitAttribute != null ? node.splitAttribute.toString() : "L", + node.isArtificial() ? "*" : "", + key != null ? Double.toString(key.getReferencevalue()) : (node.isLeaf() ? "X" : "...") + ); + } + + public PlasticTreeElement copy() { + PlasticNode nodeCpy; + nodeCpy = new PlasticNode(node); + SuccessorIdentifier keyCpy = key != null ? new SuccessorIdentifier(key) : null; + return new PlasticTreeElement(nodeCpy, keyCpy); + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java new file mode 100644 index 000000000..04f88771b --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java @@ -0,0 +1,315 @@ +package moa.classifiers.trees.plastic_util; + +import com.yahoo.labs.samoa.instances.Attribute; +import moa.AbstractMOAObject; +import moa.classifiers.core.AttributeSplitSuggestion; +import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest; +import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest; +import moa.core.DoubleVector; + +import java.util.*; + +class Restructurer extends AbstractMOAObject { + private final int maxBranchLength; + private final double acceptedThresholdDeviation; + + public Restructurer(int maxBranchLength, + double acceptedNumericThresholdDeviation) { + this.maxBranchLength = maxBranchLength; + acceptedThresholdDeviation = acceptedNumericThresholdDeviation; + } + + public PlasticNode restructure(PlasticNode node, AttributeSplitSuggestion suggestion, Attribute splitAttribute, Double splitValue) { + boolean isBinary = !(suggestion.splitTest instanceof NominalAttributeMultiwayTest); + int splitAttributeIndex = suggestion.splitTest.getAttsTestDependsOn()[0]; + + boolean checkSucceeds = checkPreconditions(node, splitAttribute, splitValue, isBinary); + + if (!checkSucceeds) + return null; + + if (splitAttribute == node.splitAttribute && isBinary) { + assert splitValue != null; + Double currentNominalBinarysplitValue = node.getSuccessors().getReferenceValue(); + if (currentNominalBinarysplitValue.equals(splitValue)) + return node; + } + + if (node.splitAttribute.isNumeric() && splitAttribute.isNumeric()) { + assert splitValue != null; + Double currentSplitValue = node.getSuccessors().getReferenceValue(); + if (node.splitAttribute == splitAttribute) { + if (!currentSplitValue.equals(splitValue)) + updateThreshold(node, splitAttributeIndex, splitValue); + return node; + } + } + +// node.collectChildrenSplitAttributes(); + MappedTree mappedTree = new MappedTree(node, splitAttribute, splitAttributeIndex, splitValue, maxBranchLength); + PlasticNode newRoot = reassembleTree(mappedTree); + + newRoot.setSplitAttribute(suggestion, splitAttribute); + newRoot.updateUsedNominalAttributesInSubtree(splitAttribute, splitAttributeIndex); + + // Reset counters in restructured nodes + newRoot.getSuccessors().getAllSuccessors().forEach(s -> cleanupSubtree((PlasticNode) s)); + + // Initialize the statistics of the root's direct successors + List sortedKeys = new LinkedList<>(newRoot.getSuccessors().getKeyset()); + Collections.sort(sortedKeys); + for (SuccessorIdentifier key: sortedKeys) { + PlasticNode successor = (PlasticNode) newRoot.getSuccessors().getSuccessorNode(key); + if (splitAttribute.isNominal()) { + double selectorValue = key.getSelectorValue(); + if (selectorValue == SuccessorIdentifier.DEFAULT_NOMINAL_VALUE) { + assert isBinary; + successor.setObservedClassDistribution(suggestion.resultingClassDistributions[1]); + } + else if (selectorValue < suggestion.numSplits()) { + successor.setObservedClassDistribution(suggestion.resultingClassDistributionFromSplit((int) selectorValue)); + } + else { + successor.resetObservedClassDistribution(); + } + } + else { + if (key.isLower()) { + successor.setObservedClassDistribution(suggestion.resultingClassDistributions[0]); + } + else { + successor.setObservedClassDistribution(suggestion.resultingClassDistributions[1]); + } + } + } + + finalProcessing(node); + return newRoot; + } + + private boolean checkPreconditions(PlasticNode node, Attribute splitAttribute, Double splitValue, boolean isBinary) { + if (node.isLeaf()) + return false; + if (splitAttribute.isNominal()) { + if (node.getSplitTest() instanceof NominalAttributeBinaryTest && isBinary) { + if ( + ((NominalAttributeBinaryTest) node.getSplitTest()).getValue() == splitValue + && splitAttribute == node.splitAttribute + ) { + System.err.println("This should never be triggered. A binary re-split with the same attribute and split value should never happen"); + } + } + else if (node.getSplitTest() instanceof NominalAttributeMultiwayTest) { + if (splitAttribute == node.splitAttribute && !isBinary) + System.err.println("This should never be triggered. A multiway re-split on the same nominal attribute should never happen"); + } + } + return true; + } + + private void updateThreshold(PlasticNode node, int splitAttributeIndex, double splitValue) { + Double oldThreshold = node.getSuccessors().getReferenceValue(); + + SuccessorIdentifier leftKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, true); + SuccessorIdentifier rightKey = new SuccessorIdentifier(true, oldThreshold, oldThreshold, false); + PlasticNode succ1 = (PlasticNode) node.getSuccessors().getSuccessorNode(leftKey); + PlasticNode succ2 = (PlasticNode) node.getSuccessors().getSuccessorNode(rightKey); + Successors newSuccessors = new Successors(true, true, splitValue); + if (succ1 != null) + newSuccessors.addSuccessorNumeric(splitValue, succ1, true); + if (succ2 != null) + newSuccessors.addSuccessorNumeric(splitValue, succ2, false); + node.successors = newSuccessors; + + if (node.isLeaf()) + return; + + for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) { + PlasticNode s = (PlasticNode) node.getSuccessors().getSuccessorNode(key); + removeUnreachableSubtree(s, splitAttributeIndex, splitValue, key.isLower()); + } + + if (Math.abs(splitValue - oldThreshold) > acceptedThresholdDeviation) { + setRestructuredFlagInSubtree(node); + } + } + + private PlasticNode reassembleTree(LinkedList mappedTree) { + if (mappedTree.size() == 0) { + System.out.println("MappedTree is empty"); + } + + PlasticNode root = mappedTree.getFirst().getBranchRef().getFirst().getNode(); + for (PlasticBranch branch: mappedTree) { + PlasticNode currentNode = root; + + int depth = 0; + for (PlasticTreeElement thisElement: branch.getBranchRef()) { + if (depth == branch.getBranchRef().size() - 1) + break; + + PlasticNode thisNode = thisElement.getNode(); + SuccessorIdentifier thisKey = thisElement.getKey(); + if (currentNode.getSplitAttribute() == thisNode.getSplitAttribute()) { + if (currentNode.getSuccessors().contains(thisKey)) { + currentNode = (PlasticNode) currentNode.getSuccessors().getSuccessorNode(thisKey); + } + else { + PlasticNode newSuccessor = branch.getBranchRef().get(depth + 1).getNode(); + boolean success = currentNode.getSuccessors().addSuccessor(newSuccessor, thisKey); + assert success; + currentNode = newSuccessor; + } + } + depth++; + } + } + return root; + } + + private PlasticNode reassembleTree(MappedTree mappedTree) { + if (!mappedTree.hasNext()) { + System.out.println("MappedTree is empty"); + } + + PlasticNode root = null; + while (mappedTree.hasNext()) { + PlasticBranch branch = mappedTree.next(); + if (root == null) + root = branch.getBranchRef().getFirst().getNode(); + + PlasticNode currentNode = root; + + int depth = 0; + for (PlasticTreeElement thisElement: branch.getBranchRef()) { + if (depth == branch.getBranchRef().size() - 1) + break; + + PlasticNode thisNode = thisElement.getNode(); + SuccessorIdentifier thisKey = thisElement.getKey(); + if (currentNode.getSplitAttribute() == thisNode.getSplitAttribute()) { + if (currentNode.getSuccessors().contains(thisKey)) { + currentNode = (PlasticNode) currentNode.getSuccessors().getSuccessorNode(thisKey); + } + else { + PlasticNode newSuccessor = branch.getBranchRef().get(depth + 1).getNode(); + boolean success = currentNode.getSuccessors().addSuccessor(newSuccessor, thisKey); + assert success; + currentNode = newSuccessor; + } + } + depth++; + } + } + return root; + } + + private void cleanupSubtree(PlasticNode node) { + if (!node.getRestructuredFlag()) + return; + if (!node.isLeaf()) { + node.resetObservedClassDistribution(); + } + node.resetObservers(); + node.seenWeight = 0.0; + node.nodeTime = 0; + node.numSplitAttempts = 0; + if (!node.isLeaf()) + node.successors.getAllSuccessors().forEach(s -> cleanupSubtree((PlasticNode) s)); + } + + private void finalProcessing(PlasticNode node) { + node.setIsArtificial(false); + if (node.isLeaf()) { + return; + } + + Set keys = new HashSet<>(node.getSuccessors().getKeyset()); + boolean allSuccessorsArePure = true; + for (SuccessorIdentifier key : keys) { + PlasticNode thisNode = (PlasticNode) node.getSuccessors().getSuccessorNode(key); + if (thisNode.isDummy()) { + node.getSuccessors().removeSuccessor(key); + } + if (!thisNode.isPure()) + allSuccessorsArePure = false; + } + + if (node.isLeaf() || node.depth >= node.maxDepth) { + node.setObservedClassDistribution(collectStatsFromSuccessors(node).getArrayCopy()); + node.killSubtree(); + node.resetSplitAttribute(); + return; + } + + if ((allSuccessorsArePure && node.getMajorityVotesOfLeaves().size() <= 1)) { + node.setObservedClassDistribution(collectStatsFromSuccessors(node).getArrayCopy()); + node.killSubtree(); + node.resetSplitAttribute(); + return; + } + + for (SuccessorIdentifier key : node.getSuccessors().getKeyset()) { + PlasticNode successor = (PlasticNode) node.getSuccessors().getSuccessorNode(key); + finalProcessing(successor); + } + } + + private DoubleVector collectStatsFromSuccessors(CustomEFDTNode node) { + if (node.isLeaf()) { + return node.observedClassDistribution; + } + else { + DoubleVector stats = new DoubleVector(); + for (CustomEFDTNode successor : node.getSuccessors().getAllSuccessors()) { + DoubleVector fromSuccessor = successor.observedClassDistribution; //collectStatsFromSuccessors(successor); + stats.addValues(fromSuccessor); + } + return stats; + } + } + + private void removeUnreachableSubtree(PlasticNode node, int splitAttributeIndex, double threshold, boolean isLower) { + if (node.isLeaf()) + return; + + if (node.getSplitAttributeIndex() != splitAttributeIndex) { + for (CustomEFDTNode successor: node.getSuccessors().getAllSuccessors()) { + removeUnreachableSubtree((PlasticNode) successor, splitAttributeIndex, threshold, isLower); + } + return; + } + + Set keysToRemove = new HashSet<>(); + for (SuccessorIdentifier key: node.getSuccessors().getKeyset()) { + assert key.isNumeric(); + if (isLower) { + if (!key.isLower() && key.getSelectorValue() >= threshold) { + keysToRemove.add(key); + } + } + else { + if (key.isLower() && key.getSelectorValue() <= threshold) { + keysToRemove.add(key); + } + } + } + for (SuccessorIdentifier key: keysToRemove) { + node.getSuccessors().removeSuccessor(key); + } + + if (!node.isLeaf()) { + node.getSuccessors().getAllSuccessors().forEach(s -> removeUnreachableSubtree((PlasticNode) s, splitAttributeIndex, threshold, isLower)); + } + } + + private void setRestructuredFlagInSubtree(PlasticNode node) { + if (node.isLeaf()) + return; + node.setRestructuredFlag(); + node.getSuccessors().getAllSuccessors().forEach(s -> setRestructuredFlagInSubtree((PlasticNode) s)); + } + + @Override + public void getDescription(StringBuilder sb, int indent) {} +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java new file mode 100644 index 000000000..b4f847665 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java @@ -0,0 +1,149 @@ +package moa.classifiers.trees.plastic_util; + +import moa.AbstractMOAObject; + +import java.util.Objects; + +class SuccessorIdentifier extends AbstractMOAObject implements Comparable { + public static final double DEFAULT_NOMINAL_VALUE = -1.0; + private final boolean isNumeric; + private final Double selectorValue; + private final Double referenceValue; + private final boolean isLower; + private int hashCode; + + public SuccessorIdentifier(SuccessorIdentifier other) { + this.isLower = other.isLower; + this.isNumeric = other.isNumeric; + this.selectorValue = other.selectorValue; + this.referenceValue = other.referenceValue; + hashCode = toString().hashCode(); + } + + public SuccessorIdentifier(boolean isNumeric, Double referenceValue, Double selectorValue, boolean isLower) { + this.isNumeric = isNumeric; + this.isLower = isLower; + this.selectorValue = selectorValue; + + if (referenceValue == null) + this.referenceValue = selectorValue; + else + this.referenceValue = referenceValue; + + hashCode = toString().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null) + return false; + if (getClass() == o.getClass()) { + SuccessorIdentifier that = (SuccessorIdentifier) o; + boolean equal = isNumeric() == that.isNumeric(); + equal &= selectorValue.equals(that.getSelectorValue()); + equal &= referenceValue.equals(that.getReferencevalue()); + if (isNumeric) { + equal &= isLower() == that.isLower(); + } + return equal; + } + Double that = (Double) o; + if (isNumeric) + return containsNumericAttribute(that); + else { + boolean result = Objects.equals(selectorValue, that); + if (!result) { + // use the default successor if the reference value is not `that` and the selector value is the selectorvalue of the default successor + result = Objects.equals(selectorValue, DEFAULT_NOMINAL_VALUE) && !Objects.equals(referenceValue, that); + } + return result; + } + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public int compareTo(SuccessorIdentifier other) { + if (isNumeric != other.isNumeric) + return 0; + if (isNumeric) { + if (!Objects.equals(referenceValue, other.referenceValue)) + return 0; + if (isLower == other.isLower) + return 0; + return isLower ? -1 : 1; + } + else { + if (selectorValue == null) // this can only happen in the case of a dummy split (which will be pruned after reordering) + return 0; + if (selectorValue == DEFAULT_NOMINAL_VALUE || other.selectorValue == DEFAULT_NOMINAL_VALUE) + return referenceValue == DEFAULT_NOMINAL_VALUE ? 1 : -1; + if (selectorValue.equals(other.selectorValue)) + return 0; + if (selectorValue < other.selectorValue) + return -1; + return 1; + } + } + + public Double getSelectorValue() { + return selectorValue; + } + + public Double getReferencevalue() { + return referenceValue; + } + + private boolean matchesCategoricalValue(double attValue) { + if (isNumeric) + return false; + return selectorValue == attValue; + } + + private boolean containsNumericAttribute(double attValue) { + if (!isNumeric) + return false; + return isLower ? attValue <= selectorValue : attValue > selectorValue; + } + + public SuccessorIdentifier getOther() { + if (!isNumeric) + return null; + return new SuccessorIdentifier(isNumeric, selectorValue, selectorValue, !isLower); + } + + public boolean isNumeric() { + return isNumeric; + } + + public boolean isLower() { + if (!isNumeric) + return false; + return isLower; + } + + public Double getValue() { + return selectorValue; + } + + public String toString() { + if (isNumeric) { + String s = "%b%f%f%b"; + return String.format(s, true, referenceValue, selectorValue, isLower); + } + else { + String s = "%b%f%f"; + return String.format(s, false, referenceValue, selectorValue); + } + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } +} diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java new file mode 100644 index 000000000..1f378cda9 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java @@ -0,0 +1,221 @@ +package moa.classifiers.trees.plastic_util; + +import moa.AbstractMOAObject; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Set; + +class Successors extends AbstractMOAObject { + private Double referenceValue; + private HashMap successors = new HashMap<>(); + + public Successors(Successors other, boolean transferNodes) { + isBinarySplit = other.isBinary(); + isNumericSplit = !other.isNominal(); + referenceValue = other.getReferenceValue(); + if (transferNodes) { + successors = new HashMap<>(other.successors); + } + } + + public Successors(boolean isBinarySplit, boolean isNumericSplit, Double splitValue) { + this.isBinarySplit = isBinarySplit; + this.isNumericSplit = isNumericSplit; + this.referenceValue = splitValue; + } + + private final boolean isBinarySplit; + private final boolean isNumericSplit; + + + protected boolean addSuccessor(CustomEFDTNode node, SuccessorIdentifier key) { + if (node == null) + return false; + if (isNumericSplit != key.isNumeric()) + return false; + if (successors.size() >= 2 && isBinary()) { + return false; + } + if (successors.containsKey(key)) { + return false; + } + successors.put(key, node); + return true; + } + + + public boolean addSuccessorNumeric(Double attValue, CustomEFDTNode n, boolean isLower) { + if (n == null) + return false; + if (!isNumericSplit) + return false; + if (successors.size() >= 2) + return false; + if (referenceValue != null && !referenceValue.equals(attValue)) + return false; + + SuccessorIdentifier id = new SuccessorIdentifier(true, attValue, attValue, isLower); + if (successors.containsKey(id)) + return false; + + referenceValue = attValue; + successors.put(id, n); + return true; + } + + + public boolean addSuccessorNominalBinary(Double attValue, CustomEFDTNode n) { + if (n == null) + return false; + if (isNumericSplit) + return false; + if (!isBinarySplit) + return false; + if (successors.size() >= 2) + return false; + if (successors.size() == 1) { + SuccessorIdentifier key = (SuccessorIdentifier) successors.keySet().toArray()[0]; // get key of existing successor + if (key.getValue() == SuccessorIdentifier.DEFAULT_NOMINAL_VALUE) { // check if the key is the default key for nominal values + if (!referenceValue.equals(attValue)) // if the key is the default key, only add the successor if the referenceValue of the split matches the provided value + return false; + } + } + + SuccessorIdentifier id = new SuccessorIdentifier(false, attValue, attValue, false); + if (successors.containsKey(id)) + return false; + + referenceValue = attValue; + successors.put(id, n); + return true; + } + + public boolean addDefaultSuccessorNominalBinary(CustomEFDTNode n) { + if (n == null) + return false; + if (isNumericSplit) + return false; + if (!isBinarySplit) + return false; + if (successors.size() >= 2) + return false; + + SuccessorIdentifier id = new SuccessorIdentifier(false, referenceValue, SuccessorIdentifier.DEFAULT_NOMINAL_VALUE, false); + if (successors.containsKey(id)) + return false; + + successors.put(id, n); + return true; + } + + public boolean addSuccessorNominalMultiway(Double attValue, CustomEFDTNode n) { + if (n == null) + return false; + if (isNumericSplit) + return false; + if (isBinarySplit) + return false; + + SuccessorIdentifier id = new SuccessorIdentifier(false, attValue, attValue, false); + if (successors.containsKey(id)) + return false; + + successors.put(id, n); + return true; + } + + public CustomEFDTNode getSuccessorNode(SuccessorIdentifier key) { + return successors.get(key); + } + + public CustomEFDTNode getSuccessorNode(Double attributeValue) { + for (SuccessorIdentifier s : successors.keySet()) { + if (s.equals(attributeValue)) + return successors.get(s); + } + return null; + } + + public SuccessorIdentifier getSuccessorKey(Object key) { + //TODO: Looping over a set is probably not the best way to do this. + for (SuccessorIdentifier successorKey : successors.keySet()) { + if (successorKey.equals(key)) { + return successorKey; + } + } + return null; + } + + public boolean isNominal() { + return !isNumericSplit; + } + + public boolean isBinary() { + return isBinarySplit; + } + + public boolean contains(Object key) { + return successors.containsKey((SuccessorIdentifier) key); + } + + public Double getReferenceValue() { + return referenceValue; + } + + public int size() { + return successors.size(); + } + + public SuccessorIdentifier getMissingKey() { + if (successors.size() > 1) + return null; + SuccessorIdentifier someKey = (SuccessorIdentifier) successors.keySet().toArray()[0]; + return someKey.getOther(); + } + + public boolean lowerIsMissing() { + SuccessorIdentifier key = getMissingKey(); + if (key == null) + return false; + return key.isLower(); + } + + public boolean upperIsMissing() { + SuccessorIdentifier key = getMissingKey(); + if (key == null) + return false; + return !key.isLower(); + } + + public void adjustThreshold(double newThreshold) { + HashMap newSuccessors = new HashMap<>(); + for (SuccessorIdentifier oldId: successors.keySet()) { + SuccessorIdentifier newId = new SuccessorIdentifier(true, newThreshold, newThreshold, oldId.isLower()); + newSuccessors.put(newId, successors.get(oldId)); + } + referenceValue = newThreshold; + successors = newSuccessors; + } + + public Collection getAllSuccessors() { + return successors.values(); + } + + public Set getKeyset() { + return successors.keySet(); + } + + protected void forceSuccessorForKey(SuccessorIdentifier key, CustomEFDTNode node) { + successors.put(key, node); + } + + protected CustomEFDTNode removeSuccessor(SuccessorIdentifier key) { + return successors.remove(key); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + + } +} From 0ff24d53cde975750abf83fbed756ca2c0113789 Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Fri, 3 Oct 2025 12:54:12 +1300 Subject: [PATCH 29/31] remove '.DS_Store' --- .DS_Store | Bin 6148 -> 0 bytes .gitignore | 2 ++ moa/.DS_Store | Bin 6148 -> 0 bytes moa/src/.DS_Store | Bin 6148 -> 0 bytes moa/src/main/.DS_Store | Bin 6148 -> 0 bytes moa/src/main/java/.DS_Store | Bin 6148 -> 0 bytes moa/src/main/java/moa/.DS_Store | Bin 6148 -> 0 bytes moa/src/main/java/moa/classifiers/.DS_Store | Bin 6148 -> 0 bytes moa/src/main/java/moa/classifiers/meta/.DS_Store | Bin 6148 -> 0 bytes .../java/moa/classifiers/meta/AutoML/.DS_Store | Bin 6148 -> 0 bytes 10 files changed, 2 insertions(+) delete mode 100644 .DS_Store delete mode 100644 moa/.DS_Store delete mode 100644 moa/src/.DS_Store delete mode 100644 moa/src/main/.DS_Store delete mode 100644 moa/src/main/java/.DS_Store delete mode 100644 moa/src/main/java/moa/.DS_Store delete mode 100644 moa/src/main/java/moa/classifiers/.DS_Store delete mode 100644 moa/src/main/java/moa/classifiers/meta/.DS_Store delete mode 100644 moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index a361f1af412f356a6f96409579fabdc215d28871..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5WRgK4%jtvQrJ?)pN$h|lB9 z?glL8EMjM1_nY6{><8H&#u)b&QI|2BF=jzSJW~a6G=`M>+c=G6?#V?OWvbTK0jp)TM)tvS*&UqqJL0(StU6*j=s6v6 zdeU31TGrm-(fQTrC7Go1P4mftZY4VgOLzyxEa%mqrHM?Rz*AmunFpinIe7tZqJNSi4XWZ3DJuyHGtTRy8rj6(SCHyj#kNou# zvWNj<;GZ$T8$*BS!J^FB`eS)`)(U8k&`>b1Km`Q!wMzgPxR30sppFaFAS5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5`hgE3p0)*MP9cYPsW#OHBl zcLNr47O^w1`_1oe_JiyXV~l(AsLPnm7_*=ua#Sh=-Ibx5Nk-&2Mm7l(8G`i@PE72t z1Acp*Wo*hG!uH#Rn#R?})*cm9*i{aG-ZXI?P9LF-b=BrNqHyo$&9 z$lgAa=`4uT@l+MW(HK&0uj4e5xhLmol&M-@2dtLW8rge`MR#!0?})>`v+Rh)pyzbN z@lkKNY*{<|2d5XK=j0`oZ<Q5`*}*SVI^(WJ>WKkjV3mQoHf=or&*7J;eB`f| zkVOm-1OJQx-Wd8r4;E$4)*s8mvsOU6hlYZA1u7t*uUrDazfwk#I#3d-k#p8au}sPqARMShV-NDL4I!~ij{ zwhY)KLDbfk0a`dQKn(oE0PYVG8lr14H>kG`==84kzE}$afv$QxfXMS UI12i8IUrpGG$GUx1HZt)2XazMhyVZp diff --git a/moa/src/main/.DS_Store b/moa/src/main/.DS_Store deleted file mode 100644 index 11fd813eb4998949bc63e3738d6f353653be2d1a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z<-XrW7Fu1-&hJEtpo(7B8XJ7cim+m70*E(U>hw+8#^Bqp z>ww>Gu}2oL2@Ae{e>hCyEO*^^zE-i;H>!5kuG_c%gPi(#kdL!&FuF$TLdqm8^&q^A z2h+Z@bt=<5h||GH6~xg1Qf{u|G?LS<9H&vHYJDBBYj&;g?9OJ*&T+dT4%+U#A!ePH z+Ym>Gt@*rWZ}07&ocEuS=TyFEMmZ3!WYb^?ub_M`XyT93M5g!PEAz`dLSldzAO?tm zm1V#j33h#D8K8v|1H`~j4B-ACpdoq&3yo^)fDW(E7`G5nK*zTPqA=(gEHpv{gzHj3 zUCPZ9gX?ne3zO#=EHvtJ#?{O)j+wc9yl^!;_=QSm+|x)sF+dC~GEmW`gXjMl{4z@) z`HLlF5d*})KVyJfJ-^q5MVYhp+w$^d6KoLS6G4Klvd;se>N`e3Y diff --git a/moa/src/main/java/.DS_Store b/moa/src/main/java/.DS_Store deleted file mode 100644 index a361f1af412f356a6f96409579fabdc215d28871..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5WRgK4%jtvQrJ?)pN$h|lB9 z?glL8EMjM1_nY6{><8H&#u)b&QI|2BF=jzSJW~a6G=`M>+c=G6?#V?OWvbTK0jp)TM)tvS*&UqqJL0(StU6*j=s6v6 zdeU31TGrm-(fQTrC7Go1P4mftZY4VgOLzyxEa%mqrHM?Rz*AmunFpinIe7tZqJNSi4XWZ3DJuyHGtTRy8rj6(SCHyj#kNou# zvWNj<;GZ$T8$*BS!J^FB`eS)`)(U8k&`>b1Km`Q!wMzgPxR30sppFaFAmd~S83}}NkJ^Xf zLNq)6BLnpA%5Vn}_>kc5`~4+Bnhb(S1sL%j45Ba@)oLFimoKcX7oDP0c5b{!m3kw8 zG)~(7@QP;VN=3oU_JfPCpLQ#oCn_HKVcZ|;gs|7gkn77Z?y0n`#&Ivvxt>{YN=~U; z*_ll0t)pgD?l)`Gs+_bMwW>TkXiTRiXKQ!w_^kUBJ;&I+s6k$>kO{%b03}MpIFKwJ_F*9h=LFk?FId*4ZZzw|Vj((}bLAVBa zWCoalc?R-kTA}`by8Qk>pTsj}fEidR21KFbb=p{x?X3&NQLmM#x2Pl(ml^y@K|{A< gjHOn*jj9Fxk_<%GVrCFMDEuLyY2bkw_)`Wx0mQOS?EnA( diff --git a/moa/src/main/java/moa/classifiers/.DS_Store b/moa/src/main/java/moa/classifiers/.DS_Store deleted file mode 100644 index 11c33a3d2a008291ef4ab231397b4bad43844b67..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK-AcnS6i&A3GKSC#1-%P+JFuIg8{U*TU%-l9sLYlPEq2Y=I(sn&z1J7=MSLF5 zNm6k*Z$;cWkbLJiX+CIv7-PIU8#Ebn7-Iq&B1dI~pnGMgWRnp&juB?laTu!*>^B?x z>ww>0VIwwVDGR=Se>je#Y3jP~yi~4iY*s~8)Wx0us4_nd(%G~dOm5J+R4NV&JqWL& zVb*uH&s35IQ8Jw9f@m;=l-uhl8K|tQX31cxYkdfwk#I#3d-k#p8ZJ@tKkG`=)WEM|O44#3kyG=UL1R U;wb3X<$!b%(1cJ&4EzEEAIz{yi~s-t diff --git a/moa/src/main/java/moa/classifiers/meta/.DS_Store b/moa/src/main/java/moa/classifiers/meta/.DS_Store deleted file mode 100644 index 76d826fc257b1b5bee5cc826571beff1b9796070..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}xR_5N-jXgqU#9L~l*Jk_c)P;$^}0rWe=fK@GAlfsMi_%4v(C6E&hi zzt-qo$lj?CfjhMu?wQ{l*7GX|B8(b-*zBnU-)ln1<+&evVptP{(Ce#QM-Q+p%hvPj zqfx1{S1$1Fa&cVXqsmsXz;}1H#^Wqo-Pqhesow^jP~2-KDSUKFmJLo}0>;dyf>z17 z5uNTd;Ny5l*tYp-+0)B}w(Q61e2n*FOt=00I(JSl3`E$1xu?(m?vc<$29N<{U{MU1 zBZ%b|wIHs83?Ku4h5>m#a8QYs!C0d@I-pS{0ALE-O2C%2gv{XvErYQ}7y+R=6;P*A zQ({n^4t{RpEQ7H|oldAJKB!rlnhJ%g)nR_F!U?rBVv7tQ1G5YybvG~X|HH5A|Jfw$ zAp^+3zhZ!f#&a?oo~DKakB_?7~Ox{4uIUd3ymO2E&h0caVFHG&5O Oe*_c_*dPN7W#AQWt7v!t diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store b/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Fri, 3 Oct 2025 13:08:06 +1300 Subject: [PATCH 30/31] fix compiler errors --- moa/src/main/java/moa/classifiers/trees/EFDT.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moa/src/main/java/moa/classifiers/trees/EFDT.java b/moa/src/main/java/moa/classifiers/trees/EFDT.java index 314c6adfc..1d9a33c1e 100644 --- a/moa/src/main/java/moa/classifiers/trees/EFDT.java +++ b/moa/src/main/java/moa/classifiers/trees/EFDT.java @@ -194,7 +194,7 @@ public int calcByteSize() { } @Override - public int measureByteSize() { + public long measureByteSize() { return calcByteSize(); } @@ -350,7 +350,7 @@ public void estimateModelByteSizes() { this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize / this.inactiveLeafNodeCount; } - int actualModelSize = this.measureByteSize(); + long actualModelSize = this.measureByteSize(); double estimatedModelSize = (this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount * this.inactiveLeafByteSizeEstimate); From e7900a26df8c64827378e65d5560b68f52967918 Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Fri, 3 Oct 2025 13:05:35 +1300 Subject: [PATCH 31/31] fix storePredictions and storeY --- .github/workflows/capymoa.yml | 35 +++++ .gitignore | 28 ++++ README.md | 2 + .../evaluation/EfficientEvaluationLoops.java | 131 ++++++++++++------ 4 files changed, 150 insertions(+), 46 deletions(-) create mode 100644 .github/workflows/capymoa.yml diff --git a/.github/workflows/capymoa.yml b/.github/workflows/capymoa.yml new file mode 100644 index 000000000..58df5b7f6 --- /dev/null +++ b/.github/workflows/capymoa.yml @@ -0,0 +1,35 @@ +name: Package Jar for CapyMOA + +on: + push: + branches: [ master, capymoa ] + pull_request: + branches: [ master, capymoa ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: maven + + - name: Build with Maven + working-directory: ./moa + # no tests + run: mvn -B package --file pom.xml -DskipTests + + # Upload jar file as artifact + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: moa-jar + path: ./moa/target/moa-*-jar-with-dependencies.jar + if-no-files-found: error + retention-days: 7 diff --git a/.gitignore b/.gitignore index 07110325e..df512d98c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,34 @@ *.iml *~ *.bak +.DS_Store +.settings +.project # MacOS folder metadata .DS_Store + +# Compiled class file +*.class + +# Log file +*.log + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.nar +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* +replay_pid* diff --git a/README.md b/README.md index c80de92c9..740d525de 100755 --- a/README.md +++ b/README.md @@ -34,6 +34,8 @@ If you want to refer to MOA in a publication, please cite the following JMLR pap ## Building MOA for CapyMOA +You can now upload the build artifact from the GitHub Actions workflow. **Make sure to unzip it!** + > These steps assume you have Java installed and maven installed. If you don't > have maven installed, you can download it from > [here](https://maven.apache.org/download.cgi). You can achieve the same diff --git a/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java b/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java index 719c18ead..73a1b405c 100644 --- a/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java +++ b/moa/src/main/java/moa/evaluation/EfficientEvaluationLoops.java @@ -4,8 +4,8 @@ import moa.classifiers.SemiSupervisedLearner; import moa.classifiers.semisupervised.ClusterAndLabelClassifier; import moa.core.Example; -import moa.core.InstanceExample; import moa.core.Measurement; +import moa.core.Utils; import moa.learners.Learner; import moa.streams.ArffFileStream; import moa.streams.ExampleStream; @@ -26,30 +26,39 @@ public class EfficientEvaluationLoops { public static class PrequentialResult { public ArrayList windowedResults; public double[] cumulativeResults; - public ArrayList targets; - public ArrayList predictions; - + public ArrayList targets; + public ArrayList predictions; public HashMap otherMeasurements; - public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults) { - this.windowedResults = windowedResults; - this.cumulativeResults = cumulativeResults; - this.targets = null; - this.predictions = null; - } - - public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, - ArrayList targets, ArrayList predictions) { + public PrequentialResult( + ArrayList windowedResults, + double[] cumulativeResults, + ArrayList targets, + ArrayList predictions, + HashMap otherMeasurements + ) { this.windowedResults = windowedResults; this.cumulativeResults = cumulativeResults; this.targets = targets; this.predictions = predictions; + this.otherMeasurements = otherMeasurements; } - public PrequentialResult(ArrayList windowedResults, double[] cumulativeResults, - HashMap otherMeasurements) { - this(windowedResults, cumulativeResults); - this.otherMeasurements = otherMeasurements; + public PrequentialResult( + ArrayList windowedResults, + double[] cumulativeResults, + ArrayList targets, + ArrayList predictions + ) { + this(windowedResults, cumulativeResults, targets, predictions, null); + } + + public PrequentialResult( + ArrayList windowedResults, + double[] cumulativeResults, + HashMap otherMeasurements + ) { + this(windowedResults, cumulativeResults, null, null, otherMeasurements); } } @@ -65,11 +74,13 @@ public PrequentialResult(ArrayList windowedResults, double[] cumulativ * @param windowedEvaluator * @param maxInstances * @param windowSize + * @param storeY + * @param storePredictions * @return the return has to be an ArrayList because we don't know ahead of time how many windows will be produced */ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Learner learner, - LearningPerformanceEvaluator basicEvaluator, - LearningPerformanceEvaluator windowedEvaluator, + LearningPerformanceEvaluator> basicEvaluator, + LearningPerformanceEvaluator> windowedEvaluator, long maxInstances, long windowSize, boolean storeY, boolean storePredictions) { int instancesProcessed = 0; @@ -78,30 +89,31 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear stream.restart(); ArrayList windowed_results = new ArrayList<>(); - ArrayList targetValues = new ArrayList<>(); - ArrayList predictions = new ArrayList<>(); + ArrayList targetValues = new ArrayList<>(); + ArrayList predictions = new ArrayList<>(); while (stream.hasMoreInstances() && (maxInstances == -1 || instancesProcessed < maxInstances)) { Example instance = stream.nextInstance(); - if (storeY) - targetValues.add(instance.getData().classValue()); double[] prediction = learner.getVotesForInstance(instance); + + // Update evaluators and store predictions if requested if (basicEvaluator != null) basicEvaluator.addResult(instance, prediction); if (windowedEvaluator != null) windowedEvaluator.addResult(instance, prediction); - if (storePredictions) - predictions.add(prediction.length == 0? 0 : prediction[0]); + predictions.add(Utils.maxIndex(prediction)); + if (storeY) + targetValues.add((int)Math.round(instance.getData().classValue())); learner.trainOnInstance(instance); - instancesProcessed++; + // Store windowed results if requested if (windowedEvaluator != null) if (instancesProcessed % windowSize == 0) { Measurement[] measurements = windowedEvaluator.getPerformanceMeasurements(); @@ -128,22 +140,30 @@ public static PrequentialResult PrequentialEvaluation(ExampleStream stream, Lear for (int i = 0; i < cumulative_results.length; ++i) cumulative_results[i] = measurements[i].getValue(); } - if (!storePredictions && !storeY) - return new PrequentialResult(windowed_results, cumulative_results); - else - return new PrequentialResult(windowed_results, cumulative_results, targetValues, predictions); + + return new PrequentialResult( + windowed_results, + cumulative_results, + targetValues, + predictions + ); } - public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, Learner learner, - LearningPerformanceEvaluator basicEvaluator, - LearningPerformanceEvaluator windowedEvaluator, - long maxInstances, - long windowSize, - long initialWindowSize, - long delayLength, - double labelProbability, - int randomSeed, - boolean debugPseudoLabels) { + public static PrequentialResult PrequentialSSLEvaluation( + ExampleStream> stream, + Learner learner, + LearningPerformanceEvaluator basicEvaluator, + LearningPerformanceEvaluator windowedEvaluator, + long maxInstances, + long windowSize, + long initialWindowSize, + long delayLength, + double labelProbability, + int randomSeed, + boolean debugPseudoLabels, + boolean storeY, + boolean storePredictions + ) { // int delayLength = this.delayLengthOption.getValue(); // double labelProbability = this.labelProbabilityOption.getValue(); @@ -161,11 +181,13 @@ public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, L ArrayList windowed_results = new ArrayList<>(); + ArrayList targetValues = new ArrayList<>(); + ArrayList predictions = new ArrayList<>(); HashMap other_measures = new HashMap<>(); // The buffer is a list of tuples. The first element is the index when // it should be emitted. The second element is the instance itself. - List> delayBuffer = new ArrayList>(); + List>> delayBuffer = new ArrayList>>(); while (stream.hasMoreInstances() && (maxInstances == -1 || instancesProcessed < maxInstances)) { @@ -178,8 +200,8 @@ public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, L learner.trainOnInstance(delayedExample); } - Example instance = stream.nextInstance(); - Example unlabeledExample = instance.copy(); + Example instance = stream.nextInstance(); + Example unlabeledExample = instance.copy(); int trueClass = (int) ((Instance) instance.getData()).classValue(); // In case it is set, then the label is not removed. We want to pass the @@ -218,6 +240,10 @@ public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, L basicEvaluator.addResult(instance, prediction); if (windowedEvaluator != null) windowedEvaluator.addResult(instance, prediction); + if (storeY) + targetValues.add((int)Math.round(instance.getData().classValue())); + if (storePredictions) + predictions.add(Utils.maxIndex(prediction)); int pseudoLabel = -1; // TRAIN @@ -227,7 +253,7 @@ public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, L // System.out.println("[TRAIN_UNLABELED][DELAYED] " + unlabeledExample.getData().toString()); pseudoLabel = ((SemiSupervisedLearner) learner).trainOnUnlabeledInstance((Instance) unlabeledExample.getData()); } - delayBuffer.add(new MutablePair(1 + instancesProcessed + delayLength, instance)); + delayBuffer.add(new MutablePair<>(1 + instancesProcessed + delayLength, instance)); } else if (is_labeled) { // System.out.println("[TRAIN] " + instance.getData().toString()); // The instance will be labeled and is not delayed e.g delayLength = -1 @@ -276,7 +302,15 @@ public static PrequentialResult PrequentialSSLEvaluation(ExampleStream stream, L other_measures.put("num_correct_pseudo_labeled", (double) numCorrectPseudoLabeled); other_measures.put("num_instances_tested", (double) numInstancesTested); other_measures.put("pseudo_label_accuracy", (double) numCorrectPseudoLabeled/numInstancesTested); - return new PrequentialResult(windowed_results, cumulative_results, other_measures); + + + return new PrequentialResult( + windowed_results, + cumulative_results, + targetValues, + predictions, + other_measures + ); } /******************************************************************************************************************/ @@ -320,7 +354,12 @@ private static void testPrequentialSSL(String file_path, Learner learner, windowSize, initialWindowSize, delayLength, - labelProbability, 1, true); + labelProbability, + 1, + true, + false, + false + ); // Record the end time long endTime = System.currentTimeMillis();