Q5`*}*SVI^(WJ>WKkjV3mQoHf=or&*7J;eB`f|
zkVOm-1OJQx-Wd8r4;E$4)*s8mvsOU6hlYZA1u7t*uUrDazfwk#I#3d-k#p8au}sPqARMShV-NDL4I!~ij{
zwhY)KLDbfk0a`dQKn(oE0PYVG8lr14H>kG`==84kzE}$afv$QxfXMS
UI12i8IUrpGG$GUx1HZt)2XazMhyVZp
literal 0
HcmV?d00001
diff --git a/moa/src/main/.DS_Store b/moa/src/main/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..11fd813eb4998949bc63e3738d6f353653be2d1a
GIT binary patch
literal 6148
zcmeHK%}T>S5Z<-XrW7Fu1-&hJEtpo(7B8XJ7cim+m70*E(U>hw+8#^Bqp
z>ww>Gu}2oL2@Ae{e>hCyEO*^^zE-i;H>!5kuG_c%gPi(#kdL!&FuF$TLdqm8^&q^A
z2h+Z@bt=<5h||GH6~xg1Qf{u|G?LS<9H&vHYJDBBYj&;g?9OJ*&T+dT4%+U#A!ePH
z+Ym>Gt@*rWZ}07&ocEuS=TyFEMmZ3!WYb^?ub_M`XyT93M5g!PEAz`dLSldzAO?tm
zm1V#j33h#D8K8v|1H`~j4B-ACpdoq&3yo^)fDW(E7`G5nK*zTPqA=(gEHpv{gzHj3
zUCPZ9gX?ne3zO#=EHvtJ#?{O)j+wc9yl^!;_=QSm+|x)sF+dC~GEmW`gXjMl{4z@)
z`HLlF5d*})KVyJfJ-^q5MVYhp+w$^d6KoLS6G4Klvd;se>N`e3Y
literal 0
HcmV?d00001
diff --git a/moa/src/main/java/.DS_Store b/moa/src/main/java/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..a361f1af412f356a6f96409579fabdc215d28871
GIT binary patch
literal 6148
zcmeHK%}T>S5Z-O8O({YS3VK`cS}?7mEnY&bFJMFuDm5WRgK4%jtvQrJ?)pN$h|lB9
z?glL8EMjM1_nY6{><8Hu)b&QI|2BF=jzSJW~a6G=`M>+c=G6?#V?OWvbTK0jp)TM)tvS*&UqqJL0(StU6*j=s6v6
zdeU31TGrm-(fQTrC7Go1P4mftZY4VgOLzyxEa%mqrHM?Rz*AmunFpinIe7tZqJNSi4XWZ3DJuyHGtTRy8rj6(SCHyj#kNou#
zvWNj<;GZ$T8$*BS!J^FB`eS)`)(U8k&`>b1Km`Q!wMzgPxR30sppFaFAmd~S83}}NkJ^Xf
zLNq)6BLnpA%5Vn}_>kc5`~4+Bnhb(S1sL%j45Ba@)oLFimoKcX7oDP0c5b{!m3kw8
zG)~(7@QP;VN=3oU_JfPCpLQ#oCn_HKVcZ|;gs|7gkn77Z?y0n`#&Ivvxt>{YN=~U;
z*_ll0t)pgD?l)`Gs+_bMwW>TkXiTRiXKQ!w_^kUBJ;&I+s6k$>kO{%b03}MpIFKwJ_F*9h=LFk?FId*4ZZzw|Vj((}bLAVBa
zWCoalc?R-kTA}`by8Qk>pTsj}fEidR21KFbb=p{x?X3&NQLmM#x2Pl(ml^y@K|{A<
gjHOn*jj9Fxk_<%GVrCFMDEuLyY2bkw_)`Wx0mQOS?EnA(
literal 0
HcmV?d00001
diff --git a/moa/src/main/java/moa/classifiers/.DS_Store b/moa/src/main/java/moa/classifiers/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..11c33a3d2a008291ef4ab231397b4bad43844b67
GIT binary patch
literal 6148
zcmeHK-AcnS6i&A3GKSC#1-%P+JFuIg8{U*TU%-l9sLYlPEq2Y=I(sn&z1J7=MSLF5
zNm6k*Z$;cWkbLJiX+CIv7-PIU8#Ebn7-Iq&B1dI~pnGMgWRnp&juB?laTu!*>^B?x
z>ww>0VIwwVDGR=Se>je#Y3jP~yi~4iY*s~8)Wx0us4_nd(%G~dOm5J+R4NV&JqWL&
zVb*uH&s35IQ8Jw9f@m;=l-uhl8K|tQX31cxYkdfwk#I#3d-k#p8ZJ@tKkG`=)WEM|O44#3kyG=UL1R
U;wb3X<$!b%(1cJ&4EzEEAIz{yi~s-t
literal 0
HcmV?d00001
diff --git a/moa/src/main/java/moa/classifiers/meta/.DS_Store b/moa/src/main/java/moa/classifiers/meta/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..76d826fc257b1b5bee5cc826571beff1b9796070
GIT binary patch
literal 6148
zcmeHK%}xR_5N-jXgqU#9L~l*Jk_c)P;$^}0rWe=fK@GAlfsMi_%4v(C6E&hi
zzt-qo$lj?CfjhMu?wQ{l*7GX|B8(b-*zBnU-)ln1<+&evVptP{(Ce#QM-Q+p%hvPj
zqfx1{S1$1Fa&cVXqsmsXz;}1H#^Wqo-Pqhesow^jP~2-KDSUKFmJLo}0>;dyf>z17
z5uNTd;Ny5l*tYp-+0)B}w(Q61e2n*FOt=00I(JSl3`E$1xu?(m?vc<$29N<{U{MU1
zBZ%b|wIHs83?Ku4h5>m#a8QYs!C0d@I-pS{0ALE-O2C%2gv{XvErYQ}7y+R=6;P*A
zQ({n^4t{RpEQ7H|oldAJKB!rlnhJ%g)nR_F!U?rBVv7tQ1G5YybvG~X|HH5A|Jfw$
zAp^+3zhZ!f#&a?oo~DKakB_?7~Ox{4uIUd3ymO2E&h0caVFHG&5O
Oe*_c_*dPN7W#AQWt7v!t
literal 0
HcmV?d00001
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store b/moa/src/main/java/moa/classifiers/meta/AutoML/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
GIT binary patch
literal 6148
zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3
zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ
zLs35+`xjp>T0= 2){
+ System.out.println("Copy failed for " + x.classifier.getCLICreationString(Classifier.class) + "! Reinitialise instead.");
+ }
+ this.classifier = x.classifier; // keep the old algorithm for now
+ keepCurrentModel = false;
+ }
+ } else{
+ this.classifier = x.classifier; // keep the old algorithm for now
+ }
+
+ adjustAlgorithm(keepCurrentModel, verbose);
+ }
+
+ // init constructor
+ public Algorithm(AlgorithmConfiguration x) {
+
+ this.algorithm = x.algorithm;
+ this.parameters = new IParameter[x.parameters.length];
+ this.preventRemoval = false;
+ this.isDefault = true;
+
+ this.attributes = new Attribute[x.parameters.length];
+ for (int i = 0; i < x.parameters.length; i++) {
+
+
+ ParameterConfiguration paramConfig = x.parameters[i];
+ if (paramConfig.type.equals("numeric") || paramConfig.type.equals("float") || paramConfig.type.equals("real")) {
+ NumericalParameter param = new NumericalParameter(paramConfig);
+ this.parameters[i] = param;
+ this.attributes[i] = new Attribute(param.getParameter());
+ } else if (paramConfig.type.equals("integer")) {
+ IntegerParameter param = new IntegerParameter(paramConfig);
+ this.parameters[i] = param;
+ this.attributes[i] = new Attribute(param.getParameter());
+ } else if (paramConfig.type.equals("nominal") || paramConfig.type.equals("categorical") || paramConfig.type.equals("factor")) {
+ CategoricalParameter param = new CategoricalParameter(paramConfig);
+ this.parameters[i] = param;
+ this.attributes[i] = new Attribute(param.getParameter(), Arrays.asList(param.getRange()));
+ } else if (paramConfig.type.equals("boolean") || paramConfig.type.equals("flag")) {
+ BooleanParameter param = new BooleanParameter(paramConfig);
+ this.parameters[i] = param;
+ this.attributes[i] = new Attribute(param.getParameter(), Arrays.asList(param.getRange()));
+ } else if (paramConfig.type.equals("ordinal")) {
+ OrdinalParameter param = new OrdinalParameter(paramConfig);
+ this.parameters[i] = param;
+ this.attributes[i] = new Attribute(param.getParameter());
+ } else {
+ throw new RuntimeException("Unknown parameter type: '" + paramConfig.type
+ + "'. Available options are 'numeric', 'integer', 'nominal', 'boolean' or 'ordinal'");
+ }
+ }
+ init();
+ }
+
+ // initialise a new algorithm using the Command Line Interface (CLI)
+ public void init() {
+ // construct CLI string from settings, e.g. denstream.WithDBSCAN -e 0.08 -b 0.3
+ StringBuilder commandLine = new StringBuilder();
+ commandLine.append(this.algorithm); // first the algorithm class
+ for (IParameter param : this.parameters) {
+ commandLine.append(" ");
+ commandLine.append(param.getCLIString());
+ }
+
+ // create new classifier from CLI string
+ ClassOption opt = new ClassOption("", ' ', "", Classifier.class, commandLine.toString());
+ this.classifier = (AbstractClassifier) opt.materializeObject(null, null);
+ this.classifier.prepareForUse();
+ }
+
+ // sample a new configuration based on the current one
+ public void adjustAlgorithm(boolean keepCurrentModel, int verbose) {
+
+ if (keepCurrentModel) {
+ // Option 1: keep the old state and just change parameter
+ StringBuilder commandLine = new StringBuilder();
+ for (IParameter param : this.parameters) {
+ commandLine.append(param.getCLIString());
+ }
+
+ Options opts = this.classifier.getOptions();
+ for (IParameter param : this.parameters) {
+ Option opt = opts.getOption(param.getParameter().charAt(0));
+ opt.setValueViaCLIString(param.getCLIValueString());
+ }
+
+ // these changes do not transfer over directly since all algorithms cache the
+ // option values. Therefore we try to adjust the cached values if possible
+ try {
+ ((AbstractClassifier) this.classifier).adjustParameters();
+ if (verbose >= 2) {
+ System.out.println("Changed: " + this.classifier.getCLICreationString(Classifier.class));
+ }
+ } catch (UnsupportedOperationException e) {
+ if (verbose >= 2) {
+ System.out.println("Cannot change parameters of " + this.algorithm + " on the fly, reset instead.");
+ }
+ adjustAlgorithm(false, verbose);
+ }
+ } else{
+ // Option 2: reinitialise the entire state
+ this.init();
+ if (verbose >= 2) {
+ System.out.println("Initialise: " + this.classifier.getCLICreationString(Classifier.class));
+ }
+
+ if (verbose >= 2) {
+ System.out.println("Train with existing classifiers.");
+ }
+
+ }
+ }
+
+ // returns the parameter values as an array
+ public double[] getParamVector(int padding) {
+ double[] params = new double[this.parameters.length + padding];
+ int pos = 0;
+ for (IParameter param : this.parameters) {
+ params[pos++] = param.getValue();
+ }
+ return params;
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
new file mode 100755
index 000000000..c9b39aa68
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
@@ -0,0 +1,697 @@
+package moa.classifiers.meta.AutoML;
+
+import com.github.javacliparser.FileOption;
+import com.google.gson.Gson;
+import com.yahoo.labs.samoa.instances.DenseInstance;
+import com.yahoo.labs.samoa.instances.Instance;
+import com.yahoo.labs.samoa.instances.Instances;
+import moa.classifiers.Classifier;
+import moa.classifiers.MultiClassClassifier;
+import moa.classifiers.meta.AdaptiveRandomForestRegressor;
+import moa.core.DoubleVector;
+import moa.core.Measurement;
+import moa.core.ObjectRepository;
+import moa.tasks.TaskMonitor;
+import java.io.BufferedReader;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+
+ public class AutoClass extends HeterogeneousEnsembleAbstract implements MultiClassClassifier {
+
+ private static final long serialVersionUID = 1L;
+
+ int instancesSeen;
+ int iter;
+ public int bestModel;
+ public ArrayList ensemble;
+ public ArrayList candidateEnsemble;
+ public Instances windowPoints;
+ HashMap ARFregs = new HashMap();
+ GeneralConfiguration settings;
+ ArrayList performanceMeasures;
+ int verbose = 0;
+ protected ExecutorService executor;
+ int numberOfCores;
+ int corrects;
+ protected boolean[][] onlineHistory;
+ // the file option dialogue in the UI
+ public FileOption fileOption = new FileOption("ConfigurationFile", 'f', "Configuration file in json format.",
+ "/Users/mbahri/Desktop/Dell/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json", ".json", false);
+
+ public void init() {
+ this.fileOption.getFile();
+ }
+
+
+ @Override
+ public boolean isRandomizable() {
+ return false;
+ }
+
+ // @Override
+ //public double[] getVotesForInstance(Instance inst) {
+ //return null; }
+
+
+ @Override
+ public void resetLearningImpl() {
+ this.performanceMeasures = new ArrayList<>();
+ this.onlineHistory = new boolean[this.settings.ensembleSize][this.settings.windowSize];
+ this.historyTotal = new double[this.settings.ensembleSize];
+ this.instancesSeen = 0;
+ this.corrects = 0;
+ this.bestModel = 0;
+ this.iter = 0;
+ this.windowPoints = null ;
+
+ // reset ARFrefs
+ for (AdaptiveRandomForestRegressor ARFreg : this.ARFregs.values()) {
+ ARFreg.resetLearning();
+ }
+
+ // reset individual classifiers
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ //this.ensemble.get(i).init();
+ this.ensemble.get(i).classifier.resetLearning();
+ }
+
+ if (this.settings.numberOfCores == -1) {
+ this.numberOfCores = Runtime.getRuntime().availableProcessors();
+ } else {
+ this.numberOfCores = this.settings.numberOfCores;
+ }
+ this.executor = Executors.newFixedThreadPool(this.numberOfCores);
+ }
+
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+
+ if (this.windowPoints == null) {
+ this.windowPoints = new Instances(inst.dataset());
+ }
+
+ if (this.settings.windowSize <= this.windowPoints.numInstances()) {
+ this.windowPoints.delete(0);
+ }
+ this.windowPoints.add(inst); // remember points of the current window
+ this.instancesSeen++;
+
+ int wValue = this.settings.windowSize;
+
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ // Online performance estimation
+ double [] votes = ensemble.get(i).classifier.getVotesForInstance(inst);
+ boolean correct = (maxIndex(votes) * 1.0 == inst.classValue());
+
+
+ // Maroua: boolean correct = (votes[maxIndex(votes)] * 1.0 == inst.classValue());
+ if (correct && ! this.onlineHistory[i][instancesSeen % wValue]) {
+ // performance estimation increases
+ this.onlineHistory[i][instancesSeen % wValue] = true;
+
+ this.historyTotal[i] += 1.0/wValue;
+
+ } else if (!correct && this.onlineHistory[i][instancesSeen % wValue]) {
+ // performance estimation decreases
+ this.onlineHistory[i][instancesSeen % wValue] = false;
+ this.historyTotal[i] -= 1.0/wValue;
+
+ } else {
+ // nothing happens
+ }
+ }
+
+ if (this.numberOfCores == 1) {
+ // train all models with the instance
+ this.performanceMeasures = new ArrayList(this.ensemble.size());
+ double bestPerformance = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ this.ensemble.get(i).classifier.trainOnInstance(inst);
+
+ // To extract the best performing mode
+ double performance = this.historyTotal[i];
+ this.performanceMeasures.add(performance);
+ if (performance > bestPerformance) {
+ this.bestModel = i;
+ bestPerformance = performance;
+ }
+ }
+
+ if (this.settings.useTestEnsemble && this.candidateEnsemble.size() > 0) {
+ // train all models with the instance
+ for (int i = 0; i < this.candidateEnsemble.size(); i++) { //Maroua; cétait Ensemble
+ this.candidateEnsemble.get(i).classifier.trainOnInstance(inst);
+ }
+ }
+ } else {
+ //EnsembleClassifierAbstractAUTO.EnsembleRunnable
+ ArrayList trainers = new ArrayList();
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ EnsembleRunnable trainer = new EnsembleRunnable(this.ensemble.get(i).classifier, inst);
+ trainers.add(trainer);
+ }
+ if (this.settings.useTestEnsemble && this.candidateEnsemble.size() > 0) {
+ // train all models with the instance
+ for (int i = 0; i < this.candidateEnsemble.size(); i++) {
+ EnsembleRunnable trainer = new EnsembleRunnable(this.candidateEnsemble.get(i).classifier, inst);
+ trainers.add(trainer);
+ }
+ }
+ try {
+ this.executor.invokeAll(trainers);
+ } catch (InterruptedException ex) {
+ throw new RuntimeException("Could not call invokeAll() on training threads.");
+ }
+ }
+
+ // every windowSize, we update the configurations
+ if (this.instancesSeen % this.settings.windowSize == 0) {
+
+ if (this.verbose >= 1) {
+ System.out.println(" ");
+ System.out.println("-------------- Processed " + instancesSeen + " Instances --------------");
+ }
+
+ updateConfiguration(); // update configuration
+ }
+
+ }
+
+ /**
+ * Returns votes using the best performing method
+ * @param inst
+ * @return votes
+ */
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+
+
+ return this.ensemble.get(this.bestModel).classifier.getVotesForInstance(inst);
+// DoubleVector combinedVote = new DoubleVector();
+//
+// for(int i = 0 ; i < this.ensemble.size() ; ++i) {
+// DoubleVector vote = new DoubleVector(this.ensemble.get(i).classifier.getVotesForInstance(inst));
+// if (vote.sumOfValues() > 0.0) {
+// vote.normalize();
+//
+// combinedVote.addValues(vote);
+// }
+// }
+// return combinedVote.getArrayRef();
+ }
+
+
+ protected void updateConfiguration() {
+ // init evaluation measure
+ if (this.verbose >= 2) {
+ System.out.println(" ");
+ System.out.println("---- Evaluate performance of current ensemble:");
+ }
+ evaluatePerformance();
+
+ if (this.settings.useTestEnsemble) {
+ promoteCandidatesIntoEnsemble();
+ }
+
+ if (this.verbose >= 1) {
+ System.out.println("Classifier " + this.bestModel + " ("
+ + this.ensemble.get(this.bestModel).classifier.getCLICreationString(Classifier.class)
+ + ") is the active classifier with performance: " + this.performanceMeasures.get(this.bestModel));
+ }
+
+ generateNewConfigurations();
+
+ // this.windowPoints.delete(); // flush the current window
+ this.iter++;
+ }
+
+ protected void evaluatePerformance() {
+
+ HashMap bestPerformanceValMap = new HashMap();
+ HashMap bestPerformanceIdxMap = new HashMap();
+ HashMap algorithmCount = new HashMap();
+
+ this.performanceMeasures = new ArrayList(this.ensemble.size());
+ double bestPerformance = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < this.ensemble.size(); i++) {
+
+ // predict performance just for evaluation
+ predictPerformance(this.ensemble.get(i));
+
+ double performance = this.historyTotal[i];
+ this.performanceMeasures.add(performance);
+ if (performance > bestPerformance) {
+ this.bestModel = i;
+ bestPerformance = performance;
+ }
+
+ if (this.verbose >= 1) {
+ System.out.println(i + ") " + this.ensemble.get(i).classifier.getCLICreationString(Classifier.class)
+ + "\t => \t performance: " + performance);
+ }
+
+ String algorithm = this.ensemble.get(i).classifier.getPurposeString();
+ if (!bestPerformanceIdxMap.containsKey(algorithm) || performance > bestPerformanceValMap.get(algorithm)) {
+ bestPerformanceValMap.put(algorithm, performance); // best performance per algorithm
+ bestPerformanceIdxMap.put(algorithm, i); // index of best performance per algorithm
+ }
+
+ // number of instances per algorithm in ensemble
+
+ algorithmCount.put(algorithm, algorithmCount.getOrDefault(algorithm, 0) + 1);
+ trainRegressor(this.ensemble.get(i), performance);
+ }
+
+ updateRemovalFlags(bestPerformanceValMap, bestPerformanceIdxMap, algorithmCount);
+ }
+
+
+ /**
+ * Computes the accuracy of a learner for a given window of instances.
+ * @param algorithm classifier to compute error
+ * @return the computed accuracy.
+ */
+ protected double computePerformanceMeasure(Algorithm algorithm) {
+ double acc = 0;
+ this.trainOnChunk(algorithm);
+ for (int i = 0; i < this.windowPoints.numInstances(); i++) {
+ try {
+
+ double[] votes = algorithm.classifier.getVotesForInstance(this.windowPoints.instance(i));
+ boolean correct = (maxIndex(votes)* 1.0 == this.windowPoints.instance(i).classValue());
+
+ if (correct){
+ acc += 1.0/this.windowPoints.numInstances();
+ }else
+ acc -= 1.0/this.windowPoints.numInstances();
+ // algorithm.classifier.trainOnInstance(this.windowPoints.instance(i));....
+ } catch (Exception e) {
+ System.out.println("computePerformanceMeasure Error");
+ }
+ }
+ algorithm.performanceMeasure = acc;
+ return acc;
+ }
+ /**
+ * Trains a classifier on the most recent window of data.
+ *
+ * @param algorithm
+ * Classifier being trained.
+ */
+ private void trainOnChunk(Algorithm algorithm) {
+ for (int i = 0; i < this.windowPoints.numInstances(); i++) {
+ algorithm.classifier.trainOnInstance(this.windowPoints.instance(i));
+ }
+ }
+
+
+ protected void promoteCandidatesIntoEnsemble() {
+ for (int i = 0; i < this.candidateEnsemble.size(); i++) {
+
+ Algorithm newAlgorithm = this.candidateEnsemble.get(i);
+
+ // predict performance just for evaluation
+ predictPerformance(newAlgorithm);
+
+ // evaluate
+ double performance = computePerformanceMeasure(newAlgorithm);
+
+ if (this.verbose >= 1) {
+ System.out.println("Test " + i + ") " + newAlgorithm.classifier.getCLICreationString(Classifier.class)
+ + "\t => \t Performance: " + performance);
+ }
+
+ // replace if better than existing
+
+ if (this.ensemble.size() < this.settings.ensembleSize) {
+ if (this.verbose >= 1) {
+ System.out.println("Promote " + newAlgorithm.classifier.getCLICreationString(Classifier.class)
+ + " from test ensemble to the ensemble as new configuration");
+ }
+
+ this.performanceMeasures.add(newAlgorithm.performanceMeasure);
+
+ this.ensemble.add(newAlgorithm);
+
+ } else if (performance > AutoClass.getWorstSolution(this.performanceMeasures)) {
+
+ HashMap replace = getReplaceMap(this.performanceMeasures);
+
+ if (replace.size() == 0) {
+ return;
+ }
+
+ int replaceIdx = AutoClass.sampleProportionally(replace,
+ !this.settings.performanceMeasureMaximisation); // false
+ if (this.verbose >= 1) {
+ System.out.println("Promote " + newAlgorithm.classifier.getCLICreationString(Classifier.class)
+ + " from test ensemble to the ensemble by replacing " + replaceIdx);
+ }
+
+ // update performance measure
+ this.performanceMeasures.set(replaceIdx, newAlgorithm.performanceMeasure);
+
+ // replace in ensemble
+
+ this.ensemble.set(replaceIdx, newAlgorithm);
+ }
+
+ }
+ }
+
+ protected void trainRegressor(Algorithm algortihm, double performance) {
+ double[] params = algortihm.getParamVector(1);
+ params[params.length - 1] = performance; // add performance as class
+ Instance inst = new DenseInstance(1.0, params);
+
+ // add header to dataset TODO: do we need an attribute for the class label?
+ Instances dataset = new Instances(null, algortihm.attributes, 0);
+ dataset.setClassIndex(dataset.numAttributes()); // set class index to our performance feature
+ inst.setDataset(dataset);
+
+ // train adaptive random forest regressor based on performance of model
+ this.ARFregs.get(algortihm.algorithm).trainOnInstanceImpl(inst);
+ }
+
+ protected void updateRemovalFlags(HashMap bestPerformanceValMap,
+ HashMap bestPerformanceIdxMap, HashMap algorithmCount) {
+
+ // reset flags
+ for (Algorithm algorithm : ensemble) {
+ algorithm.preventRemoval = false;
+ }
+
+ // only keep best overall algorithm
+ if (this.settings.keepGlobalIncumbent) {
+ this.ensemble.get(this.bestModel).preventRemoval = true;
+ }
+
+ // keep best instance per algorithm
+ if (this.settings.keepAlgorithmIncumbents) {
+ for (int idx : bestPerformanceIdxMap.values()) {
+ this.ensemble.get(idx).preventRemoval = true;
+ }
+ }
+
+ // keep all default configurations
+ if (this.settings.keepInitialConfigurations) {
+ for (Algorithm algorithm : this.ensemble) {
+ if (algorithm.isDefault) {
+ algorithm.preventRemoval = true;
+ }
+ }
+ }
+
+ // keep at least one instance per algorithm
+ if (this.settings.preventAlgorithmDeath) {
+ for (Algorithm algorithm : this.ensemble) {
+ if (algorithmCount.get(algorithm.algorithm) == 1) {
+ algorithm.preventRemoval = true;
+ }
+ }
+ }
+ }
+
+ // predict performance of new configuration
+ protected void generateNewConfigurations() {
+
+ // get performance values
+ if (this.settings.useTestEnsemble) {
+ candidateEnsemble.clear();
+ }
+
+ for (int z = 0; z < this.settings.newConfigurations; z++) {
+
+ if (this.verbose == 2) {
+ System.out.println(" ");
+ System.out.println("---- Sample new configuration " + z + ":");
+ }
+
+ int parentIdx = sampleParent(this.performanceMeasures);
+ Algorithm newAlgorithm = sampleNewConfiguration(parentIdx);
+
+ if (this.settings.useTestEnsemble) {
+ if (this.verbose >= 1) {
+ System.out.println("Based on " + parentIdx + " add "
+ + newAlgorithm.classifier.getCLICreationString(Classifier.class) + " to test ensemble");
+ }
+ candidateEnsemble.add(newAlgorithm);
+ } else {
+ double prediction = predictPerformance(newAlgorithm);
+
+ if (this.verbose >= 1) {
+ System.out.println("Based on " + parentIdx + " predict: "
+ + newAlgorithm.classifier.getCLICreationString(Classifier.class) + "\t => \t Performance: "
+ + prediction);
+ }
+
+ // the random forest only works with at least two training samples
+ if (Double.isNaN(prediction)) {
+ return;
+ }
+
+ // if we still have open slots in the ensemble (not full)
+ if (this.ensemble.size() < this.settings.ensembleSize) {
+ if (this.verbose >= 1) {
+ System.out.println("Add configuration as new algorithm.");
+ }
+
+ // add to ensemble
+ this.ensemble.add(newAlgorithm);
+
+ // update current performance with the prediction
+ this.performanceMeasures.add(prediction);
+ } else if (prediction > AutoClass.getWorstSolution(this.performanceMeasures)) {
+ // if the predicted performance is better than the one we have in the ensemble
+ HashMap replace = getReplaceMap(this.performanceMeasures);
+
+ if (replace.size() == 0) {
+ return;
+ }
+
+ int replaceIdx = AutoClass.sampleProportionally(replace,
+ !this.settings.performanceMeasureMaximisation); // false
+
+ if (this.verbose >= 1) {
+ System.out.println("Replace algorithm: " + replaceIdx);
+ }
+
+ // update current performance with the prediction
+ this.performanceMeasures.set(replaceIdx, prediction);
+
+ // replace in ensemble
+ this.ensemble.set(replaceIdx, newAlgorithm);
+ }
+ }
+
+ }
+
+ }
+
+ protected int sampleParent(ArrayList performM) {
+ // copy existing classifier configuration
+ HashMap parents = new HashMap();
+ for (int i = 0; i < performM.size(); i++) {
+ parents.put(i, performM.get(i));
+ }
+ int parentIdx = AutoClass.sampleProportionally(parents,
+ this.settings.performanceMeasureMaximisation); // true
+
+ return parentIdx;
+ }
+
+ protected Algorithm sampleNewConfiguration(int parentIdx) {
+
+ if (this.verbose >= 2) {
+ System.out.println("Selected Configuration " + parentIdx + " as parent: "
+ + this.ensemble.get(parentIdx).classifier.getCLICreationString(Classifier.class));
+ }
+ Algorithm newAlgorithm = new Algorithm(this.ensemble.get(parentIdx), this.settings.lambda,
+ this.settings.resetProbability, this.settings.keepCurrentModel, this.verbose);
+
+ return newAlgorithm;
+ }
+
+ protected double predictPerformance(Algorithm newAlgorithm) {
+ // create instance from new configuration
+ double[] params = newAlgorithm.getParamVector(0);
+ Instance newInst = new DenseInstance(1.0, params);
+ Instances newDataset = new Instances(null, newAlgorithm.attributes, 0);
+ newDataset.setClassIndex(newDataset.numAttributes());
+ newInst.setDataset(newDataset);
+
+ // predict the performance of the new configuration using the trained adaptive
+ // random forest
+ double prediction = this.ARFregs.get(newAlgorithm.algorithm).getVotesForInstance(newInst)[0];
+
+ newAlgorithm.prediction = prediction; // remember prediction
+
+ return prediction;
+ }
+
+ // get mapping of algorithms and their performance that could be removed
+ HashMap getReplaceMap(ArrayList performM) {
+ HashMap replace = new HashMap();
+
+ double worst = AutoClass.getWorstSolution(performM);
+
+ // replace solutions that cannot get worse first
+ if (worst <= -1.0) {
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ if (performM.get(i) <= -1.0 && !this.ensemble.get(i).preventRemoval) {
+ replace.put(i, performM.get(i));
+ }
+ }
+ }
+
+ if (replace.size() == 0) {
+ for (int i = 0; i < this.ensemble.size(); i++) {
+ if (!this.ensemble.get(i).preventRemoval) {
+ replace.put(i, performM.get(i));
+ }
+ }
+ }
+
+ return replace;
+ }
+
+ // get lowest value in arraylist
+ static double getWorstSolution(ArrayList values) {
+
+ double min = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < values.size(); i++) {
+ if (values.get(i) < min) {
+ min = values.get(i);
+ }
+ }
+ return (min);
+ }
+
+ static int sampleProportionally(HashMap values, boolean maximisation) {
+
+ // if we want to sample lower values with higher probability, we invert here
+ if (!maximisation) {
+ HashMap vals = new HashMap(values.size());
+
+ for (int i : values.keySet()) {
+ vals.put(i, -1 * values.get(i));
+ }
+ return (AutoClass.rouletteWheelSelection(vals));
+ }
+
+ return (AutoClass.rouletteWheelSelection(values));
+ }
+
+ // sample an index from a list of values, proportionally to the respective value
+ static int rouletteWheelSelection(HashMap values) {
+
+ // get min
+ double minVal = Double.POSITIVE_INFINITY;
+ for (Double value : values.values()) {
+ if (value < minVal) {
+ minVal = value;
+ }
+ }
+
+ // to have a positive range we shift here
+ double shift = Math.abs(minVal) - minVal;
+
+ double completeWeight = 0.0;
+ for (Double value : values.values()) {
+ completeWeight += value + shift;
+ }
+
+ // sample random number within range of total weight
+ double r = Math.random() * completeWeight;
+ double countWeight = 0.0;
+
+ for (int j : values.keySet()) {
+ countWeight += values.get(j) + shift;
+ if (countWeight >= r) {
+ return j;
+ }
+ }
+ throw new RuntimeException("Sampling failed");
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ // TODO Auto-generated method stub
+ }
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+
+ try {
+
+ // read settings from json
+ BufferedReader bufferedReader = new BufferedReader(new FileReader(fileOption.getValue()));
+ Gson gson = new Gson();
+ // store settings in dedicated class structure
+ this.settings = gson.fromJson(bufferedReader, GeneralConfiguration.class);
+
+ this.instancesSeen = 0;
+ this.bestModel = 0;
+ this.iter = 0;
+ this.windowPoints = null;
+
+ // create the ensemble
+ this.ensemble = new ArrayList(this.settings.ensembleSize);
+ // copy and initialise the provided starting configurations in the ensemble
+ for (int i = 0; i < this.settings.algorithms.length; i++) {
+ this.ensemble.add(new Algorithm(this.settings.algorithms[i]));
+ }
+
+ if (this.settings.useTestEnsemble) {
+ this.candidateEnsemble = new ArrayList(this.settings.newConfigurations);
+ }
+
+ // create one regressor per algorithm
+ for (int i = 0; i < this.settings.algorithms.length; i++) {
+ AdaptiveRandomForestRegressor ARFreg = new AdaptiveRandomForestRegressor();
+ ARFreg.prepareForUse();
+ this.ARFregs.put(this.settings.algorithms[i].algorithm, ARFreg);
+ }
+
+ } catch (FileNotFoundException e) {
+ e.printStackTrace();
+ }
+
+ super.prepareForUseImpl(monitor, repository);
+
+ }
+
+ // Modified from:
+ // https://github.com/Waikato/moa/blob/master/moa/src/main/java/moa/classifiers/meta/AdaptiveRandomForest.java#L157
+ // Helper class for parallelisation
+ protected class EnsembleRunnable implements Runnable, Callable {
+ final private Classifier classifier;
+ final private Instance instance;
+
+ public EnsembleRunnable(Classifier classifier, Instance instance) {
+ this.classifier = classifier;
+ this.instance = instance;
+ }
+
+ @Override
+ public void run() {
+ classifier.trainOnInstance(this.instance);
+ }
+
+ @Override
+ public Integer call() throws Exception {
+ run();
+ return 0;
+ }
+ }
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java
new file mode 100755
index 000000000..4342668e7
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/BooleanParameter.java
@@ -0,0 +1,127 @@
+package moa.classifiers.meta.AutoML;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import java.util.ArrayList;
+import java.util.HashMap;
+
+// the representation of a boolean / binary / flag parameter
+public class BooleanParameter implements IParameter {
+ private String parameter;
+ private int numericValue;
+ private String value;
+ private String[] range = { "false", "true" };
+ private Attribute attribute;
+ private ArrayList probabilities;
+ private boolean optimise;
+
+ public BooleanParameter(BooleanParameter x) {
+ this.parameter = x.parameter;
+ this.numericValue = x.numericValue;
+ this.value = x.value;
+ this.attribute = x.attribute;
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = x.range.clone();
+ this.probabilities = new ArrayList(x.probabilities);
+ }
+
+ }
+
+ public BooleanParameter(ParameterConfiguration x) {
+ this.parameter = x.parameter;
+ this.value = String.valueOf(x.value);
+ for (int i = 0; i < this.range.length; i++) {
+ if (this.range[i].equals(this.value)) {
+ this.numericValue = i; // get index of init value
+ }
+ }
+ this.attribute = new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.probabilities = new ArrayList(2);
+ for (int i = 0; i < 2; i++) {
+ this.probabilities.add(0.5); // equal probabilities
+ }
+ }
+ }
+
+ public BooleanParameter copy() {
+ return new BooleanParameter(this);
+ }
+
+ public String getCLIString() {
+ // if option is set
+ if (this.numericValue == 1) {
+ return ("-" + this.parameter); // only the parameter
+ }
+ return "";
+ }
+
+ public String getCLIValueString() {
+ if (this.numericValue == 1) {
+ return ("");
+ } else {
+ return (null);
+ }
+ }
+
+ public double getValue() {
+ return this.numericValue;
+ }
+
+ public String getParameter() {
+ return this.parameter;
+ }
+
+ public String[] getRange() {
+ return this.range;
+ }
+
+ public void sampleNewConfig(double lambda, double reset, int verbose) {
+
+ if(!this.optimise){
+ return;
+ }
+
+ if (Math.random() < reset) {
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ this.probabilities.set(i, 1.0/this.probabilities.size());
+ }
+ }
+
+ HashMap map = new HashMap();
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ map.put(i, this.probabilities.get(i));
+ }
+
+ // update configuration
+ this.numericValue = AutoClass.sampleProportionally(map, true);
+ String newValue = this.range[this.numericValue];
+ if (verbose >= 3) {
+ System.out
+ .print("Sample new configuration for boolean parameter -" + this.parameter + " with probabilities");
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ System.out.print(" " + this.probabilities.get(i));
+ }
+ System.out.println("\t=>\t -" + this.parameter + " " + newValue);
+ }
+ this.value = newValue;
+
+ // adapt distribution
+ // this.probabilities.set(this.numericValue,
+ // this.probabilities.get(this.numericValue) + (1.0/iter));
+ this.probabilities.set(this.numericValue,
+ this.probabilities.get(this.numericValue) * (2 - Math.pow(2, -1 * lambda)));
+
+ // divide by sum
+ double sum = 0.0;
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ sum += this.probabilities.get(i);
+ }
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ this.probabilities.set(i, this.probabilities.get(i) / sum);
+ }
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java
new file mode 100755
index 000000000..e85e672c8
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/CategoricalParameter.java
@@ -0,0 +1,125 @@
+package moa.classifiers.meta.AutoML;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import moa.classifiers.meta.AutoML.IParameter;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+
+// the representation of a categorical / nominal parameter
+public class CategoricalParameter implements IParameter {
+ private String parameter;
+ private int numericValue;
+ private String value;
+ private String[] range;
+ private Attribute attribute;
+ private ArrayList probabilities;
+ private boolean optimise;
+
+ public CategoricalParameter(CategoricalParameter x) {
+ this.parameter = x.parameter;
+ this.numericValue = x.numericValue;
+ this.value = x.value;
+ this.attribute = x.attribute;
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = x.range.clone();
+ this.probabilities = new ArrayList(x.probabilities);
+ }
+ }
+
+ public CategoricalParameter(ParameterConfiguration x) {
+ this.parameter = x.parameter;
+ this.value = String.valueOf(x.value);
+ this.attribute = new Attribute(x.parameter, Arrays.asList(range));
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = new String[x.range.length];
+ for (int i = 0; i < x.range.length; i++) {
+ range[i] = String.valueOf(x.range[i]);
+ if (this.range[i].equals(this.value)) {
+ this.numericValue = i; // get index of init value
+ }
+ }
+ this.probabilities = new ArrayList(x.range.length);
+ for (int i = 0; i < x.range.length; i++) {
+ this.probabilities.add(1.0 / x.range.length); // equal probabilities
+ }
+ }
+ }
+
+ public CategoricalParameter copy() {
+ return new CategoricalParameter(this);
+ }
+
+ public String getCLIString() {
+ return ("-" + this.parameter + " " + this.value);
+ }
+
+ public String getCLIValueString() {
+ return ("" + this.value);
+ }
+
+ public double getValue() {
+ return this.numericValue;
+ }
+
+ public String getParameter() {
+ return this.parameter;
+ }
+
+ public String[] getRange() {
+ return this.range;
+ }
+
+ public void sampleNewConfig(double lambda, double reset, int verbose) {
+
+ if(!this.optimise){
+ return;
+ }
+
+ if (Math.random() < reset) {
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ this.probabilities.set(i, 1.0/this.probabilities.size());
+ }
+ }
+
+ HashMap map = new HashMap();
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ map.put(i, this.probabilities.get(i));
+ }
+ // update configuration
+ this.numericValue = AutoClass.sampleProportionally(map, true);
+ String newValue = this.range[this.numericValue];
+
+ if (verbose >= 3) {
+ System.out
+ .print("Sample new configuration for nominal parameter -" + this.parameter + "with probabilities");
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ System.out.print(" " + this.probabilities.get(i));
+ }
+ System.out.println("\t=>\t -" + this.parameter + " " + newValue);
+ }
+ this.value = newValue;
+
+ // adapt distribution
+ // TODO not directly transferable from irace: (1-((iter -1) / maxIter))
+ // this.probabilities.set(this.numericValue,
+ // this.probabilities.get(this.numericValue) + (1.0/iter));
+ this.probabilities.set(this.numericValue,
+ this.probabilities.get(this.numericValue) * (2 - Math.pow(2, -1 * lambda)));
+
+ // divide by sum (TODO is this even necessary with our proportional sampling
+ // strategy?)
+ double sum = 0.0;
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ sum += this.probabilities.get(i);
+ }
+ for (int i = 0; i < this.probabilities.size(); i++) {
+ this.probabilities.set(i, this.probabilities.get(i) / sum);
+ }
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java
new file mode 100755
index 000000000..7a17f8d64
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java
@@ -0,0 +1,228 @@
+/*
+ * HeterogeneousEnsembleAbstract.java
+ * Copyright (C) 2017 University of Waikato, Hamilton, New Zealand
+ * @author Jan N. van Rijn (j.n.van.rijn@liacs.leidenuniv.nl)
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+package moa.classifiers.meta.AutoML;
+
+import com.github.javacliparser.FlagOption;
+import com.github.javacliparser.IntOption;
+import com.github.javacliparser.ListOption;
+import com.github.javacliparser.Option;
+import com.yahoo.labs.samoa.instances.Instance;
+import com.yahoo.labs.samoa.instances.InstancesHeader;
+import moa.classifiers.AbstractClassifier;
+import moa.classifiers.Classifier;
+import moa.classifiers.MultiClassClassifier;
+import moa.core.Measurement;
+import moa.core.ObjectRepository;
+import moa.options.ClassOption;
+import moa.tasks.TaskMonitor;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * BLAST (Best Last) for Heterogeneous Ensembles Abstract Base Class
+ *
+ *
+ * Given a set of (heterogeneous) classifiers, BLAST builds an ensemble, and
+ * determines the weights of all ensemble members based on their performance on
+ * recent observed instances. Used as Abstact Base Class for
+ * HeterogeneousEnsembleBlast and HeterogeneousEnsembleBlastFadingFactors.
+ *
+ *
+ *
+ * J. N. van Rijn, G. Holmes, B. Pfahringer, J. Vanschoren. Having a Blast:
+ * Meta-Learning and Heterogeneous Ensembles for Data Streams. In 2015 IEEE
+ * International Conference on Data Mining, pages 1003-1008. IEEE, 2015.
+ *
+ *
+ *
+ * Parameters:
+ *
+ *
+ * - -b : Comma-separated string of classifiers
+ * - -g : Grace period (1 = optimal)
+ * - -k : Number of active classifiers
+ *
+ * Maroua: Created for EnsembleClassifierAbstractAuto2
+ * @author Jan N. van Rijn (j.n.van.rijn@liacs.leidenuniv.nl)
+ * @version $Revision: 1 $
+ */
+public abstract class HeterogeneousEnsembleAbstract extends AbstractClassifier implements MultiClassClassifier {
+
+ private static final long serialVersionUID = 1L;
+
+ public ListOption baselearnersOption = new ListOption("baseClassifiers", 'b',
+ "The classifiers the ensemble consists of.",
+ new ClassOption("learner", ' ', "", Classifier.class,
+ "trees.HoeffdingTree"),
+ new Option[] {
+ new ClassOption("", ' ', "", Classifier.class, "bayes.NaiveBayes"),
+ new ClassOption("", ' ', "", Classifier.class,
+ "functions.Perceptron"),
+ new ClassOption("", ' ', "", Classifier.class, "functions.SGD"),
+ new ClassOption("", ' ', "", Classifier.class, "lazy.kNN"),
+ new ClassOption("", ' ', "", Classifier.class,
+ "trees.HoeffdingTree") },
+ ',');
+
+ public IntOption gracePerionOption = new IntOption("gracePeriod", 'g',
+ "How many instances before we reevalate the best classifier", 1, 1,
+ Integer.MAX_VALUE);
+
+ public IntOption activeClassifiersOption = new IntOption("activeClassifiers",
+ 'k', "The number of active classifiers (used for voting)", 1, 1,
+ Integer.MAX_VALUE);
+
+ public FlagOption weightClassifiersOption = new FlagOption(
+ "weightClassifiers", 'p',
+ "Uses online performance estimation to weight the classifiers");
+
+ protected Classifier[] ensemble;
+
+ protected double[] historyTotal;
+
+ protected Integer instancesSeen;
+
+ List topK;
+
+ @Override
+ public String getPurposeString() {
+ return "The model-free heterogeneous ensemble as presented in "
+ + "'Having a Blast: Meta-Learning and Heterogeneous Ensembles "
+ + "for Data Streams' (ICDM 2015).";
+ }
+
+ public int getEnsembleSize() {
+ // mainly for testing, @throws exception if called before initialization
+ return this.ensemble.length;
+ }
+
+ public String getMemberCliString(int idx) {
+ // mainly for testing, @pre: idx < getEnsembleSize()
+ return this.ensemble[idx].getCLICreationString(Classifier.class);
+ }
+
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ double[] votes = new double[inst.classAttribute().numValues()];
+
+ for (int i = 0; i < topK.size(); ++i) {
+ double[] memberVotes = normalize(
+ ensemble[topK.get(i)].getVotesForInstance(inst));
+ double weight = 1.0;
+
+ if (weightClassifiersOption.isSet()) {
+ weight = historyTotal[topK.get(i)];
+ }
+
+ // make internal classifiers so-called "hard classifiers"
+ votes[maxIndex(memberVotes)] += 1.0 * weight;
+ }
+
+ return votes;
+ }
+
+ @Override
+ public void setModelContext(InstancesHeader ih) {
+ super.setModelContext(ih);
+
+ for (int i = 0; i < this.ensemble.length; ++i) {
+ this.ensemble[i].setModelContext(ih);
+ }
+ }
+
+ @Override
+ public boolean isRandomizable() {
+ return false;
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder arg0, int arg1) {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+
+ Option[] learnerOptions = this.baselearnersOption.getList();
+ this.ensemble = new Classifier[learnerOptions.length];
+ for (int i = 0; i < learnerOptions.length; i++) {
+ monitor.setCurrentActivity("Materializing learner " + (i + 1) + "...",
+ -1.0);
+ this.ensemble[i] = (Classifier) ((ClassOption) learnerOptions[i])
+ .materializeObject(monitor, repository);
+ if (monitor.taskShouldAbort()) {
+ return;
+ }
+ monitor.setCurrentActivity("Preparing learner " + (i + 1) + "...", -1.0);
+ this.ensemble[i].prepareForUse(monitor, repository);
+ if (monitor.taskShouldAbort()) {
+ return;
+ }
+ }
+ super.prepareForUseImpl(monitor, repository);
+
+ topK = topK(historyTotal, activeClassifiersOption.getValue());
+ }
+
+ protected static List topK(double[] scores, int k) {
+ double[] scoresWorking = Arrays.copyOf(scores, scores.length);
+
+ List topK = new ArrayList();
+
+ for (int i = 0; i < k; ++i) {
+ int bestIdx = maxIndex(scoresWorking);
+ topK.add(bestIdx);
+ scoresWorking[bestIdx] = -1;
+ }
+
+ return topK;
+ }
+
+ protected static int maxIndex(double[] scores) {
+ int bestIdx = 0;
+ for (int i = 1; i < scores.length; ++i) {
+ if (scores[i] > scores[bestIdx]) {
+ bestIdx = i;
+ }
+ }
+ return bestIdx;
+ }
+
+ protected static double[] normalize(double[] input) {
+ double sum = 0.0;
+ for (int i = 0; i < input.length; ++i) {
+ sum += input[i];
+ }
+ for (int i = 0; i < input.length; ++i) {
+ input[i] /= sum;
+ }
+ return input;
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java
new file mode 100755
index 000000000..3c3d9af66
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/IParameter.java
@@ -0,0 +1,16 @@
+package moa.classifiers.meta.AutoML;
+
+// interface allows us to maintain a single list of parameters
+public interface IParameter {
+ public void sampleNewConfig(double lambda, double reset, int verbose);
+
+ public IParameter copy();
+
+ public String getCLIString();
+
+ public String getCLIValueString();
+
+ public double getValue();
+
+ public String getParameter();
+}
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java
new file mode 100755
index 000000000..517e8fe6b
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/IntegerParameter.java
@@ -0,0 +1,89 @@
+package moa.classifiers.meta.AutoML;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+
+// the representation of an integer parameter
+public class IntegerParameter implements IParameter {
+ private String parameter;
+ private int value;
+ private int[] range;
+ private double std;
+ private Attribute attribute;
+ private boolean optimise;
+
+ public IntegerParameter(IntegerParameter x) {
+ this.parameter = x.parameter;
+ this.value = x.value;
+ this.attribute = x.attribute;// new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = x.range.clone();
+ this.std = x.std;
+ }
+ }
+
+ public IntegerParameter(ParameterConfiguration x) {
+ this.parameter = x.parameter;
+ this.value = (int)(double) x.value; // TODO fix casts
+
+ this.attribute = new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = new int[x.range.length];
+ for (int i = 0; i < x.range.length; i++) {
+ range[i] = (int) (double) x.range[i];
+ }
+ this.std = (this.range[1] - this.range[0]) / 2;
+ }
+ }
+
+ public IntegerParameter copy() {
+ return new IntegerParameter(this);
+ }
+
+ public String getCLIString() {
+ return ("-" + this.parameter + " " + this.value);
+ }
+
+ public String getCLIValueString() {
+ return ("" + this.value);
+ }
+
+ public double getValue() {
+ return this.value;
+ }
+
+ public void setValue(int value){
+ this.value = value;
+ }
+
+ public String getParameter() {
+ return this.parameter;
+ }
+
+ public void sampleNewConfig(double lambda, double reset, int verbose) {
+ if(!this.optimise){
+ return;
+ }
+ if (Math.random() < reset) {
+ this.std = (this.range[1] - this.range[0]) / 2;
+ }
+
+ // update configuration
+ // for integer features use truncated normal distribution
+ TruncatedNormal trncnormal = new TruncatedNormal(this.value, this.std, this.range[0], this.range[1]);
+ int newValue = (int) Math.round(trncnormal.sample());
+ if (verbose >= 3) {
+ System.out.println("Sample new configuration for integer parameter -" + this.parameter + " with mean: "
+ + this.value + ", std: " + this.std + ", lb: " + this.range[0] + ", ub: " + this.range[1]
+ + "\t=>\t -" + this.parameter + " " + newValue);
+ }
+
+ this.value = newValue;
+
+ // adapt distribution
+ this.std = this.std * Math.pow(2, -1 * lambda);
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java
new file mode 100755
index 000000000..8d4fb6fbd
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/NumericalParameter.java
@@ -0,0 +1,93 @@
+package moa.classifiers.meta.AutoML;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+
+// the representation of a numerical / real parameter
+public class NumericalParameter implements IParameter {
+ private String parameter;
+ private double value;
+ private double[] range;
+ private double std;
+ private Attribute attribute;
+ private boolean optimise;
+
+ public NumericalParameter(NumericalParameter x) {
+ this.parameter = x.parameter;
+ this.value = x.value;
+ this.attribute = new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = x.range.clone();
+ this.std = x.std;
+ }
+
+ }
+
+ public NumericalParameter(ParameterConfiguration x) {
+ this.parameter = x.parameter;
+ this.value = (double) x.value;
+ this.attribute = new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = new double[x.range.length];
+ for (int i = 0; i < x.range.length; i++) {
+ range[i] = (double) x.range[i];
+ }
+ this.std = (this.range[1] - this.range[0]) / 2;
+ }
+ }
+
+ public NumericalParameter copy() {
+ return new NumericalParameter(this);
+ }
+
+ public String getCLIString() {
+ return ("-" + this.parameter + " " + this.value);
+ }
+
+ public String getCLIValueString() {
+ return ("" + this.value);
+ }
+
+ public double getValue() {
+ return this.value;
+ }
+
+ public String getParameter() {
+ return this.parameter;
+ }
+
+ public void sampleNewConfig(double lambda, double reset, int verbose) {
+
+ if(!this.optimise){
+ return;
+ }
+
+ // trying to balanced exploitation vs exploration by resetting the std
+ if (Math.random() < reset) {
+ this.std = (this.range[1] - this.range[0]) / 2;
+ }
+ this.std = this.std * Math.pow(2, -1 * lambda);
+
+ // update configuration
+ // for numeric features use truncated normal distribution
+ TruncatedNormal trncnormal = new TruncatedNormal(this.value, this.std, this.range[0], this.range[1]);
+ double newValue = trncnormal.sample();
+
+ if (verbose >= 3) {
+ System.out.println("Sample new configuration for numerical parameter -" + this.parameter + " with mean: "
+ + this.value + ", std: " + this.std + ", lb: " + this.range[0] + ", ub: " + this.range[1]
+ + "\t=>\t -" + this.parameter + " " + newValue);
+ }
+
+ this.value = newValue;
+
+ // adapt distribution
+ // this.std = this.std * (Math.pow((1.0 / nbNewConfigurations), (1.0 /
+ // nbVariable)));
+
+
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java b/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java
new file mode 100755
index 000000000..9a445541e
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/OrdinalParameter.java
@@ -0,0 +1,98 @@
+package moa.classifiers.meta.AutoML;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+
+// the representation of an integer parameter
+public class OrdinalParameter implements IParameter {
+ private String parameter;
+ private String value;
+ private int numericValue;
+ private String[] range;
+ private double std;
+ private Attribute attribute;
+ private boolean optimise;
+
+ // copy constructor
+ public OrdinalParameter(OrdinalParameter x) {
+ this.parameter = x.parameter;
+ this.value = x.value;
+ this.numericValue = x.numericValue;
+ this.attribute = x.attribute;
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = x.range.clone();
+ this.std = x.std;
+ }
+ }
+
+ // init constructor
+ public OrdinalParameter(ParameterConfiguration x) {
+ this.parameter = x.parameter;
+ this.value = String.valueOf(x.value);
+ this.attribute = new Attribute(x.parameter);
+ this.optimise = x.optimise;
+
+ if(this.optimise){
+ this.range = new String[x.range.length];
+ for (int i = 0; i < x.range.length; i++) {
+ range[i] = String.valueOf(x.range[i]);
+ if (this.range[i].equals(this.value)) {
+ this.numericValue = i; // get index of init value
+ }
+ }
+ this.std = (this.range.length - 0) / 2;
+ }
+
+ }
+
+ public OrdinalParameter copy() {
+ return new OrdinalParameter(this);
+ }
+
+ public String getCLIString() {
+ return ("-" + this.parameter + " " + this.value);
+ }
+
+ public String getCLIValueString() {
+ return ("" + this.value);
+ }
+
+ public double getValue() {
+ return this.numericValue;
+ }
+
+ public String getParameter() {
+ return this.parameter;
+ }
+
+ public void sampleNewConfig(double lambda, double reset, int verbose) {
+
+ if(!this.optimise){
+ return;
+ }
+
+ // update configuration
+ if (Math.random() < reset) {
+ double upper = (double) (this.range.length - 1);
+ this.std = upper / 2;
+ }
+
+ // treat index of range as integer parameter
+ TruncatedNormal trncnormal = new TruncatedNormal(this.numericValue, this.std, 0.0, (double) (this.range.length - 1)); // limits are the indexes of the range
+ int newValue = (int) Math.round(trncnormal.sample());
+
+ if (verbose >= 3) {
+ System.out.println("Sample new configuration for ordinal parameter -" + this.parameter + " with mean: "
+ + this.numericValue + ", std: " + this.std + ", lb: " + 0 + ", ub: " + (this.range.length - 1)
+ + "\t=>\t -" + this.parameter + " " + this.range[newValue] + " (" + newValue + ")");
+ }
+
+ this.numericValue = newValue;
+ this.value = this.range[this.numericValue];
+
+ // adapt distribution
+ this.std = this.std * Math.pow(2, -1 * lambda);
+ }
+
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java
new file mode 100755
index 000000000..31d6224c6
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java
@@ -0,0 +1,56 @@
+package moa.classifiers.meta.AutoML;
+
+import umontreal.iro.lecuyer.probdist.NormalDist;
+import java.util.Random;
+
+public class TruncatedNormal {
+
+ double mean;
+ double sd;
+ double lb;
+ double ub;
+ double cdf_a;
+ double Z;
+
+ // https://en.wikipedia.org/wiki/Truncated_normal_distribution
+ TruncatedNormal(double mean, double sd, double lb, double ub){
+ this.mean = mean;
+ this.sd = sd;
+ this.lb = lb;
+ this.ub = ub;
+
+ this.cdf_a = NormalDist.cdf01((lb - mean)/sd);
+ double cdf_b = NormalDist.cdf01((ub - mean)/sd);
+ this.Z = cdf_b - cdf_a;
+ }
+
+ public double sample() {
+ //TODO This is the simple sampling strategy. Faster approaches are available
+ Random random = new Random();
+ double val = random.nextDouble() * Z + cdf_a;
+ return mean + sd * NormalDist.inverseF01(val);
+ }
+
+
+
+
+ public static void main(String[] args) {
+ TruncatedNormal trncnorm = new TruncatedNormal(0,10,-5,5);
+ double min=Double.MAX_VALUE;
+ double max=-Double.MAX_VALUE;
+ double sum=0.0;
+ for(int i=0; i<100000; i++) {
+ double val = trncnorm.sample();
+ sum += val;
+ if(val > max) max=val;
+ if(val < min) min=val;
+ System.out.println(val);
+ }
+
+ System.out.println("Min: " + min);
+ System.out.println("Max: " + max);
+ System.out.println("Mean: " + sum / 100000);
+
+ }
+
+}
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json b/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json
new file mode 100755
index 000000000..50b1eda78
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json
@@ -0,0 +1,52 @@
+{
+
+ "windowSize" : 1000,
+ "ensembleSize" : 10,
+ "newConfigurations" : 10,
+ "keepCurrentModel" : true,
+ "lambda" : 0.05,
+ "preventAlgorithmDeath" : true,
+ "keepGlobalIncumbent" : true,
+ "keepAlgorithmIncumbents" : true,
+ "keepInitialConfigurations" : true,
+ "useTestEnsemble" : true,
+ "resetProbability" : 0.01,
+ "numberOfCores" : 1,
+ "performanceMeasureMaximisation": true,
+
+ "algorithms": [
+ {
+ "algorithm": "lazy.kNN",
+ "parameters": [
+ {"parameter": "k", "type":"integer", "value":10, "range":[2,30]}
+ // {"parameter": "w", "type":"integer", "value":1000, "range":[500,2000]}
+ ]
+ }
+ ,
+ {
+ "algorithm": "trees.HoeffdingTree",
+ "parameters": [
+ {"parameter": "g", "type":"integer", "value":200, "range":[10, 200]},
+ {"parameter": "c", "type":"float", "value":0.01, "range":[0, 1]}
+ ]
+ }
+ ,
+ {
+ "algorithm": "lazy.kNNwithPAWandADWIN",
+ "parameters": [
+ {"parameter": "k", "type":"integer", "value":10, "range":[2,30]}
+ // {"parameter": "w", "type":"numeric", "value":1000, "range":[1000,5000]},
+ //{"parameter": "m", "type":"boolean", "value":"false"}
+ ]
+ }
+ ,
+ {
+ "algorithm": "trees.HoeffdingAdaptiveTree",
+ "parameters": [
+ {"parameter": "g", "type":"integer", "value":200, "range":[10, 200]},
+ {"parameter": "c", "type":"float", "value":0.01, "range":[0, 1]}
+ ]
+ }
+
+ ]
+}
From eb132c7c00ac7c053e96679757f0fb045663c38b Mon Sep 17 00:00:00 2001
From: Heitor Murilo Gomes
Date: Sat, 7 Sep 2024 06:19:31 +1200
Subject: [PATCH 25/31] update to autoclass
---
.../java/moa/classifiers/lazy/RW_kNN.java | 4 ++--
.../classifiers/meta/AutoML/Algorithm.java | 2 +-
.../classifiers/meta/AutoML/AutoClass.java | 5 ++++-
.../meta/AutoML/TruncatedNormal.java | 20 +++++++++++++------
4 files changed, 21 insertions(+), 10 deletions(-)
diff --git a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java
index 4d9601502..09ad98294 100755
--- a/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java
+++ b/moa/src/main/java/moa/classifiers/lazy/RW_kNN.java
@@ -99,7 +99,7 @@ public void trainOnInstanceImpl(Instance inst) {
this.reservoir = new Instances(inst.dataset());
}
if (this.limitOptionReservoir.getValue() <= this.reservoir.numInstances()) {
- int replaceIndex = r.nextInt(this.limitOptionReservoir.getValue() - 1);
+ int replaceIndex = this.classifierRandom.nextInt(this.limitOptionReservoir.getValue() - 1);
this.reservoir.set(replaceIndex, inst);
} else
this.reservoir.add(inst);
@@ -155,6 +155,6 @@ public void getModelDescription(StringBuilder out, int indent) {
}
public boolean isRandomizable() {
- return false;
+ return true;
}
}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java b/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java
index 4add3725f..441b68f59 100755
--- a/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/Algorithm.java
@@ -163,7 +163,7 @@ public void adjustAlgorithm(boolean keepCurrentModel, int verbose) {
// these changes do not transfer over directly since all algorithms cache the
// option values. Therefore we try to adjust the cached values if possible
try {
- ((AbstractClassifier) this.classifier).adjustParameters();
+// ((AbstractClassifier) this.classifier).adjustParameters();
if (verbose >= 2) {
System.out.println("Changed: " + this.classifier.getCLICreationString(Classifier.class));
}
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
index c9b39aa68..4702989fa 100755
--- a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
@@ -256,7 +256,8 @@ protected void evaluatePerformance() {
+ "\t => \t performance: " + performance);
}
- String algorithm = this.ensemble.get(i).classifier.getPurposeString();
+// String algorithm = this.ensemble.get(i).classifier.getPurposeString();
+ String algorithm = this.ensemble.get(i).classifier.getClass().getName();
if (!bestPerformanceIdxMap.containsKey(algorithm) || performance > bestPerformanceValMap.get(algorithm)) {
bestPerformanceValMap.put(algorithm, performance); // best performance per algorithm
bestPerformanceIdxMap.put(algorithm, i); // index of best performance per algorithm
@@ -695,3 +696,5 @@ public Integer call() throws Exception {
return 0;
}
}
+ } // Close the EnsembleRunnable class
+
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java
index 31d6224c6..192ae1577 100755
--- a/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/TruncatedNormal.java
@@ -1,6 +1,7 @@
package moa.classifiers.meta.AutoML;
-import umontreal.iro.lecuyer.probdist.NormalDist;
+import org.apache.commons.math3.distribution.NormalDistribution;
+//import umontreal.iro.lecuyer.probdist.NormalDist;
import java.util.Random;
public class TruncatedNormal {
@@ -11,16 +12,22 @@ public class TruncatedNormal {
double ub;
double cdf_a;
double Z;
-
+
+ private NormalDistribution normalDist;
+
// https://en.wikipedia.org/wiki/Truncated_normal_distribution
TruncatedNormal(double mean, double sd, double lb, double ub){
this.mean = mean;
this.sd = sd;
this.lb = lb;
this.ub = ub;
-
- this.cdf_a = NormalDist.cdf01((lb - mean)/sd);
- double cdf_b = NormalDist.cdf01((ub - mean)/sd);
+
+
+ normalDist = new NormalDistribution(0, 1); // Standard normal distribution
+ this.cdf_a = normalDist.cumulativeProbability((lb - mean) / sd);
+ double cdf_b = normalDist.cumulativeProbability((ub - mean) / sd);
+// this.cdf_a = NormalDist.cdf01((lb - mean)/sd);
+// double cdf_b = NormalDist.cdf01((ub - mean)/sd);
this.Z = cdf_b - cdf_a;
}
@@ -28,7 +35,8 @@ public double sample() {
//TODO This is the simple sampling strategy. Faster approaches are available
Random random = new Random();
double val = random.nextDouble() * Z + cdf_a;
- return mean + sd * NormalDist.inverseF01(val);
+ return mean + sd * normalDist.inverseCumulativeProbability(val);
+// return mean + sd * NormalDistribution.inverseF01(val);
}
From 4cb717680f65ab7f9a29e4c7b48dbca7a0f242e5 Mon Sep 17 00:00:00 2001
From: Heitor Murilo Gomes
Date: Sat, 7 Sep 2024 22:23:30 +1200
Subject: [PATCH 26/31] fix: update to make it compatible with capymoa
---
.../classifiers/meta/AutoML/AutoClass.java | 21 +++++++++++++++++++
.../AutoML/HeterogeneousEnsembleAbstract.java | 12 ++++++-----
2 files changed, 28 insertions(+), 5 deletions(-)
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
index 4702989fa..94f151fda 100755
--- a/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/AutoClass.java
@@ -5,6 +5,7 @@
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
+import com.yahoo.labs.samoa.instances.InstancesHeader;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.meta.AdaptiveRandomForestRegressor;
@@ -44,6 +45,13 @@ public class AutoClass extends HeterogeneousEnsembleAbstract implements MultiCla
public FileOption fileOption = new FileOption("ConfigurationFile", 'f', "Configuration file in json format.",
"/Users/mbahri/Desktop/Dell/moa/src/main/java/moa/classifiers/meta/AutoML/settings.json", ".json", false);
+ @Override
+ public String getPurposeString() {
+ return "Autoclass: Automl for data stream classification. "+
+ "Bahri, Maroua, and Nikolaos Georgantas. " +
+ "2023 IEEE International Conference on Big Data (BigData). IEEE, 2023.";
+ }
+
public void init() {
this.fileOption.getFile();
}
@@ -631,6 +639,19 @@ protected Measurement[] getModelMeasurementsImpl() {
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
+
+ @Override
+ public void setModelContext(InstancesHeader ih) {
+ super.setModelContext(ih);
+
+// This will cause issues in case setModelContext is invoked before the ensemble has been created.
+// It is likely safe to not perform this action due to how the context can be acquired later by the learners.
+// However, it is worth reviewing this in the future. See also HeterogeneousEnsembleAbstract.setModelContext
+// for (int i = 0; i < this.ensemble.size(); ++i) {
+// this.ensemble.get(i).classifier.setModelContext(ih);
+// }
+ }
+
@Override
public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
diff --git a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java
index 7a17f8d64..3bcbdbafe 100755
--- a/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java
+++ b/moa/src/main/java/moa/classifiers/meta/AutoML/HeterogeneousEnsembleAbstract.java
@@ -84,7 +84,7 @@ public abstract class HeterogeneousEnsembleAbstract extends AbstractClassifier i
',');
public IntOption gracePerionOption = new IntOption("gracePeriod", 'g',
- "How many instances before we reevalate the best classifier", 1, 1,
+ "How many instances before we re-evaluate the best classifier", 1, 1,
Integer.MAX_VALUE);
public IntOption activeClassifiersOption = new IntOption("activeClassifiers",
@@ -143,10 +143,12 @@ public double[] getVotesForInstance(Instance inst) {
@Override
public void setModelContext(InstancesHeader ih) {
super.setModelContext(ih);
-
- for (int i = 0; i < this.ensemble.length; ++i) {
- this.ensemble[i].setModelContext(ih);
- }
+// This will cause issues in case setModelContext is invoked before the ensemble has been created.
+// It is likely safe to not perform this action due to how the context can be acquired later by the learners.
+// However, it is worth reviewing this in the future. See also AutoClass.setModelContext
+// for (int i = 0; i < this.ensemble.length; ++i) {
+// this.ensemble[i].setModelContext(ih);
+// }
}
@Override
From e658b0268f2297448436b78afe2df9758bb8a7bb Mon Sep 17 00:00:00 2001
From: Spencer Sun
Date: Wed, 19 Feb 2025 11:53:43 +1300
Subject: [PATCH 27/31] fix: debugging SRP_MB and TstThnTrn class
---
.../classifiers/meta/minibatch/StreamingRandomPatchesMB.java | 1 +
.../main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java | 4 ++++
2 files changed, 5 insertions(+)
diff --git a/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java b/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java
index b5bb9a6e3..1c151c848 100644
--- a/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java
+++ b/moa/src/main/java/moa/classifiers/meta/minibatch/StreamingRandomPatchesMB.java
@@ -287,6 +287,7 @@ protected void initEnsemble(Instance instance) {
break;
case StreamingRandomPatchesMB.TRAIN_RANDOM_SUBSPACES:
case StreamingRandomPatchesMB.TRAIN_RANDOM_PATCHES:
+ if (this.subspaces.isEmpty()) break;
int selectedValue = this.classifierRandom.nextInt(subspaces.size());
ArrayList subsetOfFeatures = this.subspaces.get(selectedValue);
subsetOfFeatures.add(instance.classIndex());
diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
index e0c69a5bd..e16d43791 100644
--- a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
+++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java
@@ -25,6 +25,7 @@
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
+import moa.classifiers.AbstractClassifierMiniBatch;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
@@ -217,6 +218,9 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
if (immediateResultStream != null) {
immediateResultStream.close();
}
+ if (learner instanceof AbstractClassifierMiniBatch) {
+ ((AbstractClassifierMiniBatch) learner).trainingHasEnded();
+ }
return learningCurve;
}
From f4c4524ccf690f83c920930da7b4d44a79785967 Mon Sep 17 00:00:00 2001
From: heymarco
Date: Wed, 12 Mar 2025 22:37:28 +0100
Subject: [PATCH 28/31] Added PLASTIC and the necessary adaptations to existing
classes (#23)
* Added PLASTIC and the necessary adaptations to existing classes
* Added PLASTIC and PLASTIC-A. Now also supports NB and NBA as leaf prediction options
---
...GaussianNumericAttributeClassObserver.java | 18 +-
.../NominalAttributeClassObserver.java | 33 +-
.../NominalAttributeBinaryTest.java | 6 +-
.../NumericAttributeBinaryTest.java | 10 +-
.../main/java/moa/classifiers/trees/EFDT.java | 2716 ++++++++---------
.../java/moa/classifiers/trees/PLASTIC.java | 192 ++
.../java/moa/classifiers/trees/PLASTICA.java | 192 ++
.../CustomADWINChangeDetector.java | 27 +
.../trees/plastic_util/CustomEFDTNode.java | 780 +++++
.../trees/plastic_util/CustomHTNode.java | 80 +
.../trees/plastic_util/EFHATNode.java | 174 ++
.../trees/plastic_util/MappedTree.java | 409 +++
.../plastic_util/MeasuresNumberOfLeaves.java | 5 +
.../plastic_util/PerformsTreeRevision.java | 5 +
.../trees/plastic_util/PlasticBranch.java | 63 +
.../trees/plastic_util/PlasticNode.java | 398 +++
.../plastic_util/PlasticTreeElement.java | 40 +
.../trees/plastic_util/Restructurer.java | 315 ++
.../plastic_util/SuccessorIdentifier.java | 149 +
.../trees/plastic_util/Successors.java | 221 ++
20 files changed, 4457 insertions(+), 1376 deletions(-)
create mode 100644 moa/src/main/java/moa/classifiers/trees/PLASTIC.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/PLASTICA.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/CustomHTNode.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/EFHATNode.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/MappedTree.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/MeasuresNumberOfLeaves.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PerformsTreeRevision.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticBranch.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticNode.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/PlasticTreeElement.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/Restructurer.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/SuccessorIdentifier.java
create mode 100644 moa/src/main/java/moa/classifiers/trees/plastic_util/Successors.java
diff --git a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
index dbc0e0e17..a727c4066 100644
--- a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
+++ b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
@@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.classifiers.core.attributeclassobservers;
@@ -81,7 +81,7 @@ public void observeAttributeClass(double attVal, int classVal, double weight) {
@Override
public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
+ int classVal) {
GaussianEstimator obs = this.attValDistPerClass.get(classVal);
return obs != null ? obs.probabilityDensity(attVal) : 0.0;
}
@@ -99,12 +99,24 @@ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
bestSuggestion = new AttributeSplitSuggestion(
new NumericAttributeBinaryTest(attIndex, splitValue,
- true), postSplitDists, merit);
+ true), postSplitDists, merit);
}
}
return bestSuggestion;
}
+ /* Used by PLASTIC during restructuring when forcing a leaf split becomes necessary */
+ public AttributeSplitSuggestion forceSplit(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex, double threshold) {
+ AttributeSplitSuggestion bestSuggestion = null;
+ double[][] postSplitDists = getClassDistsResultingFromBinarySplit(threshold);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex, threshold, true), postSplitDists, merit);
+ return bestSuggestion;
+ }
+
public double[] getSplitPointSuggestions() {
Set suggestedSplitValues = new TreeSet();
double minValue = Double.POSITIVE_INFINITY;
diff --git a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
index ce8021017..c5de2406e 100644
--- a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
+++ b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
@@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.classifiers.core.attributeclassobservers;
@@ -68,7 +68,7 @@ public void observeAttributeClass(double attVal, int classVal, double weight) {
@Override
public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
+ int classVal) {
DoubleVector obs = this.attValDistPerClass.get(classVal);
return obs != null ? (obs.getValue((int) attVal) + 1.0)
/ (obs.sumOfValues() + obs.numValues()) : 0.0;
@@ -109,6 +109,33 @@ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
return bestSuggestion;
}
+ /* Used by PLASTIC during restructuring when forcing a leaf split becomes necessary */
+ public AttributeSplitSuggestion forceSplit(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binary, Double splitValue) {
+ AttributeSplitSuggestion bestSuggestion;
+ int maxAttValsObserved = getMaxAttValsObserved();
+ if (!binary) {
+ double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NominalAttributeMultiwayTest(attIndex), postSplitDists,
+ merit);
+ return bestSuggestion;
+ }
+ assert splitValue != null: "Split value is null";
+ if (splitValue >= maxAttValsObserved) {
+ return null;
+ }
+ double[][] postSplitDists = getClassDistsResultingFromBinarySplit(splitValue.intValue());
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NominalAttributeBinaryTest(attIndex, splitValue.intValue()),
+ postSplitDists, merit);
+ return bestSuggestion;
+ }
+
public int getMaxAttValsObserved() {
int maxAttValsObserved = 0;
for (DoubleVector attValDist : this.attValDistPerClass) {
@@ -157,7 +184,7 @@ public double[][] getClassDistsResultingFromBinarySplit(int valIndex) {
}
}
return new double[][]{equalsDist.getArrayRef(),
- notEqualDist.getArrayRef()};
+ notEqualDist.getArrayRef()};
}
@Override
diff --git a/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java b/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
index 73498ff6b..397d308b4 100644
--- a/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
+++ b/moa/src/main/java/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
@@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.classifiers.core.conditionaltests;
@@ -48,6 +48,10 @@ public int branchForInstance(Instance inst) {
return inst.isMissing(instAttIndex) ? -1 : ((int) inst.value(instAttIndex) == this.attValue ? 0 : 1);
}
+ public double getValue() {
+ return attValue;
+ }
+
@Override
public String describeConditionForBranch(int branch, InstancesHeader context) {
if ((branch == 0) || (branch == 1)) {
diff --git a/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java b/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
index d8cb5e8c3..5acb11610 100644
--- a/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
+++ b/moa/src/main/java/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
@@ -15,7 +15,7 @@
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
- *
+ *
*/
package moa.classifiers.core.conditionaltests;
@@ -39,16 +39,20 @@ public class NumericAttributeBinaryTest extends InstanceConditionalBinaryTest {
protected boolean equalsPassesTest;
public NumericAttributeBinaryTest(int attIndex, double attValue,
- boolean equalsPassesTest) {
+ boolean equalsPassesTest) {
this.attIndex = attIndex;
this.attValue = attValue;
this.equalsPassesTest = equalsPassesTest;
}
+ public double getValue() {
+ return attValue;
+ }
+
@Override
public int branchForInstance(Instance inst) {
int instAttIndex = this.attIndex ; // < inst.classIndex() ? this.attIndex
- // : this.attIndex + 1;
+ // : this.attIndex + 1;
if (inst.isMissing(instAttIndex)) {
return -1;
}
diff --git a/moa/src/main/java/moa/classifiers/trees/EFDT.java b/moa/src/main/java/moa/classifiers/trees/EFDT.java
index d900ed250..314c6adfc 100644
--- a/moa/src/main/java/moa/classifiers/trees/EFDT.java
+++ b/moa/src/main/java/moa/classifiers/trees/EFDT.java
@@ -78,1538 +78,1522 @@
public class EFDT extends AbstractClassifier implements MultiClassClassifier {
- private static final long serialVersionUID = 2L;
-
- public IntOption reEvalPeriodOption = new IntOption(
- "reevaluationPeriod",
- 'R',
- "The number of instances an internal node should observe between re-evaluation attempts.",
- 2000, 0, Integer.MAX_VALUE);
-
- public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm',
- "Maximum memory consumed by the tree.", 33554432, 0,
- Integer.MAX_VALUE);
-
- /*
- * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption(
- * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{
- * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100",
- * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating
- * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints",
- * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna
- * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary
- * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100
- * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0);
- */
- public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
- 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
- "GaussianNumericAttributeClassObserver");
-
- public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
- 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
- "NominalAttributeClassObserver");
-
- public IntOption memoryEstimatePeriodOption = new IntOption(
- "memoryEstimatePeriod", 'e',
- "How many instances between memory consumption checks.", 1000000,
- 0, Integer.MAX_VALUE);
-
- public IntOption gracePeriodOption = new IntOption(
- "gracePeriod",
- 'g',
- "The number of instances a leaf should observe between split attempts.",
- 200, 0, Integer.MAX_VALUE);
-
- public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
- 's', "Split criterion to use.", SplitCriterion.class,
- "InfoGainSplitCriterion");
-
- public FloatOption splitConfidenceOption = new FloatOption(
- "splitConfidence",
- 'c',
- "The allowable error in split decision, values closer to 0 will take longer to decide.",
- 0.0000001, 0.0, 1.0);
-
- public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
- 't', "Threshold below which a split will be forced to break ties.",
- 0.05, 0.0, 1.0);
-
- public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
- "Only allow binary splits.");
-
- public FlagOption stopMemManagementOption = new FlagOption(
- "stopMemManagement", 'z',
- "Stop growing as soon as memory limit is hit.");
-
- public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts",
- 'r', "Disable poor attributes.");
-
- public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p',
- "Disable pre-pruning.");
-
- public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
- "leafprediction", 'l', "Leaf prediction to use.", new String[]{
- "MC", "NB", "NBAdaptive"}, new String[]{
- "Majority class",
- "Naive Bayes",
- "Naive Bayes Adaptive"}, 2);
-
- public IntOption nbThresholdOption = new IntOption(
- "nbThreshold",
- 'q',
- "The number of instances a leaf should observe before permitting Naive Bayes.",
- 0, 0, Integer.MAX_VALUE);
-
- protected Node treeRoot = null;
-
- protected int decisionNodeCount;
-
- protected int activeLeafNodeCount;
-
- protected int inactiveLeafNodeCount;
-
- protected double inactiveLeafByteSizeEstimate;
-
- protected double activeLeafByteSizeEstimate;
-
- protected double byteSizeEstimateOverheadFraction;
-
- protected boolean growthAllowed;
-
- protected int numInstances = 0;
-
- protected int splitCount = 0;
-
- @Override
- public String getPurposeString() {
- return "Hoeffding Tree or VFDT.";
- }
-
- public long calcByteSize() {
- long size = SizeOf.sizeOf(this);
- if (this.treeRoot != null) {
- size += this.treeRoot.calcByteSizeIncludingSubtree();
- }
- return size;
- }
-
- @Override
- public long measureByteSize() {
- return calcByteSize();
- }
-
- @Override
- public void resetLearningImpl() {
- this.treeRoot = null;
- this.decisionNodeCount = 0;
- this.activeLeafNodeCount = 0;
- this.inactiveLeafNodeCount = 0;
- this.inactiveLeafByteSizeEstimate = 0.0;
- this.activeLeafByteSizeEstimate = 0.0;
- this.byteSizeEstimateOverheadFraction = 1.0;
- this.growthAllowed = true;
- if (this.leafpredictionOption.getChosenIndex() > 0) {
- this.removePoorAttsOption = null;
- }
- }
-
- @Override
- public double[] getVotesForInstance(Instance inst) {
- if (this.treeRoot != null) {
- FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst,
- null, -1);
- Node leafNode = foundNode.node;
- if (leafNode == null) {
- leafNode = foundNode.parent;
- }
- return leafNode.getClassVotes(inst, this);
- }
- else {
- int numClasses = inst.dataset().numClasses();
- return new double[numClasses];
- }
- }
-
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- FoundNode[] learningNodes = findLearningNodes();
-
- return new Measurement[]{
-
- new Measurement("tree size (nodes)", this.decisionNodeCount
- + this.activeLeafNodeCount + this.inactiveLeafNodeCount),
- new Measurement("tree size (leaves)", learningNodes.length),
- new Measurement("active learning leaves",
- this.activeLeafNodeCount),
- new Measurement("tree depth", measureTreeDepth()),
- new Measurement("active leaf byte size estimate",
- this.activeLeafByteSizeEstimate),
- new Measurement("inactive leaf byte size estimate",
- this.inactiveLeafByteSizeEstimate),
- new Measurement("byte size estimate overhead",
- this.byteSizeEstimateOverheadFraction),
- new Measurement("splits",
- this.splitCount)};
- }
-
- public int measureTreeDepth() {
- if (this.treeRoot != null) {
- return this.treeRoot.subtreeDepth();
- }
- return 0;
- }
-
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- this.treeRoot.describeSubtree(this, out, indent);
- }
-
- @Override
- public boolean isRandomizable() {
- return false;
- }
-
- public static double computeHoeffdingBound(double range, double confidence,
- double n) {
- return Math.sqrt(((range * range) * Math.log(1.0 / confidence))
- / (2.0 * n));
- }
-
- protected AttributeClassObserver newNominalClassObserver() {
- AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption);
- return (AttributeClassObserver) nominalClassObserver.copy();
- }
-
- protected AttributeClassObserver newNumericClassObserver() {
- AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption);
- return (AttributeClassObserver) numericClassObserver.copy();
- }
-
- public void enforceTrackerLimit() {
- if ((this.inactiveLeafNodeCount > 0)
- || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
- * this.inactiveLeafByteSizeEstimate)
- * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) {
- if (this.stopMemManagementOption.isSet()) {
- this.growthAllowed = false;
- return;
- }
- FoundNode[] learningNodes = findLearningNodes();
- Arrays.sort(learningNodes, new Comparator() {
-
- @Override
- public int compare(FoundNode fn1, FoundNode fn2) {
- return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise());
- }
- });
- int maxActive = 0;
- while (maxActive < learningNodes.length) {
- maxActive++;
- if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive)
- * this.inactiveLeafByteSizeEstimate)
- * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) {
- maxActive--;
- break;
- }
- }
- int cutoff = learningNodes.length - maxActive;
- for (int i = 0; i < cutoff; i++) {
- if (learningNodes[i].node instanceof ActiveLearningNode) {
- deactivateLearningNode(
- (ActiveLearningNode) learningNodes[i].node,
- learningNodes[i].parent,
- learningNodes[i].parentBranch);
- }
- }
- for (int i = cutoff; i < learningNodes.length; i++) {
- if (learningNodes[i].node instanceof InactiveLearningNode) {
- activateLearningNode(
- (InactiveLearningNode) learningNodes[i].node,
- learningNodes[i].parent,
- learningNodes[i].parentBranch);
- }
- }
- }
- }
-
- public void estimateModelByteSizes() {
- FoundNode[] learningNodes = findLearningNodes();
- long totalActiveSize = 0;
- long totalInactiveSize = 0;
- for (FoundNode foundNode : learningNodes) {
- if (foundNode.node instanceof ActiveLearningNode) {
- totalActiveSize += SizeOf.fullSizeOf(foundNode.node);
- }
- else {
- totalInactiveSize += SizeOf.fullSizeOf(foundNode.node);
- }
- }
- if (totalActiveSize > 0) {
- this.activeLeafByteSizeEstimate = (double) totalActiveSize
- / this.activeLeafNodeCount;
- }
- if (totalInactiveSize > 0) {
- this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize
- / this.inactiveLeafNodeCount;
- }
- long actualModelSize = this.measureByteSize();
- double estimatedModelSize = (this.activeLeafNodeCount
- * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
- * this.inactiveLeafByteSizeEstimate);
- this.byteSizeEstimateOverheadFraction = actualModelSize
- / estimatedModelSize;
- if (actualModelSize > this.maxByteSizeOption.getValue()) {
- enforceTrackerLimit();
- }
- }
-
- public void deactivateAllLeaves() {
- FoundNode[] learningNodes = findLearningNodes();
- for (FoundNode learningNode : learningNodes) {
- if (learningNode.node instanceof ActiveLearningNode) {
- deactivateLearningNode(
- (ActiveLearningNode) learningNode.node,
- learningNode.parent, learningNode.parentBranch);
- }
- }
- }
-
- protected void deactivateLearningNode(ActiveLearningNode toDeactivate,
- SplitNode parent, int parentBranch) {
- Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
- if (parent == null) {
- this.treeRoot = newLeaf;
- }
- else {
- parent.setChild(parentBranch, newLeaf);
- }
- this.activeLeafNodeCount--;
- this.inactiveLeafNodeCount++;
- }
-
- protected void activateLearningNode(InactiveLearningNode toActivate,
- SplitNode parent, int parentBranch) {
- Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution());
- if (parent == null) {
- this.treeRoot = newLeaf;
- }
- else {
- parent.setChild(parentBranch, newLeaf);
- }
- this.activeLeafNodeCount++;
- this.inactiveLeafNodeCount--;
- }
-
- protected FoundNode[] findLearningNodes() {
- List foundList = new LinkedList<>();
- findLearningNodes(this.treeRoot, null, -1, foundList);
- return foundList.toArray(new FoundNode[foundList.size()]);
- }
-
- protected void findLearningNodes(Node node, SplitNode parent,
- int parentBranch, List found) {
- if (node != null) {
- if (node instanceof LearningNode) {
- found.add(new FoundNode(node, parent, parentBranch));
- }
- if (node instanceof SplitNode) {
- SplitNode splitNode = (SplitNode) node;
- for (int i = 0; i < splitNode.numChildren(); i++) {
- findLearningNodes(splitNode.getChild(i), splitNode, i,
- found);
- }
- }
- }
- }
-
- protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
- int parentIndex) {
-
- if (!node.observedClassDistributionIsPure()) {
- node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains
-
-
- SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
- AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
- Arrays.sort(bestSplitSuggestions);
- boolean shouldSplit = false;
-
- for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
-
- if (bestSplitSuggestion.splitTest != null) {
- if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) {
- node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0);
- }
- double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]));
- node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit);
- }
-
- else { // handle the null attribute
- double currentSum = node.getInfogainSum().get(-1); // null split
- node.getInfogainSum().put(-1, Math.max(0.0, currentSum + bestSplitSuggestion.merit));
- assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here.";
- }
-
- }
-
- if (bestSplitSuggestions.length < 2) {
- shouldSplit = bestSplitSuggestions.length > 0;
- }
-
- else {
- double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
- this.splitConfidenceOption.getValue(), node.getWeightSeen());
- AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
-
- double bestSuggestionAverageMerit;
- double currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
-
- // because this is an unsplit leaf. current average merit should be always zero on the null split.
-
- if (bestSuggestion.splitTest == null) { // if you have a null split
- bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
- }
- else {
- bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
- }
-
- if (bestSuggestion.merit < 1e-10) {
- shouldSplit = false; // we don't use average here
- }
-
- else if ((bestSuggestionAverageMerit - currentAverageMerit) >
- hoeffdingBound
- || (hoeffdingBound < this.tieThresholdOption.getValue())) {
- if (bestSuggestionAverageMerit - currentAverageMerit < hoeffdingBound) {
- // Placeholder to list this possibility
- }
- shouldSplit = true;
- }
-
- if (shouldSplit) {
- for (Integer i : node.usedNominalAttributes) {
- if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) {
- shouldSplit = false;
- break;
- }
- }
- }
-
- // }
- if ((this.removePoorAttsOption != null)
- && this.removePoorAttsOption.isSet()) {
- Set poorAtts = new HashSet<>();
- // scan 1 - add any poor to set
- for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
- if (bestSplitSuggestion.splitTest != null) {
- int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn();
- if (splitAtts.length == 1) {
- if (bestSuggestion.merit
- - bestSplitSuggestion.merit > hoeffdingBound) {
- poorAtts.add(splitAtts[0]);
- }
- }
- }
- }
- // scan 2 - remove good ones from set
- for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
- if (bestSplitSuggestion.splitTest != null) {
- int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn();
- if (splitAtts.length == 1) {
- if (bestSuggestion.merit
- - bestSplitSuggestion.merit < hoeffdingBound) {
- poorAtts.remove(splitAtts[0]);
- }
- }
- }
- }
- for (int poorAtt : poorAtts) {
- node.disableAttribute(poorAtt);
- }
- }
- }
- if (shouldSplit) {
- splitCount++;
-
- AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- if (splitDecision.splitTest == null) {
- // preprune - null wins
- deactivateLearningNode(node, parent, parentIndex);
- }
- else {
- Node newSplit = newSplitNode(splitDecision.splitTest,
- node.getObservedClassDistribution(), splitDecision.numSplits());
- ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers
- newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, leaf to split
-
- for (int i = 0; i < splitDecision.numSplits(); i++) {
-
- double[] j = splitDecision.resultingClassDistributionFromSplit(i);
-
- Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
-
- if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
- || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) {
- newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy
- newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
- // no nominal attribute should be split on more than once in the path
- }
- ((EFDTSplitNode) newSplit).setChild(i, newChild);
- }
- this.activeLeafNodeCount--;
- this.decisionNodeCount++;
- this.activeLeafNodeCount += splitDecision.numSplits();
- if (parent == null) {
- this.treeRoot = newSplit;
- }
- else {
- parent.setChild(parentIndex, newSplit);
- }
-
- }
- // manage memory
- enforceTrackerLimit();
- }
- }
- }
-
- @Override
- public void trainOnInstanceImpl(Instance inst) {
-
- if (this.treeRoot == null) {
- this.treeRoot = newLearningNode();
- ((EFDTNode) this.treeRoot).setRoot(true);
- this.activeLeafNodeCount = 1;
- }
-
- FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
- Node leafNode = foundNode.node;
-
- if (leafNode == null) {
- leafNode = newLearningNode();
- foundNode.parent.setChild(foundNode.parentBranch, leafNode);
- this.activeLeafNodeCount++;
- }
-
- ((EFDTNode) this.treeRoot).learnFromInstance(inst, this, null, -1);
-
- numInstances++;
- }
-
-
- protected LearningNode newLearningNode() {
- return new EFDTLearningNode(new double[0]);
- }
-
- protected LearningNode newLearningNode(double[] initialClassObservations) {
- return new EFDTLearningNode(initialClassObservations);
- }
-
- protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
- double[] classObservations, int size) {
- return new EFDTSplitNode(splitTest, classObservations, size);
- }
-
- protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
- double[] classObservations) {
- return new EFDTSplitNode(splitTest, classObservations);
- }
-
- private int argmax(double[] array) {
-
- double max = array[0];
- int maxarg = 0;
-
- for (int i = 1; i < array.length; i++) {
-
- if (array[i] > max) {
- max = array[i];
- maxarg = i;
- }
- }
- return maxarg;
- }
-
- public interface EFDTNode {
-
- boolean isRoot();
-
- void setRoot(boolean isRoot);
-
- void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch);
-
- void setParent(EFDTSplitNode parent);
-
- EFDTSplitNode getParent();
-
- }
-
- public static class FoundNode {
-
- public Node node;
-
- public SplitNode parent;
-
- public int parentBranch;
-
- public FoundNode(Node node, SplitNode parent, int parentBranch) {
- this.node = node;
- this.parent = parent;
- this.parentBranch = parentBranch;
- }
- }
-
- public static class Node extends AbstractMOAObject {
+ private static final long serialVersionUID = 2L;
+
+ public IntOption reEvalPeriodOption = new IntOption(
+ "reevaluationPeriod",
+ 'R',
+ "The number of instances an internal node should observe between re-evaluation attempts.",
+ 2000, 0, Integer.MAX_VALUE);
+
+ public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm',
+ "Maximum memory consumed by the tree.", 33554432, 0,
+ Integer.MAX_VALUE);
+
+ /*
+ * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption(
+ * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{
+ * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100",
+ * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating
+ * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints",
+ * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna
+ * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary
+ * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100
+ * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0);
+ */
+ public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
+ 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
+ "GaussianNumericAttributeClassObserver");
- private HashMap infogainSum;
+ public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
+ 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
+ "NominalAttributeClassObserver");
- private int numSplitAttempts = 0;
+ public IntOption memoryEstimatePeriodOption = new IntOption(
+ "memoryEstimatePeriod", 'e',
+ "How many instances between memory consumption checks.", 1000000,
+ 0, Integer.MAX_VALUE);
- private static final long serialVersionUID = 1L;
+ public IntOption gracePeriodOption = new IntOption(
+ "gracePeriod",
+ 'g',
+ "The number of instances a leaf should observe between split attempts.",
+ 200, 0, Integer.MAX_VALUE);
- protected DoubleVector observedClassDistribution;
+ public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
+ 's', "Split criterion to use.", SplitCriterion.class,
+ "InfoGainSplitCriterion");
- protected DoubleVector classDistributionAtTimeOfCreation;
+ public FloatOption splitConfidenceOption = new FloatOption(
+ "splitConfidence",
+ 'c',
+ "The allowable error in split decision, values closer to 0 will take longer to decide.",
+ 0.0000001, 0.0, 1.0);
- protected int nodeTime;
+ public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
+ 't', "Threshold below which a split will be forced to break ties.",
+ 0.05, 0.0, 1.0);
- protected List usedNominalAttributes = new ArrayList<>();
+ public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
+ "Only allow binary splits.");
- public Node(double[] classObservations) {
- this.observedClassDistribution = new DoubleVector(classObservations);
- this.classDistributionAtTimeOfCreation = new DoubleVector(classObservations);
- this.infogainSum = new HashMap<>();
- this.infogainSum.put(-1, 0.0); // Initialize for null split
+ public FlagOption stopMemManagementOption = new FlagOption(
+ "stopMemManagement", 'z',
+ "Stop growing as soon as memory limit is hit.");
- }
+ public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts",
+ 'r', "Disable poor attributes.");
- public int getNumSplitAttempts() {
- return numSplitAttempts;
- }
+ public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p',
+ "Disable pre-pruning.");
- public void addToSplitAttempts(int i) {
- numSplitAttempts += i;
- }
+ public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
+ "leafprediction", 'l', "Leaf prediction to use.", new String[]{
+ "MC", "NB", "NBAdaptive"}, new String[]{
+ "Majority class",
+ "Naive Bayes",
+ "Naive Bayes Adaptive"}, 2);
- public HashMap getInfogainSum() {
- return infogainSum;
- }
+ public IntOption nbThresholdOption = new IntOption(
+ "nbThreshold",
+ 'q',
+ "The number of instances a leaf should observe before permitting Naive Bayes.",
+ 0, 0, Integer.MAX_VALUE);
- public void setInfogainSum(HashMap igs) {
- infogainSum = igs;
- }
+ protected Node treeRoot = null;
- public long calcByteSize() {
- return (SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution));
- }
+ protected int decisionNodeCount;
- public long calcByteSizeIncludingSubtree() {
- return calcByteSize();
- }
+ protected int activeLeafNodeCount;
- public boolean isLeaf() {
- return true;
- }
+ protected int inactiveLeafNodeCount;
- public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
- int parentBranch) {
- return new FoundNode(this, parent, parentBranch);
- }
+ protected double inactiveLeafByteSizeEstimate;
- public double[] getObservedClassDistribution() {
- return this.observedClassDistribution.getArrayCopy();
- }
+ protected double activeLeafByteSizeEstimate;
- public double[] getClassVotes(Instance inst, EFDT ht) {
- return this.observedClassDistribution.getArrayCopy();
- }
+ protected double byteSizeEstimateOverheadFraction;
- public double[] getClassDistributionAtTimeOfCreation() {
- return this.classDistributionAtTimeOfCreation.getArrayCopy();
- }
+ protected boolean growthAllowed;
- public boolean observedClassDistributionIsPure() {
- return this.observedClassDistribution.numNonZeroEntries() < 2;
- }
+ protected int numInstances = 0;
- public void describeSubtree(EFDT ht, StringBuilder out,
- int indent) {
- StringUtils.appendIndented(out, indent, "Leaf ");
- out.append(ht.getClassNameString());
- out.append(" = ");
- out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex()));
- out.append(" weights: ");
- this.observedClassDistribution.getSingleLineDescription(out,
- ht.treeRoot.observedClassDistribution.numValues());
- StringUtils.appendNewline(out);
- }
+ protected int splitCount = 0;
- public int subtreeDepth() {
- return 0;
+ @Override
+ public String getPurposeString() {
+ return "Hoeffding Tree or VFDT.";
}
- public double calculatePromise() {
- double totalSeen = this.observedClassDistribution.sumOfValues();
- return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex()))
- : 0.0;
+ public int calcByteSize() {
+ int size = (int) SizeOf.sizeOf(this);
+ if (this.treeRoot != null) {
+ size += this.treeRoot.calcByteSizeIncludingSubtree();
+ }
+ return size;
}
@Override
- public void getDescription(StringBuilder sb, int indent) {
- describeSubtree(null, sb, indent);
+ public int measureByteSize() {
+ return calcByteSize();
}
- }
-
- public static class SplitNode extends Node {
-
- private static final long serialVersionUID = 1L;
-
- protected InstanceConditionalTest splitTest;
-
- protected AutoExpandVector children; // = new AutoExpandVector();
@Override
- public long calcByteSize() {
- return super.calcByteSize()
- + SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest);
+ public void resetLearningImpl() {
+ this.treeRoot = null;
+ this.decisionNodeCount = 0;
+ this.activeLeafNodeCount = 0;
+ this.inactiveLeafNodeCount = 0;
+ this.inactiveLeafByteSizeEstimate = 0.0;
+ this.activeLeafByteSizeEstimate = 0.0;
+ this.byteSizeEstimateOverheadFraction = 1.0;
+ this.growthAllowed = true;
+ if (this.leafpredictionOption.getChosenIndex() > 0) {
+ this.removePoorAttsOption = null;
+ }
}
@Override
- public long calcByteSizeIncludingSubtree() {
- long byteSize = calcByteSize();
- for (Node child : this.children) {
- if (child != null) {
- byteSize += child.calcByteSizeIncludingSubtree();
- }
- }
- return byteSize;
- }
-
- public SplitNode(InstanceConditionalTest splitTest,
- double[] classObservations, int size) {
- super(classObservations);
- this.splitTest = splitTest;
- this.children = new AutoExpandVector<>(size);
- }
-
- public SplitNode(InstanceConditionalTest splitTest,
- double[] classObservations) {
- super(classObservations);
- this.splitTest = splitTest;
- this.children = new AutoExpandVector<>();
- }
-
-
- public int numChildren() {
- return this.children.size();
- }
-
- public void setChild(int index, Node child) {
- if ((this.splitTest.maxBranches() >= 0)
- && (index >= this.splitTest.maxBranches())) {
- throw new IndexOutOfBoundsException();
- }
- this.children.set(index, child);
- }
-
- public Node getChild(int index) {
- return this.children.get(index);
- }
-
- public int instanceChildIndex(Instance inst) {
- return this.splitTest.branchForInstance(inst);
+ public double[] getVotesForInstance(Instance inst) {
+ if (this.treeRoot != null) {
+ FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst,
+ null, -1);
+ Node leafNode = foundNode.node;
+ if (leafNode == null) {
+ leafNode = foundNode.parent;
+ }
+ return leafNode.getClassVotes(inst, this);
+ } else {
+ int numClasses = inst.dataset().numClasses();
+ return new double[numClasses];
+ }
}
@Override
- public boolean isLeaf() {
- return false;
+ protected Measurement[] getModelMeasurementsImpl() {
+ FoundNode[] learningNodes = findLearningNodes();
+
+ return new Measurement[]{
+
+ new Measurement("tree size (nodes)", this.decisionNodeCount
+ + this.activeLeafNodeCount + this.inactiveLeafNodeCount),
+ new Measurement("tree size (leaves)", learningNodes.length),
+ new Measurement("active learning leaves",
+ this.activeLeafNodeCount),
+ new Measurement("tree depth", measureTreeDepth()),
+ new Measurement("active leaf byte size estimate",
+ this.activeLeafByteSizeEstimate),
+ new Measurement("inactive leaf byte size estimate",
+ this.inactiveLeafByteSizeEstimate),
+ new Measurement("byte size estimate overhead",
+ this.byteSizeEstimateOverheadFraction),
+ new Measurement("splits",
+ this.splitCount)};
+ }
+
+ public int measureTreeDepth() {
+ if (this.treeRoot != null) {
+ return this.treeRoot.subtreeDepth();
+ }
+ return 0;
}
@Override
- public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
- int parentBranch) {
-
- //System.err.println("OVERRIDING ");
-
- int childIndex = instanceChildIndex(inst);
- if (childIndex >= 0) {
- Node child = getChild(childIndex);
- if (child != null) {
- return child.filterInstanceToLeaf(inst, this, childIndex);
- }
- return new FoundNode(null, this, childIndex);
- }
- return new FoundNode(this, parent, parentBranch);
+ public void getModelDescription(StringBuilder out, int indent) {
+ this.treeRoot.describeSubtree(this, out, indent);
}
@Override
- public void describeSubtree(EFDT ht, StringBuilder out,
- int indent) {
- for (int branch = 0; branch < numChildren(); branch++) {
- Node child = getChild(branch);
- if (child != null) {
- StringUtils.appendIndented(out, indent, "if ");
- out.append(this.splitTest.describeConditionForBranch(branch,
- ht.getModelContext()));
- out.append(": ");
- StringUtils.appendNewline(out);
- child.describeSubtree(ht, out, indent + 2);
- }
- }
+ public boolean isRandomizable() {
+ return false;
+ }
+
+ public static double computeHoeffdingBound(double range, double confidence,
+ double n) {
+ return Math.sqrt(((range * range) * Math.log(1.0 / confidence))
+ / (2.0 * n));
+ }
+
+ protected AttributeClassObserver newNominalClassObserver() {
+ AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption);
+ return (AttributeClassObserver) nominalClassObserver.copy();
+ }
+
+ protected AttributeClassObserver newNumericClassObserver() {
+ AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption);
+ return (AttributeClassObserver) numericClassObserver.copy();
+ }
+
+ public void enforceTrackerLimit() {
+ if ((this.inactiveLeafNodeCount > 0)
+ || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
+ * this.inactiveLeafByteSizeEstimate)
+ * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) {
+ if (this.stopMemManagementOption.isSet()) {
+ this.growthAllowed = false;
+ return;
+ }
+ FoundNode[] learningNodes = findLearningNodes();
+ Arrays.sort(learningNodes, new Comparator() {
+
+ @Override
+ public int compare(FoundNode fn1, FoundNode fn2) {
+ return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise());
+ }
+ });
+ int maxActive = 0;
+ while (maxActive < learningNodes.length) {
+ maxActive++;
+ if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive)
+ * this.inactiveLeafByteSizeEstimate)
+ * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) {
+ maxActive--;
+ break;
+ }
+ }
+ int cutoff = learningNodes.length - maxActive;
+ for (int i = 0; i < cutoff; i++) {
+ if (learningNodes[i].node instanceof ActiveLearningNode) {
+ deactivateLearningNode(
+ (ActiveLearningNode) learningNodes[i].node,
+ learningNodes[i].parent,
+ learningNodes[i].parentBranch);
+ }
+ }
+ for (int i = cutoff; i < learningNodes.length; i++) {
+ if (learningNodes[i].node instanceof InactiveLearningNode) {
+ activateLearningNode(
+ (InactiveLearningNode) learningNodes[i].node,
+ learningNodes[i].parent,
+ learningNodes[i].parentBranch);
+ }
+ }
+ }
+ }
+
+ public void estimateModelByteSizes() {
+ FoundNode[] learningNodes = findLearningNodes();
+ long totalActiveSize = 0;
+ long totalInactiveSize = 0;
+ for (FoundNode foundNode : learningNodes) {
+ if (foundNode.node instanceof ActiveLearningNode) {
+ totalActiveSize += SizeOf.fullSizeOf(foundNode.node);
+ } else {
+ totalInactiveSize += SizeOf.fullSizeOf(foundNode.node);
+ }
+ }
+ if (totalActiveSize > 0) {
+ this.activeLeafByteSizeEstimate = (double) totalActiveSize
+ / this.activeLeafNodeCount;
+ }
+ if (totalInactiveSize > 0) {
+ this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize
+ / this.inactiveLeafNodeCount;
+ }
+ int actualModelSize = this.measureByteSize();
+ double estimatedModelSize = (this.activeLeafNodeCount
+ * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
+ * this.inactiveLeafByteSizeEstimate);
+ this.byteSizeEstimateOverheadFraction = actualModelSize
+ / estimatedModelSize;
+ if (actualModelSize > this.maxByteSizeOption.getValue()) {
+ enforceTrackerLimit();
+ }
+ }
+
+ public void deactivateAllLeaves() {
+ FoundNode[] learningNodes = findLearningNodes();
+ for (FoundNode learningNode : learningNodes) {
+ if (learningNode.node instanceof ActiveLearningNode) {
+ deactivateLearningNode(
+ (ActiveLearningNode) learningNode.node,
+ learningNode.parent, learningNode.parentBranch);
+ }
+ }
+ }
+
+ protected void deactivateLearningNode(ActiveLearningNode toDeactivate,
+ SplitNode parent, int parentBranch) {
+ Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
+ if (parent == null) {
+ this.treeRoot = newLeaf;
+ } else {
+ parent.setChild(parentBranch, newLeaf);
+ }
+ this.activeLeafNodeCount--;
+ this.inactiveLeafNodeCount++;
+ }
+
+ protected void activateLearningNode(InactiveLearningNode toActivate,
+ SplitNode parent, int parentBranch) {
+ Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution());
+ if (parent == null) {
+ this.treeRoot = newLeaf;
+ } else {
+ parent.setChild(parentBranch, newLeaf);
+ }
+ this.activeLeafNodeCount++;
+ this.inactiveLeafNodeCount--;
+ }
+
+ protected FoundNode[] findLearningNodes() {
+ List foundList = new LinkedList<>();
+ findLearningNodes(this.treeRoot, null, -1, foundList);
+ return foundList.toArray(new FoundNode[foundList.size()]);
+ }
+
+ protected void findLearningNodes(Node node, SplitNode parent,
+ int parentBranch, List found) {
+ if (node != null) {
+ if (node instanceof LearningNode) {
+ found.add(new FoundNode(node, parent, parentBranch));
+ }
+ if (node instanceof SplitNode) {
+ SplitNode splitNode = (SplitNode) node;
+ for (int i = 0; i < splitNode.numChildren(); i++) {
+ findLearningNodes(splitNode.getChild(i), splitNode, i,
+ found);
+ }
+ }
+ }
}
- @Override
- public int subtreeDepth() {
- int maxChildDepth = 0;
- for (Node child : this.children) {
- if (child != null) {
- int depth = child.subtreeDepth();
- if (depth > maxChildDepth) {
- maxChildDepth = depth;
- }
- }
- }
- return maxChildDepth + 1;
+ protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
+ int parentIndex) {
+
+ if (!node.observedClassDistributionIsPure()) {
+ node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains
+
+
+ SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
+ AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
+ Arrays.sort(bestSplitSuggestions);
+ boolean shouldSplit = false;
+
+ for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
+
+ if (bestSplitSuggestion.splitTest != null) {
+ if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) {
+ node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0);
+ }
+ double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]));
+ node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit);
+ } else { // handle the null attribute
+ double currentSum = node.getInfogainSum().get(-1); // null split
+ node.getInfogainSum().put(-1, Math.max(0.0, currentSum + bestSplitSuggestion.merit));
+ assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here.";
+ }
+
+ }
+
+ if (bestSplitSuggestions.length < 2) {
+ shouldSplit = bestSplitSuggestions.length > 0;
+ } else {
+ double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
+ this.splitConfidenceOption.getValue(), node.getWeightSeen());
+ AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+
+ double bestSuggestionAverageMerit;
+ double currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+
+ // because this is an unsplit leaf. current average merit should be always zero on the null split.
+
+ if (bestSuggestion.splitTest == null) { // if you have a null split
+ bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+ } else {
+ bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
+ }
+
+ if (bestSuggestion.merit < 1e-10) {
+ shouldSplit = false; // we don't use average here
+ } else if ((bestSuggestionAverageMerit - currentAverageMerit) >
+ hoeffdingBound
+ || (hoeffdingBound < this.tieThresholdOption.getValue())) {
+ if (bestSuggestionAverageMerit - currentAverageMerit < hoeffdingBound) {
+ // Placeholder to list this possibility
+ }
+ shouldSplit = true;
+ }
+
+ if (shouldSplit) {
+ for (Integer i : node.usedNominalAttributes) {
+ if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) {
+ shouldSplit = false;
+ break;
+ }
+ }
+ }
+
+ // }
+ if ((this.removePoorAttsOption != null)
+ && this.removePoorAttsOption.isSet()) {
+ Set poorAtts = new HashSet<>();
+ // scan 1 - add any poor to set
+ for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
+ if (bestSplitSuggestion.splitTest != null) {
+ int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn();
+ if (splitAtts.length == 1) {
+ if (bestSuggestion.merit
+ - bestSplitSuggestion.merit > hoeffdingBound) {
+ poorAtts.add(splitAtts[0]);
+ }
+ }
+ }
+ }
+ // scan 2 - remove good ones from set
+ for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
+ if (bestSplitSuggestion.splitTest != null) {
+ int[] splitAtts = bestSplitSuggestion.splitTest.getAttsTestDependsOn();
+ if (splitAtts.length == 1) {
+ if (bestSuggestion.merit
+ - bestSplitSuggestion.merit < hoeffdingBound) {
+ poorAtts.remove(splitAtts[0]);
+ }
+ }
+ }
+ }
+ for (int poorAtt : poorAtts) {
+ node.disableAttribute(poorAtt);
+ }
+ }
+ }
+ if (shouldSplit) {
+ splitCount++;
+
+ AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ if (splitDecision.splitTest == null) {
+ // preprune - null wins
+ deactivateLearningNode(node, parent, parentIndex);
+ } else {
+ Node newSplit = newSplitNode(splitDecision.splitTest,
+ node.getObservedClassDistribution(), splitDecision.numSplits());
+ ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers
+ newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, leaf to split
+
+ for (int i = 0; i < splitDecision.numSplits(); i++) {
+
+ double[] j = splitDecision.resultingClassDistributionFromSplit(i);
+
+ Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
+
+ if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
+ || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) {
+ newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy
+ newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
+ // no nominal attribute should be split on more than once in the path
+ }
+ ((EFDTSplitNode) newSplit).setChild(i, newChild);
+ }
+ this.activeLeafNodeCount--;
+ this.decisionNodeCount++;
+ this.activeLeafNodeCount += splitDecision.numSplits();
+ if (parent == null) {
+ this.treeRoot = newSplit;
+ } else {
+ parent.setChild(parentIndex, newSplit);
+ }
+
+ }
+ // manage memory
+ enforceTrackerLimit();
+ }
+ }
}
- }
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
- public class EFDTSplitNode extends SplitNode implements EFDTNode {
-
- /**
- *
- */
-
- private boolean isRoot;
+ if (this.treeRoot == null) {
+ this.treeRoot = newLearningNode();
+ ((EFDTNode) this.treeRoot).setRoot(true);
+ this.activeLeafNodeCount = 1;
+ }
- private EFDTSplitNode parent = null;
+ FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
+ Node leafNode = foundNode.node;
- private static final long serialVersionUID = 1L;
+ if (leafNode == null) {
+ leafNode = newLearningNode();
+ foundNode.parent.setChild(foundNode.parentBranch, leafNode);
+ this.activeLeafNodeCount++;
+ }
- protected AutoExpandVector attributeObservers;
+ ((EFDTNode) this.treeRoot).learnFromInstance(inst, this, null, -1);
- public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations, int size) {
- super(splitTest, classObservations, size);
+ numInstances++;
}
- public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations) {
- super(splitTest, classObservations);
- }
- @Override
- public boolean isRoot() {
- return isRoot;
+ protected LearningNode newLearningNode() {
+ return new EFDTLearningNode(new double[0], this.leafpredictionOption.getChosenLabel());
}
- @Override
- public void setRoot(boolean isRoot) {
- this.isRoot = isRoot;
+ protected LearningNode newLearningNode(double[] initialClassObservations) {
+ return new EFDTLearningNode(initialClassObservations, this.leafpredictionOption.getChosenLabel());
}
- public void killSubtree(EFDT ht) {
- for (Node child : this.children) {
- if (child != null) {
-
- //Recursive delete of SplitNodes
- if (child instanceof SplitNode) {
- ((EFDTSplitNode) child).killSubtree(ht);
- }
- else if (child instanceof ActiveLearningNode) {
- ht.activeLeafNodeCount--;
- }
- else if (child instanceof InactiveLearningNode) {
- ht.inactiveLeafNodeCount--;
- }
- }
- }
+ protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations, int size) {
+ return new EFDTSplitNode(splitTest, classObservations, size);
}
-
- // DRY Don't Repeat Yourself... code duplicated from ActiveLearningNode in VFDT.java. However, this is the most practical way to share stand-alone.
- public AttributeSplitSuggestion[] getBestSplitSuggestions(
- SplitCriterion criterion, EFDT ht) {
- List bestSuggestions = new LinkedList<>();
- double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
- if (!ht.noPrePruneOption.isSet()) {
- // add null split as an option
- bestSuggestions.add(new AttributeSplitSuggestion(null,
- new double[0][], criterion.getMeritOfSplit(
- preSplitDist, new double[][]{preSplitDist})));
- }
- for (int i = 0; i < this.attributeObservers.size(); i++) {
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs != null) {
- AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
- preSplitDist, i, ht.binarySplitsOption.isSet());
- if (bestSuggestion != null) {
- bestSuggestions.add(bestSuggestion);
- }
- }
- }
- return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
+ protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations) {
+ return new EFDTSplitNode(splitTest, classObservations);
}
+ private int argmax(double[] array) {
- @Override
- public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) {
-
- nodeTime++;
- //// Update node statistics and class distribution
-
- this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); // update prior (predictor)
-
- for (int i = 0; i < inst.numAttributes() - 1; i++) { //update likelihood
- int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs == null) {
- obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
- this.attributeObservers.set(i, obs);
- }
- obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
- }
-
- // check if a better split is available. if so, chop the tree at this point, copying likelihood. predictors for children are from parent likelihood.
- if (ht.numInstances % ht.reEvalPeriodOption.getValue() == 0) {
- this.reEvaluateBestSplit(this, parent, parentBranch);
- }
+ double max = array[0];
+ int maxarg = 0;
- int childBranch = this.instanceChildIndex(inst);
- Node child = this.getChild(childBranch);
-
- if (child != null) {
- ((EFDTNode) child).learnFromInstance(inst, ht, this, childBranch);
- }
+ for (int i = 1; i < array.length; i++) {
+ if (array[i] > max) {
+ max = array[i];
+ maxarg = i;
+ }
+ }
+ return maxarg;
}
- protected void reEvaluateBestSplit(EFDTSplitNode node, EFDTSplitNode parent,
- int parentIndex) {
-
+ public interface EFDTNode {
- node.addToSplitAttempts(1);
+ boolean isRoot();
- // EFDT must transfer over gain averages when replacing a node: leaf to split, split to leaf, or split to split
- // It must replace split nodes with leaves if null wins
+ void setRoot(boolean isRoot);
+ void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch);
- // node is a reference to this anyway... why have it at all?
+ void setParent(EFDTSplitNode parent);
- int currentSplit = -1;
- // and if we always choose to maintain tree structure
+ EFDTSplitNode getParent();
- //lets first find out X_a, the current split
+ }
- if (this.splitTest != null) {
- currentSplit = this.splitTest.getAttsTestDependsOn()[0];
- // given the current implementations in MOA, we're only ever expecting one int to be returned
- }
+ public static class FoundNode {
- //compute Hoeffding bound
- SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(EFDT.this.splitCriterionOption);
- double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getClassDistributionAtTimeOfCreation()),
- EFDT.this.splitConfidenceOption.getValue(), node.observedClassDistribution.sumOfValues());
+ public Node node;
- // get best split suggestions
- AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, EFDT.this);
- Arrays.sort(bestSplitSuggestions);
+ public SplitNode parent;
- // get the best suggestion
- AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ public int parentBranch;
+ public FoundNode(Node node, SplitNode parent, int parentBranch) {
+ this.node = node;
+ this.parent = parent;
+ this.parentBranch = parentBranch;
+ }
+ }
- for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
+ public static class Node extends AbstractMOAObject {
- if (bestSplitSuggestion.splitTest != null) {
- if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) {
- node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0);
- }
- double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]));
- node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit);
- }
+ private HashMap infogainSum;
- else { // handle the null attribute. this is fine to do- it'll always average zero, and we will use this later to potentially burn bad splits.
- double currentSum = node.getInfogainSum().get(-1); // null split
- node.getInfogainSum().put(-1, currentSum + bestSplitSuggestion.merit);
- }
+ private int numSplitAttempts = 0;
- }
+ private static final long serialVersionUID = 1L;
- // get the average merit for best and current splits
+ protected DoubleVector observedClassDistribution;
- double bestSuggestionAverageMerit;
- double currentAverageMerit;
+ protected DoubleVector classDistributionAtTimeOfCreation;
- if (bestSuggestion.splitTest == null) { // best is null
- bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
- }
- else {
+ protected int nodeTime;
- bestSuggestionAverageMerit = node.getInfogainSum().get(bestSuggestion.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts();
- }
+ protected List usedNominalAttributes = new ArrayList<>();
- if (node.splitTest == null) { // current is null- shouldn't happen, check for robustness
- currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
- }
- else {
- currentAverageMerit = node.getInfogainSum().get(node.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts();
- }
+ public Node(double[] classObservations) {
+ this.observedClassDistribution = new DoubleVector(classObservations);
+ this.classDistributionAtTimeOfCreation = new DoubleVector(classObservations);
+ this.infogainSum = new HashMap<>();
+ this.infogainSum.put(-1, 0.0); // Initialize for null split
- double tieThreshold = EFDT.this.tieThresholdOption.getValue();
+ }
- // compute the average deltaG
- double deltaG = bestSuggestionAverageMerit - currentAverageMerit;
+ public int getNumSplitAttempts() {
+ return numSplitAttempts;
+ }
- if (deltaG > hoeffdingBound
- || (hoeffdingBound < tieThreshold && deltaG > tieThreshold / 2)) {
+ public void addToSplitAttempts(int i) {
+ numSplitAttempts += i;
+ }
- System.err.println(numInstances);
+ public HashMap getInfogainSum() {
+ return infogainSum;
+ }
- AttributeSplitSuggestion splitDecision = bestSuggestion;
+ public void setInfogainSum(HashMap igs) {
+ infogainSum = igs;
+ }
- // if null split wins
- if (splitDecision.splitTest == null) {
+ public int calcByteSize() {
+ return (int) (SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution));
+ }
- node.killSubtree(EFDT.this);
- EFDTLearningNode replacement = (EFDTLearningNode) newLearningNode();
- replacement.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement leaf
- if (node.getParent() != null) {
- node.getParent().setChild(parentIndex, replacement);
- }
- else {
- assert (node.isRoot());
- node.setRoot(true);
- }
- }
+ public int calcByteSizeIncludingSubtree() {
+ return calcByteSize();
+ }
- else {
+ public boolean isLeaf() {
+ return true;
+ }
- Node newSplit = newSplitNode(splitDecision.splitTest,
- node.getObservedClassDistribution(), splitDecision.numSplits());
+ public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
+ int parentBranch) {
+ return new FoundNode(this, parent, parentBranch);
+ }
- ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers
- newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement split
+ public double[] getObservedClassDistribution() {
+ return this.observedClassDistribution.getArrayCopy();
+ }
- if (node.splitTest == splitDecision.splitTest
- && node.splitTest.getClass() == NumericAttributeBinaryTest.class &&
- (argmax(splitDecision.resultingClassDistributions[0]) == argmax(node.getChild(0).getObservedClassDistribution())
- || argmax(splitDecision.resultingClassDistributions[1]) == argmax(node.getChild(1).getObservedClassDistribution()))
- ) {
- // change split but don't destroy the subtrees
- for (int i = 0; i < splitDecision.numSplits(); i++) {
- ((EFDTSplitNode) newSplit).setChild(i, this.getChild(i));
- }
+ public double[] getClassVotes(Instance inst, EFDT ht) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
- }
- else {
+ public double[] getClassDistributionAtTimeOfCreation() {
+ return this.classDistributionAtTimeOfCreation.getArrayCopy();
+ }
- // otherwise, torch the subtree and split on the new best attribute.
+ public boolean observedClassDistributionIsPure() {
+ return this.observedClassDistribution.numNonZeroEntries() < 2;
+ }
- this.killSubtree(EFDT.this);
+ public void describeSubtree(EFDT ht, StringBuilder out,
+ int indent) {
+ StringUtils.appendIndented(out, indent, "Leaf ");
+ out.append(ht.getClassNameString());
+ out.append(" = ");
+ out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex()));
+ out.append(" weights: ");
+ this.observedClassDistribution.getSingleLineDescription(out,
+ ht.treeRoot.observedClassDistribution.numValues());
+ StringUtils.appendNewline(out);
+ }
- for (int i = 0; i < splitDecision.numSplits(); i++) {
+ public int subtreeDepth() {
+ return 0;
+ }
+
+ public double calculatePromise() {
+ double totalSeen = this.observedClassDistribution.sumOfValues();
+ return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex()))
+ : 0.0;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ describeSubtree(null, sb, indent);
+ }
+ }
+
+ public static class SplitNode extends Node {
+
+ private static final long serialVersionUID = 1L;
+
+ protected InstanceConditionalTest splitTest;
+
+ protected AutoExpandVector children; // = new AutoExpandVector();
+
+ @Override
+ public int calcByteSize() {
+ return super.calcByteSize()
+ + (int) (SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest));
+ }
+
+ @Override
+ public int calcByteSizeIncludingSubtree() {
+ int byteSize = calcByteSize();
+ for (Node child : this.children) {
+ if (child != null) {
+ byteSize += child.calcByteSizeIncludingSubtree();
+ }
+ }
+ return byteSize;
+ }
+
+ public SplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations, int size) {
+ super(classObservations);
+ this.splitTest = splitTest;
+ this.children = new AutoExpandVector<>(size);
+ }
+
+ public SplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations) {
+ super(classObservations);
+ this.splitTest = splitTest;
+ this.children = new AutoExpandVector<>();
+ }
+
+
+ public int numChildren() {
+ return this.children.size();
+ }
+
+ public void setChild(int index, Node child) {
+ if ((this.splitTest.maxBranches() >= 0)
+ && (index >= this.splitTest.maxBranches())) {
+ throw new IndexOutOfBoundsException();
+ }
+ this.children.set(index, child);
+ }
+
+ public Node getChild(int index) {
+ return this.children.get(index);
+ }
+
+ public int instanceChildIndex(Instance inst) {
+ return this.splitTest.branchForInstance(inst);
+ }
+
+ @Override
+ public boolean isLeaf() {
+ return false;
+ }
+
+ @Override
+ public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
+ int parentBranch) {
+
+ //System.err.println("OVERRIDING ");
+
+ int childIndex = instanceChildIndex(inst);
+ if (childIndex >= 0) {
+ Node child = getChild(childIndex);
+ if (child != null) {
+ return child.filterInstanceToLeaf(inst, this, childIndex);
+ }
+ return new FoundNode(null, this, childIndex);
+ }
+ return new FoundNode(this, parent, parentBranch);
+ }
+
+ @Override
+ public void describeSubtree(EFDT ht, StringBuilder out,
+ int indent) {
+ for (int branch = 0; branch < numChildren(); branch++) {
+ Node child = getChild(branch);
+ if (child != null) {
+ StringUtils.appendIndented(out, indent, "if ");
+ out.append(this.splitTest.describeConditionForBranch(branch,
+ ht.getModelContext()));
+ out.append(": ");
+ StringUtils.appendNewline(out);
+ child.describeSubtree(ht, out, indent + 2);
+ }
+ }
+ }
+
+ @Override
+ public int subtreeDepth() {
+ int maxChildDepth = 0;
+ for (Node child : this.children) {
+ if (child != null) {
+ int depth = child.subtreeDepth();
+ if (depth > maxChildDepth) {
+ maxChildDepth = depth;
+ }
+ }
+ }
+ return maxChildDepth + 1;
+ }
+ }
+
+
+ public class EFDTSplitNode extends SplitNode implements EFDTNode {
+
+ /**
+ *
+ */
+
+ private boolean isRoot;
+
+ private EFDTSplitNode parent = null;
+
+ private static final long serialVersionUID = 1L;
+
+ protected AutoExpandVector attributeObservers;
+
+ public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations, int size) {
+ super(splitTest, classObservations, size);
+ }
+
+ public EFDTSplitNode(InstanceConditionalTest splitTest, double[] classObservations) {
+ super(splitTest, classObservations);
+ }
+
+ @Override
+ public boolean isRoot() {
+ return isRoot;
+ }
+
+ @Override
+ public void setRoot(boolean isRoot) {
+ this.isRoot = isRoot;
+ }
+
+ public void killSubtree(EFDT ht) {
+ for (Node child : this.children) {
+ if (child != null) {
+
+ //Recursive delete of SplitNodes
+ if (child instanceof SplitNode) {
+ ((EFDTSplitNode) child).killSubtree(ht);
+ } else if (child instanceof ActiveLearningNode) {
+ ht.activeLeafNodeCount--;
+ } else if (child instanceof InactiveLearningNode) {
+ ht.inactiveLeafNodeCount--;
+ }
+ }
+ }
+ }
+
+
+ // DRY Don't Repeat Yourself... code duplicated from ActiveLearningNode in VFDT.java. However, this is the most practical way to share stand-alone.
+ public AttributeSplitSuggestion[] getBestSplitSuggestions(
+ SplitCriterion criterion, EFDT ht) {
+ List bestSuggestions = new LinkedList<>();
+ double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
+ if (!ht.noPrePruneOption.isSet()) {
+ // add null split as an option
+ bestSuggestions.add(new AttributeSplitSuggestion(null,
+ new double[0][], criterion.getMeritOfSplit(
+ preSplitDist, new double[][]{preSplitDist})));
+ }
+ for (int i = 0; i < this.attributeObservers.size(); i++) {
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs != null) {
+ AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
+ preSplitDist, i, ht.binarySplitsOption.isSet());
+ if (bestSuggestion != null) {
+ bestSuggestions.add(bestSuggestion);
+ }
+ }
+ }
+ return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
+ }
+
+
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) {
+
+ nodeTime++;
+ //// Update node statistics and class distribution
- double[] j = splitDecision.resultingClassDistributionFromSplit(i);
+ this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); // update prior (predictor)
- Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
+ for (int i = 0; i < inst.numAttributes() - 1; i++) { //update likelihood
+ int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs == null) {
+ obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
+ this.attributeObservers.set(i, obs);
+ }
+ obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
+ }
- if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
- || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) {
- newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy
- newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
- // no nominal attribute should be split on more than once in the path
- }
- ((EFDTSplitNode) newSplit).setChild(i, newChild);
- }
-
- EFDT.this.activeLeafNodeCount--;
- EFDT.this.decisionNodeCount++;
- EFDT.this.activeLeafNodeCount += splitDecision.numSplits();
+ // check if a better split is available. if so, chop the tree at this point, copying likelihood. predictors for children are from parent likelihood.
+ if (ht.numInstances % ht.reEvalPeriodOption.getValue() == 0) {
+ this.reEvaluateBestSplit(this, parent, parentBranch);
+ }
- }
+ int childBranch = this.instanceChildIndex(inst);
+ Node child = this.getChild(childBranch);
+ if (child != null) {
+ ((EFDTNode) child).learnFromInstance(inst, ht, this, childBranch);
+ }
- if (parent == null) {
- ((EFDTNode) newSplit).setRoot(true);
- ((EFDTNode) newSplit).setParent(null);
- EFDT.this.treeRoot = newSplit;
- }
- else {
- ((EFDTNode) newSplit).setRoot(false);
- ((EFDTNode) newSplit).setParent(parent);
- parent.setChild(parentIndex, newSplit);
- }
- }
- }
- }
+ }
- @Override
- public void setParent(EFDTSplitNode parent) {
- this.parent = parent;
- }
+ protected void reEvaluateBestSplit(EFDTSplitNode node, EFDTSplitNode parent,
+ int parentIndex) {
- @Override
- public EFDTSplitNode getParent() {
- return this.parent;
- }
- }
- public static abstract class LearningNode extends Node {
+ node.addToSplitAttempts(1);
+
+ // EFDT must transfer over gain averages when replacing a node: leaf to split, split to leaf, or split to split
+ // It must replace split nodes with leaves if null wins
+
+
+ // node is a reference to this anyway... why have it at all?
+
+ int currentSplit = -1;
+ // and if we always choose to maintain tree structure
+
+ //lets first find out X_a, the current split
+
+ if (this.splitTest != null) {
+ currentSplit = this.splitTest.getAttsTestDependsOn()[0];
+ // given the current implementations in MOA, we're only ever expecting one int to be returned
+ }
- private static final long serialVersionUID = 1L;
+ //compute Hoeffding bound
+ SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(EFDT.this.splitCriterionOption);
+ double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getClassDistributionAtTimeOfCreation()),
+ EFDT.this.splitConfidenceOption.getValue(), node.observedClassDistribution.sumOfValues());
- public LearningNode(double[] initialClassObservations) {
- super(initialClassObservations);
- }
+ // get best split suggestions
+ AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, EFDT.this);
+ Arrays.sort(bestSplitSuggestions);
- public abstract void learnFromInstance(Instance inst, EFDT ht);
- }
+ // get the best suggestion
+ AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- public class EFDTLearningNode extends LearningNodeNBAdaptive implements EFDTNode {
+ for (AttributeSplitSuggestion bestSplitSuggestion : bestSplitSuggestions) {
- private boolean isRoot;
+ if (bestSplitSuggestion.splitTest != null) {
+ if (!node.getInfogainSum().containsKey((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]))) {
+ node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), 0.0);
+ }
+ double currentSum = node.getInfogainSum().get((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]));
+ node.getInfogainSum().put((bestSplitSuggestion.splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestion.merit);
+ } else { // handle the null attribute. this is fine to do- it'll always average zero, and we will use this later to potentially burn bad splits.
+ double currentSum = node.getInfogainSum().get(-1); // null split
+ node.getInfogainSum().put(-1, currentSum + bestSplitSuggestion.merit);
+ }
- private EFDTSplitNode parent = null;
+ }
- public EFDTLearningNode(double[] initialClassObservations) {
- super(initialClassObservations);
- }
+ // get the average merit for best and current splits
+ double bestSuggestionAverageMerit;
+ double currentAverageMerit;
- /**
- *
- */
- private static final long serialVersionUID = -2525042202040084035L;
+ if (bestSuggestion.splitTest == null) { // best is null
+ bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+ } else {
- @Override
- public boolean isRoot() {
- return isRoot;
- }
+ bestSuggestionAverageMerit = node.getInfogainSum().get(bestSuggestion.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts();
+ }
- @Override
- public void setRoot(boolean isRoot) {
- this.isRoot = isRoot;
- }
+ if (node.splitTest == null) { // current is null- shouldn't happen, check for robustness
+ currentAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+ } else {
+ currentAverageMerit = node.getInfogainSum().get(node.splitTest.getAttsTestDependsOn()[0]) / node.getNumSplitAttempts();
+ }
- @Override
- public void learnFromInstance(Instance inst, EFDT ht) {
- super.learnFromInstance(inst, ht);
+ double tieThreshold = EFDT.this.tieThresholdOption.getValue();
- }
+ // compute the average deltaG
+ double deltaG = bestSuggestionAverageMerit - currentAverageMerit;
- @Override
- public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) {
- learnFromInstance(inst, ht);
-
- if (ht.growthAllowed) {
- ActiveLearningNode activeLearningNode = this;
- double weightSeen = activeLearningNode.getWeightSeen();
- if (activeLearningNode.nodeTime % ht.gracePeriodOption.getValue() == 0) {
- attemptToSplit(activeLearningNode, parent,
- parentBranch);
- activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
- }
- }
- }
+ if (deltaG > hoeffdingBound
+ || (hoeffdingBound < tieThreshold && deltaG > tieThreshold / 2)) {
- @Override
- public void setParent(EFDTSplitNode parent) {
- this.parent = parent;
- }
+ System.err.println(numInstances);
- @Override
- public EFDTSplitNode getParent() {
- return this.parent;
- }
+ AttributeSplitSuggestion splitDecision = bestSuggestion;
- }
+ // if null split wins
+ if (splitDecision.splitTest == null) {
- public static class InactiveLearningNode extends LearningNode {
+ node.killSubtree(EFDT.this);
+ EFDTLearningNode replacement = (EFDTLearningNode) newLearningNode();
+ replacement.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement leaf
+ if (node.getParent() != null) {
+ node.getParent().setChild(parentIndex, replacement);
+ } else {
+ assert (node.isRoot());
+ node.setRoot(true);
+ }
+ } else {
- private static final long serialVersionUID = 1L;
+ Node newSplit = newSplitNode(splitDecision.splitTest,
+ node.getObservedClassDistribution(), splitDecision.numSplits());
- public InactiveLearningNode(double[] initialClassObservations) {
- super(initialClassObservations);
- }
+ ((EFDTSplitNode) newSplit).attributeObservers = node.attributeObservers; // copy the attribute observers
+ newSplit.setInfogainSum(node.getInfogainSum()); // transfer infogain history, split to replacement split
- @Override
- public void learnFromInstance(Instance inst, EFDT ht) {
- this.observedClassDistribution.addToValue((int) inst.classValue(),
- inst.weight());
- }
- }
+ if (node.splitTest == splitDecision.splitTest
+ && node.splitTest.getClass() == NumericAttributeBinaryTest.class &&
+ (argmax(splitDecision.resultingClassDistributions[0]) == argmax(node.getChild(0).getObservedClassDistribution())
+ || argmax(splitDecision.resultingClassDistributions[1]) == argmax(node.getChild(1).getObservedClassDistribution()))
+ ) {
+ // change split but don't destroy the subtrees
+ for (int i = 0; i < splitDecision.numSplits(); i++) {
+ ((EFDTSplitNode) newSplit).setChild(i, this.getChild(i));
+ }
- public static class ActiveLearningNode extends LearningNode {
+ } else {
- private static final long serialVersionUID = 1L;
+ // otherwise, torch the subtree and split on the new best attribute.
- protected double weightSeenAtLastSplitEvaluation;
+ this.killSubtree(EFDT.this);
- protected AutoExpandVector attributeObservers = new AutoExpandVector<>();
+ for (int i = 0; i < splitDecision.numSplits(); i++) {
- protected boolean isInitialized;
+ double[] j = splitDecision.resultingClassDistributionFromSplit(i);
- public ActiveLearningNode(double[] initialClassObservations) {
- super(initialClassObservations);
- this.weightSeenAtLastSplitEvaluation = getWeightSeen();
- this.isInitialized = false;
- }
+ Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
- @Override
- public long calcByteSize() {
- return super.calcByteSize()
- + (SizeOf.fullSizeOf(this.attributeObservers));
- }
+ if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
+ || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) {
+ newChild.usedNominalAttributes = new ArrayList<>(node.usedNominalAttributes); //deep copy
+ newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
+ // no nominal attribute should be split on more than once in the path
+ }
+ ((EFDTSplitNode) newSplit).setChild(i, newChild);
+ }
- @Override
- public void learnFromInstance(Instance inst, EFDT ht) {
- nodeTime++;
-
- if (this.isInitialized) {
- this.attributeObservers = new AutoExpandVector<>(inst.numAttributes());
- this.isInitialized = true;
- }
- this.observedClassDistribution.addToValue((int) inst.classValue(),
- inst.weight());
- for (int i = 0; i < inst.numAttributes() - 1; i++) {
- int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs == null) {
- obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
- this.attributeObservers.set(i, obs);
- }
- obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
- }
- }
+ EFDT.this.activeLeafNodeCount--;
+ EFDT.this.decisionNodeCount++;
+ EFDT.this.activeLeafNodeCount += splitDecision.numSplits();
- public double getWeightSeen() {
- return this.observedClassDistribution.sumOfValues();
- }
+ }
- public double getWeightSeenAtLastSplitEvaluation() {
- return this.weightSeenAtLastSplitEvaluation;
- }
- public void setWeightSeenAtLastSplitEvaluation(double weight) {
- this.weightSeenAtLastSplitEvaluation = weight;
- }
+ if (parent == null) {
+ ((EFDTNode) newSplit).setRoot(true);
+ ((EFDTNode) newSplit).setParent(null);
+ EFDT.this.treeRoot = newSplit;
+ } else {
+ ((EFDTNode) newSplit).setRoot(false);
+ ((EFDTNode) newSplit).setParent(parent);
+ parent.setChild(parentIndex, newSplit);
+ }
+ }
+ }
+ }
- public AttributeSplitSuggestion[] getBestSplitSuggestions(
- SplitCriterion criterion, EFDT ht) {
- List bestSuggestions = new LinkedList<>();
- double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
- if (!ht.noPrePruneOption.isSet()) {
- // add null split as an option
- bestSuggestions.add(new AttributeSplitSuggestion(null,
- new double[0][], criterion.getMeritOfSplit(
- preSplitDist, new double[][]{preSplitDist})));
- }
- for (int i = 0; i < this.attributeObservers.size(); i++) {
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs != null) {
- AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
- preSplitDist, i, ht.binarySplitsOption.isSet());
- if (bestSuggestion != null) {
- bestSuggestions.add(bestSuggestion);
- }
- }
- }
- return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
- }
+ @Override
+ public void setParent(EFDTSplitNode parent) {
+ this.parent = parent;
+ }
- public void disableAttribute(int attIndex) {
- this.attributeObservers.set(attIndex,
- new NullAttributeClassObserver());
+ @Override
+ public EFDTSplitNode getParent() {
+ return this.parent;
+ }
}
- }
- public static class LearningNodeNB extends ActiveLearningNode {
+ public static abstract class LearningNode extends Node {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public LearningNodeNB(double[] initialClassObservations) {
- super(initialClassObservations);
- }
-
- @Override
- public double[] getClassVotes(Instance inst, EFDT ht) {
- if (getWeightSeen() >= ht.nbThresholdOption.getValue()) {
- return NaiveBayes.doNaiveBayesPrediction(inst,
- this.observedClassDistribution,
- this.attributeObservers);
- }
- return super.getClassVotes(inst, ht);
- }
+ public LearningNode(double[] initialClassObservations) {
+ super(initialClassObservations);
+ }
- @Override
- public void disableAttribute(int attIndex) {
- // should not disable poor atts - they are used in NB calc
+ public abstract void learnFromInstance(Instance inst, EFDT ht);
}
- }
- public static class LearningNodeNBAdaptive extends LearningNodeNB {
- private static final long serialVersionUID = 1L;
+ public class EFDTLearningNode extends LearningNodeNBAdaptive implements EFDTNode {
- protected double mcCorrectWeight = 0.0;
+ private boolean isRoot;
- protected double nbCorrectWeight = 0.0;
+ private EFDTSplitNode parent = null;
- public LearningNodeNBAdaptive(double[] initialClassObservations) {
- super(initialClassObservations);
- }
+ private String predictionType;
- @Override
- public void learnFromInstance(Instance inst, EFDT ht) {
- int trueClass = (int) inst.classValue();
- if (this.observedClassDistribution.maxIndex() == trueClass) {
- this.mcCorrectWeight += inst.weight();
- }
- if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
- this.observedClassDistribution, this.attributeObservers)) == trueClass) {
- this.nbCorrectWeight += inst.weight();
- }
- super.learnFromInstance(inst, ht);
- }
+ public EFDTLearningNode(double[] initialClassObservations, String predictionType) {
+ super(initialClassObservations);
+ this.predictionType = predictionType;
+ }
- @Override
- public double[] getClassVotes(Instance inst, EFDT ht) {
- if (this.mcCorrectWeight > this.nbCorrectWeight) {
- return this.observedClassDistribution.getArrayCopy();
- }
- return NaiveBayes.doNaiveBayesPrediction(inst,
- this.observedClassDistribution, this.attributeObservers);
- }
- }
- static class VFDT extends EFDT {
+ /**
+ *
+ */
+ private static final long serialVersionUID = -2525042202040084035L;
- @Override
- protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
- int parentIndex) {
- if (!node.observedClassDistributionIsPure()) {
-
-
- SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
- AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
-
- Arrays.sort(bestSplitSuggestions);
- boolean shouldSplit = false;
-
- for (int i = 0; i < bestSplitSuggestions.length; i++){
-
- node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains
-
-
- if (bestSplitSuggestions[i].splitTest != null){
- if (!node.getInfogainSum().containsKey((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0])))
- {
- node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), 0.0);
- }
- double currentSum = node.getInfogainSum().get((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]));
- node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestions[i].merit);
- }
-
- else { // handle the null attribute
- double currentSum = node.getInfogainSum().get(-1); // null split
- node.getInfogainSum().put(-1, currentSum + Math.max(0.0, bestSplitSuggestions[i].merit));
- assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here.";
- }
-
- }
-
- if (bestSplitSuggestions.length < 2) {
- shouldSplit = bestSplitSuggestions.length > 0;
- }
-
- else {
-
-
- double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
- this.splitConfidenceOption.getValue(), node.getWeightSeen());
-
- AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
-
-
- double bestSuggestionAverageMerit = 0.0;
- double secondBestSuggestionAverageMerit = 0.0;
-
- if(bestSuggestion.splitTest == null){ // if you have a null split
- bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
- } else{
- bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
- }
-
- if(secondBestSuggestion.splitTest == null){ // if you have a null split
- secondBestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
- } else{
- secondBestSuggestionAverageMerit = node.getInfogainSum().get((secondBestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
- }
-
- //comment this if statement to get VFDT bug
- if(bestSuggestion.merit < 1e-10){ // we don't use average here
- shouldSplit = false;
- }
-
- else
- if ((bestSuggestionAverageMerit - secondBestSuggestionAverageMerit > hoeffdingBound)
- || (hoeffdingBound < this.tieThresholdOption.getValue()))
- {
- shouldSplit = true;
- }
-
- if(shouldSplit){
- for(Integer i : node.usedNominalAttributes){
- if(bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i){
- shouldSplit = false;
- break;
- }
- }
- }
-
- // }
- if ((this.removePoorAttsOption != null)
- && this.removePoorAttsOption.isSet()) {
- Set poorAtts = new HashSet();
- // scan 1 - add any poor to set
- for (int i = 0; i < bestSplitSuggestions.length; i++) {
- if (bestSplitSuggestions[i].splitTest != null) {
- int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
- if (splitAtts.length == 1) {
- if (bestSuggestion.merit
- - bestSplitSuggestions[i].merit > hoeffdingBound) {
- poorAtts.add(new Integer(splitAtts[0]));
- }
- }
- }
- }
- // scan 2 - remove good ones from set
- for (int i = 0; i < bestSplitSuggestions.length; i++) {
- if (bestSplitSuggestions[i].splitTest != null) {
- int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
- if (splitAtts.length == 1) {
- if (bestSuggestion.merit
- - bestSplitSuggestions[i].merit < hoeffdingBound) {
- poorAtts.remove(new Integer(splitAtts[0]));
- }
- }
- }
- }
- for (int poorAtt : poorAtts) {
- node.disableAttribute(poorAtt);
- }
- }
- }
- if (shouldSplit) {
- splitCount++;
-
- AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- if (splitDecision.splitTest == null) {
- // preprune - null wins
- deactivateLearningNode(node, parent, parentIndex);
- } else {
- SplitNode newSplit = newSplitNode(splitDecision.splitTest,
- node.getObservedClassDistribution(), splitDecision.numSplits());
- for (int i = 0; i < splitDecision.numSplits(); i++) {
-
- double[] j = splitDecision.resultingClassDistributionFromSplit(i);
-
- Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
-
- if(splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
- ||splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class){
- newChild.usedNominalAttributes = new ArrayList(node.usedNominalAttributes); //deep copy
- newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
- // no nominal attribute should be split on more than once in the path
- }
- newSplit.setChild(i, newChild);
- }
- this.activeLeafNodeCount--;
- this.decisionNodeCount++;
- this.activeLeafNodeCount += splitDecision.numSplits();
- if (parent == null) {
- this.treeRoot = newSplit;
- } else {
- parent.setChild(parentIndex, newSplit);
- }
-
- }
-
- // manage memory
- enforceTrackerLimit();
- }
- }
- }
+ @Override
+ public boolean isRoot() {
+ return isRoot;
+ }
- @Override
- protected LearningNode newLearningNode() {
- return newLearningNode(new double[0]);
- }
+ @Override
+ public void setRoot(boolean isRoot) {
+ this.isRoot = isRoot;
+ }
- @Override
- protected LearningNode newLearningNode(double[] initialClassObservations) {
- LearningNode ret;
- int predictionOption = this.leafpredictionOption.getChosenIndex();
- if (predictionOption == 0) { //MC
- ret = new ActiveLearningNode(initialClassObservations);
- } else if (predictionOption == 1) { //NB
- ret = new LearningNodeNB(initialClassObservations);
- } else { //NBAdaptive
- ret = new LearningNodeNBAdaptive(initialClassObservations);
- }
- return ret;
- }
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht) {
+ super.learnFromInstance(inst, ht);
- @Override
- protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
- double[] classObservations, int size) {
- return new SplitNode(splitTest, classObservations, size);
- }
+ }
- @Override
- protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
- double[] classObservations) {
- return new SplitNode(splitTest, classObservations);
- }
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht, EFDTSplitNode parent, int parentBranch) {
+ learnFromInstance(inst, ht);
- @Override
- public void trainOnInstanceImpl(Instance inst) {
- //System.err.println(i++);
- if (this.treeRoot == null) {
- this.treeRoot = newLearningNode();
- this.activeLeafNodeCount = 1;
- }
- FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
- Node leafNode = foundNode.node;
-
- if (leafNode == null) {
- leafNode = newLearningNode();
- foundNode.parent.setChild(foundNode.parentBranch, leafNode);
- this.activeLeafNodeCount++;
- }
-
- if (leafNode instanceof LearningNode) {
- LearningNode learningNode = (LearningNode) leafNode;
- learningNode.learnFromInstance(inst, this);
- if (this.growthAllowed
- && (learningNode instanceof ActiveLearningNode)) {
- ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode;
- double weightSeen = activeLearningNode.getWeightSeen();
- if (activeLearningNode.nodeTime % this.gracePeriodOption.getValue() == 0) {
- attemptToSplit(activeLearningNode, foundNode.parent,
- foundNode.parentBranch);
- activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
- }
- }
- }
-
- if (this.trainingWeightSeenByModel
- % this.memoryEstimatePeriodOption.getValue() == 0) {
- estimateModelByteSizes();
- }
-
- numInstances++;
+ if (ht.growthAllowed) {
+ ActiveLearningNode activeLearningNode = this;
+ double weightSeen = activeLearningNode.getWeightSeen();
+ if (activeLearningNode.nodeTime % ht.gracePeriodOption.getValue() == 0) {
+ attemptToSplit(activeLearningNode, parent,
+ parentBranch);
+ activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
+ }
+ }
+ }
+
+ @Override
+ public void setParent(EFDTSplitNode parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public EFDTSplitNode getParent() {
+ return this.parent;
+ }
+
+ @Override
+ public double[] getClassVotes(Instance inst, EFDT ht) {
+ if (this.predictionType.equals("MC"))
+ return this.observedClassDistribution.getArrayCopy();
+ else if (this.predictionType.equals("NB"))
+ return NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution, this.attributeObservers);
+ // Naive Bayes Adaptive
+ if (this.mcCorrectWeight > this.nbCorrectWeight) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+ return NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution, this.attributeObservers);
+ }
+ }
+
+ public static class InactiveLearningNode extends LearningNode {
+
+ private static final long serialVersionUID = 1L;
+
+ public InactiveLearningNode(double[] initialClassObservations) {
+ super(initialClassObservations);
+ }
+
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht) {
+ this.observedClassDistribution.addToValue((int) inst.classValue(),
+ inst.weight());
+ }
+ }
+
+ public static class ActiveLearningNode extends LearningNode {
+
+ private static final long serialVersionUID = 1L;
+
+ protected double weightSeenAtLastSplitEvaluation;
+
+ protected AutoExpandVector attributeObservers = new AutoExpandVector<>();
+
+ protected boolean isInitialized;
+
+ public ActiveLearningNode(double[] initialClassObservations) {
+ super(initialClassObservations);
+ this.weightSeenAtLastSplitEvaluation = getWeightSeen();
+ this.isInitialized = false;
+ }
+
+ @Override
+ public int calcByteSize() {
+ return super.calcByteSize()
+ + (int) (SizeOf.fullSizeOf(this.attributeObservers));
+ }
+
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht) {
+ nodeTime++;
+
+ if (this.isInitialized) {
+ this.attributeObservers = new AutoExpandVector<>(inst.numAttributes());
+ this.isInitialized = true;
+ }
+ this.observedClassDistribution.addToValue((int) inst.classValue(),
+ inst.weight());
+ for (int i = 0; i < inst.numAttributes() - 1; i++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs == null) {
+ obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
+ this.attributeObservers.set(i, obs);
+ }
+ obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
+ }
+ }
+
+ public double getWeightSeen() {
+ return this.observedClassDistribution.sumOfValues();
+ }
+
+ public double getWeightSeenAtLastSplitEvaluation() {
+ return this.weightSeenAtLastSplitEvaluation;
+ }
+
+ public void setWeightSeenAtLastSplitEvaluation(double weight) {
+ this.weightSeenAtLastSplitEvaluation = weight;
+ }
+
+ public AttributeSplitSuggestion[] getBestSplitSuggestions(
+ SplitCriterion criterion, EFDT ht) {
+ List bestSuggestions = new LinkedList<>();
+ double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
+ if (!ht.noPrePruneOption.isSet()) {
+ // add null split as an option
+ bestSuggestions.add(new AttributeSplitSuggestion(null,
+ new double[0][], criterion.getMeritOfSplit(
+ preSplitDist, new double[][]{preSplitDist})));
+ }
+ for (int i = 0; i < this.attributeObservers.size(); i++) {
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs != null) {
+ AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
+ preSplitDist, i, ht.binarySplitsOption.isSet());
+ if (bestSuggestion != null) {
+ bestSuggestions.add(bestSuggestion);
+ }
+ }
+ }
+ return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
+ }
+
+ public void disableAttribute(int attIndex) {
+ this.attributeObservers.set(attIndex,
+ new NullAttributeClassObserver());
+ }
+ }
+
+ public static class LearningNodeNB extends ActiveLearningNode {
+
+ private static final long serialVersionUID = 1L;
+
+ public LearningNodeNB(double[] initialClassObservations) {
+ super(initialClassObservations);
+ }
+
+ @Override
+ public double[] getClassVotes(Instance inst, EFDT ht) {
+ if (getWeightSeen() >= ht.nbThresholdOption.getValue()) {
+ return NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution,
+ this.attributeObservers);
+ }
+ return super.getClassVotes(inst, ht);
+ }
+
+ @Override
+ public void disableAttribute(int attIndex) {
+ // should not disable poor atts - they are used in NB calc
+ }
+ }
+
+ public static class LearningNodeNBAdaptive extends LearningNodeNB {
+
+ private static final long serialVersionUID = 1L;
+
+ protected double mcCorrectWeight = 0.0;
+
+ protected double nbCorrectWeight = 0.0;
+
+ public LearningNodeNBAdaptive(double[] initialClassObservations) {
+ super(initialClassObservations);
+ }
+
+ @Override
+ public void learnFromInstance(Instance inst, EFDT ht) {
+ int trueClass = (int) inst.classValue();
+ if (this.observedClassDistribution.maxIndex() == trueClass) {
+ this.mcCorrectWeight += inst.weight();
+ }
+ if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution, this.attributeObservers)) == trueClass) {
+ this.nbCorrectWeight += inst.weight();
+ }
+ super.learnFromInstance(inst, ht);
+ }
+
+ @Override
+ public double[] getClassVotes(Instance inst, EFDT ht) {
+ if (this.mcCorrectWeight > this.nbCorrectWeight) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+ return NaiveBayes.doNaiveBayesPrediction(inst,
+ this.observedClassDistribution, this.attributeObservers);
+ }
+ }
+
+ static class VFDT extends EFDT {
+
+ @Override
+ protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
+ int parentIndex) {
+ if (!node.observedClassDistributionIsPure()) {
+
+
+ SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
+ AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
+
+ Arrays.sort(bestSplitSuggestions);
+ boolean shouldSplit = false;
+
+ for (int i = 0; i < bestSplitSuggestions.length; i++) {
+
+ node.addToSplitAttempts(1); // even if we don't actually attempt to split, we've computed infogains
+
+
+ if (bestSplitSuggestions[i].splitTest != null) {
+ if (!node.getInfogainSum().containsKey((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]))) {
+ node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), 0.0);
+ }
+ double currentSum = node.getInfogainSum().get((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]));
+ node.getInfogainSum().put((bestSplitSuggestions[i].splitTest.getAttsTestDependsOn()[0]), currentSum + bestSplitSuggestions[i].merit);
+ } else { // handle the null attribute
+ double currentSum = node.getInfogainSum().get(-1); // null split
+ node.getInfogainSum().put(-1, currentSum + Math.max(0.0, bestSplitSuggestions[i].merit));
+ assert node.getInfogainSum().get(-1) >= 0.0 : "Negative infogain shouldn't be possible here.";
+ }
+
+ }
+
+ if (bestSplitSuggestions.length < 2) {
+ shouldSplit = bestSplitSuggestions.length > 0;
+ } else {
+
+
+ double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
+ this.splitConfidenceOption.getValue(), node.getWeightSeen());
+
+ AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
+
+
+ double bestSuggestionAverageMerit = 0.0;
+ double secondBestSuggestionAverageMerit = 0.0;
+
+ if (bestSuggestion.splitTest == null) { // if you have a null split
+ bestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+ } else {
+ bestSuggestionAverageMerit = node.getInfogainSum().get((bestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
+ }
+
+ if (secondBestSuggestion.splitTest == null) { // if you have a null split
+ secondBestSuggestionAverageMerit = node.getInfogainSum().get(-1) / node.getNumSplitAttempts();
+ } else {
+ secondBestSuggestionAverageMerit = node.getInfogainSum().get((secondBestSuggestion.splitTest.getAttsTestDependsOn()[0])) / node.getNumSplitAttempts();
+ }
+
+ //comment this if statement to get VFDT bug
+ if (bestSuggestion.merit < 1e-10) { // we don't use average here
+ shouldSplit = false;
+ } else if ((bestSuggestionAverageMerit - secondBestSuggestionAverageMerit > hoeffdingBound)
+ || (hoeffdingBound < this.tieThresholdOption.getValue())) {
+ shouldSplit = true;
+ }
+
+ if (shouldSplit) {
+ for (Integer i : node.usedNominalAttributes) {
+ if (bestSuggestion.splitTest.getAttsTestDependsOn()[0] == i) {
+ shouldSplit = false;
+ break;
+ }
+ }
+ }
+
+ // }
+ if ((this.removePoorAttsOption != null)
+ && this.removePoorAttsOption.isSet()) {
+ Set poorAtts = new HashSet();
+ // scan 1 - add any poor to set
+ for (int i = 0; i < bestSplitSuggestions.length; i++) {
+ if (bestSplitSuggestions[i].splitTest != null) {
+ int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
+ if (splitAtts.length == 1) {
+ if (bestSuggestion.merit
+ - bestSplitSuggestions[i].merit > hoeffdingBound) {
+ poorAtts.add(new Integer(splitAtts[0]));
+ }
+ }
+ }
+ }
+ // scan 2 - remove good ones from set
+ for (int i = 0; i < bestSplitSuggestions.length; i++) {
+ if (bestSplitSuggestions[i].splitTest != null) {
+ int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
+ if (splitAtts.length == 1) {
+ if (bestSuggestion.merit
+ - bestSplitSuggestions[i].merit < hoeffdingBound) {
+ poorAtts.remove(new Integer(splitAtts[0]));
+ }
+ }
+ }
+ }
+ for (int poorAtt : poorAtts) {
+ node.disableAttribute(poorAtt);
+ }
+ }
+ }
+ if (shouldSplit) {
+ splitCount++;
+
+ AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ if (splitDecision.splitTest == null) {
+ // preprune - null wins
+ deactivateLearningNode(node, parent, parentIndex);
+ } else {
+ SplitNode newSplit = newSplitNode(splitDecision.splitTest,
+ node.getObservedClassDistribution(), splitDecision.numSplits());
+ for (int i = 0; i < splitDecision.numSplits(); i++) {
+
+ double[] j = splitDecision.resultingClassDistributionFromSplit(i);
+
+ Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
+
+ if (splitDecision.splitTest.getClass() == NominalAttributeBinaryTest.class
+ || splitDecision.splitTest.getClass() == NominalAttributeMultiwayTest.class) {
+ newChild.usedNominalAttributes = new ArrayList(node.usedNominalAttributes); //deep copy
+ newChild.usedNominalAttributes.add(splitDecision.splitTest.getAttsTestDependsOn()[0]);
+ // no nominal attribute should be split on more than once in the path
+ }
+ newSplit.setChild(i, newChild);
+ }
+ this.activeLeafNodeCount--;
+ this.decisionNodeCount++;
+ this.activeLeafNodeCount += splitDecision.numSplits();
+ if (parent == null) {
+ this.treeRoot = newSplit;
+ } else {
+ parent.setChild(parentIndex, newSplit);
+ }
+
+ }
+
+ // manage memory
+ enforceTrackerLimit();
+ }
+ }
+ }
+
+ @Override
+ protected LearningNode newLearningNode() {
+ return newLearningNode(new double[0]);
+ }
+
+ @Override
+ protected LearningNode newLearningNode(double[] initialClassObservations) {
+ LearningNode ret;
+ int predictionOption = this.leafpredictionOption.getChosenIndex();
+ if (predictionOption == 0) { //MC
+ ret = new ActiveLearningNode(initialClassObservations);
+ } else if (predictionOption == 1) { //NB
+ ret = new LearningNodeNB(initialClassObservations);
+ } else { //NBAdaptive
+ ret = new LearningNodeNBAdaptive(initialClassObservations);
+ }
+ return ret;
+ }
+
+ @Override
+ protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations, int size) {
+ return new SplitNode(splitTest, classObservations, size);
+ }
+
+ @Override
+ protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
+ double[] classObservations) {
+ return new SplitNode(splitTest, classObservations);
+ }
+
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ //System.err.println(i++);
+ if (this.treeRoot == null) {
+ this.treeRoot = newLearningNode();
+ this.activeLeafNodeCount = 1;
+ }
+ FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
+ Node leafNode = foundNode.node;
+
+ if (leafNode == null) {
+ leafNode = newLearningNode();
+ foundNode.parent.setChild(foundNode.parentBranch, leafNode);
+ this.activeLeafNodeCount++;
+ }
+
+ if (leafNode instanceof LearningNode) {
+ LearningNode learningNode = (LearningNode) leafNode;
+ learningNode.learnFromInstance(inst, this);
+ if (this.growthAllowed
+ && (learningNode instanceof ActiveLearningNode)) {
+ ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode;
+ double weightSeen = activeLearningNode.getWeightSeen();
+ if (activeLearningNode.nodeTime % this.gracePeriodOption.getValue() == 0) {
+ attemptToSplit(activeLearningNode, foundNode.parent,
+ foundNode.parentBranch);
+ activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
+ }
+ }
+ }
+
+ if (this.trainingWeightSeenByModel
+ % this.memoryEstimatePeriodOption.getValue() == 0) {
+ estimateModelByteSizes();
+ }
+
+ numInstances++;
+ }
}
- }
}
diff --git a/moa/src/main/java/moa/classifiers/trees/PLASTIC.java b/moa/src/main/java/moa/classifiers/trees/PLASTIC.java
new file mode 100644
index 000000000..d11327224
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/PLASTIC.java
@@ -0,0 +1,192 @@
+package moa.classifiers.trees;
+
+import com.github.javacliparser.FlagOption;
+import com.github.javacliparser.FloatOption;
+import com.github.javacliparser.IntOption;
+import com.github.javacliparser.MultiChoiceOption;
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.classifiers.AbstractClassifier;
+import moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
+import moa.classifiers.core.splitcriteria.SplitCriterion;
+import moa.classifiers.MultiClassClassifier;
+import moa.classifiers.trees.plastic_util.MeasuresNumberOfLeaves;
+import moa.classifiers.trees.plastic_util.PerformsTreeRevision;
+import moa.classifiers.trees.plastic_util.PlasticNode;
+import moa.core.DoubleVector;
+import moa.core.Measurement;
+import moa.options.ClassOption;
+
+import java.util.ArrayList;
+
+/**
+ * PLASTIC
+ *
+ * Restructures the subtrees that would have been pruned by EFDT.
+ *
+ * See details in:
Marco Heyden, Heitor Murilo Gomes, Edouard Fouché, Bernhard Pfahringer, Klemens Böhm:
+ * Leveraging Plasticity in Incremental Decision Trees. ECML/PKDD (5) 2024: 38-54
+ *
+ * @author Marco Heyden (marco dot heyden at kit dot edu)
+ * @version $Revision: 1 $
+ */
+public class PLASTIC extends AbstractClassifier implements MultiClassClassifier, PerformsTreeRevision, MeasuresNumberOfLeaves {
+
+ PlasticNode root;
+ int seenItems = 0;
+
+ public IntOption gracePeriodOption = new IntOption(
+ "gracePeriod",
+ 'g',
+ "The number of instances a leaf should observe between split attempts.",
+ 200, 0, Integer.MAX_VALUE);
+
+ public IntOption reEvalPeriodOption = new IntOption(
+ "reevaluationPeriod",
+ 'R',
+ "The number of instances an internal node should observe between re-evaluation attempts.",
+ 200, 0, Integer.MAX_VALUE);
+
+ public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
+ 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
+ "NominalAttributeClassObserver");
+
+ public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
+ 's', "Split criterion to use.", SplitCriterion.class,
+ "InfoGainSplitCriterion");
+
+ public FloatOption splitConfidenceOption = new FloatOption(
+ "splitConfidence",
+ 'c',
+ "The allowable error in split decision when using fixed confidence. Values closer to 0 will take longer to decide.",
+ 0.0000001, 0.0, 1.0);
+
+ public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
+ 't', "Threshold below which a split will be forced to break ties.",
+ 0.05, 0.0, 1.0);
+
+ public FloatOption tieThresholdReevalOption = new FloatOption("tieThresholdReevaluation",
+ 'T', "Threshold below which a split will be forced to break ties during reevaluation.",
+ 0.05, 0.0, 1.0);
+
+ public FloatOption relMinDeltaG = new FloatOption("relMinDeltaG",
+ 'G', "Relative minimum information gain to split a tie during reevaluation.",
+ 0.5, 0.0, 1.0);
+
+ public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
+ "Only allow binary splits.");
+
+ public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
+ "leafprediction", 'l', "Leaf prediction to use.", new String[]{
+ "MC", "NB", "NBA"}, new String[]{
+ "Majority class", "Naive Bayes", "Naive Bayes Adaptive"}, 2);
+
+ public IntOption maxDepthOption = new IntOption(
+ "maxDepth",
+ 'D',
+ "Maximum allowed depth of tree.",
+ 20, 0, Integer.MAX_VALUE);
+
+ public IntOption maxBranchLengthOption = new IntOption(
+ "maxBranchLength",
+ 'B',
+ "Maximum allowed length of branches during restructuring.",
+ 5, 1, Integer.MAX_VALUE);
+
+ /**
+ * Creates and configures the root node of the tree
+ *
+ * The root is the only access point for the main PLASTIC class. All subtrees etc will simply connect to the root.
+ *
+ * @return the created root node
+ **/
+ PlasticNode createRoot() {
+ return new PlasticNode(
+ (SplitCriterion) getPreparedClassOption(splitCriterionOption),
+ gracePeriodOption.getValue(),
+ splitConfidenceOption.getValue(),
+ 0.2,
+ false,
+ leafpredictionOption.getChosenLabel(),
+ reEvalPeriodOption.getValue(),
+ 0,
+ maxDepthOption.getValue(),
+ tieThresholdOption.getValue(),
+ tieThresholdReevalOption.getValue(),
+ 0.5,
+ binarySplitsOption.isSet(),
+ true,
+ (NominalAttributeClassObserver) getPreparedClassOption(nominalEstimatorOption),
+ new DoubleVector(),
+ new ArrayList<>(),
+ maxBranchLengthOption.getValue(),
+ 0.05,
+ -1
+ );
+ }
+
+ /**
+ * Trains the tree on the provided instance
+ * @param inst The instance to train on
+ **/
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ if (root == null)
+ root = createRoot();
+ root.learnInstance(inst, seenItems);
+ seenItems++;
+ }
+
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ if (root == null) {
+ root = createRoot();
+ return new double[inst.numClasses()];
+ }
+ return root.predict(inst);
+ }
+
+ @Override
+ public void resetLearningImpl() {
+ root = null;
+ seenItems = 0;
+ }
+
+ /**
+ * Checks if a tree revision (e.g., pruning or restructuring) was performed at the current time step
+ *
+ * @return if a tree revision was performed
+ **/
+ @Override
+ public boolean didPerformTreeRevision() {
+ return root.didPerformTreeRevision();
+ }
+
+ /**
+ * Returns the number of leaves of the tree.
+ *
+ *
+ * Note that calling this function repeatedly causes some overhead
+ * as the function traverses the tree recursively.
+ *
+ *
+ * @return the created root node
+ **/
+ @Override
+ public int getLeafNumber() {
+ return root.getLeafNumber();
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ return new Measurement[0];
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {}
+
+ @Override
+ public boolean isRandomizable() {
+ return false;
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/PLASTICA.java b/moa/src/main/java/moa/classifiers/trees/PLASTICA.java
new file mode 100644
index 000000000..b3cc82056
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/PLASTICA.java
@@ -0,0 +1,192 @@
+/*
+ * DriftDetectionMethodClassifier.java
+ * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
+ * @author Manuel Baena (mbaena@lcc.uma.es)
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package moa.classifiers.trees;
+
+import com.github.javacliparser.FloatOption;
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.capabilities.CapabilitiesHandler;
+import moa.capabilities.Capability;
+import moa.capabilities.ImmutableCapabilities;
+import moa.classifiers.AbstractClassifier;
+import moa.classifiers.Classifier;
+import moa.classifiers.MultiClassClassifier;
+import moa.classifiers.trees.plastic_util.CustomADWINChangeDetector;
+import moa.core.Measurement;
+import moa.core.Utils;
+import moa.options.ClassOption;
+
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * PLASTIC-A
+ *
+ * Combination of PLASTIC and a drift detector at the root. Will grow a background tree after a change and
+ * replace the current tree once the BG tree is more accurate.
+ *
+ * See details in:
Marco Heyden, Heitor Murilo Gomes, Edouard Fouché, Bernhard Pfahringer, Klemens Böhm:
+ * Leveraging Plasticity in Incremental Decision Trees. ECML/PKDD (5) 2024: 38-54
+ *
+ * @author Marco Heyden (marco dot heyden at kit dot edu)
+ * @version $Revision: 1 $
+ */
+public class PLASTICA extends AbstractClassifier implements MultiClassClassifier,
+ CapabilitiesHandler {
+
+ @Override
+ public String getPurposeString() {
+ return "Classifier that grows a background tree when a change is detected in accuracy.";
+ }
+
+ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+ "Classifier to train.", Classifier.class, "trees.PLASTIC");
+
+ public FloatOption confidenceOption = new FloatOption(
+ "Confidence",
+ 'c',
+ "Confidence at which the current learner will be replaced.",
+ 0.05, 0.0, 1.0);
+
+ protected Classifier classifier;
+
+ protected Classifier newclassifier;
+
+ protected CustomADWINChangeDetector mainLearnerChangeDetector;
+ protected CustomADWINChangeDetector bgLearnerChangeDetector;
+
+ private double confidence;
+ protected boolean newClassifierReset;
+
+ @Override
+ public void resetLearningImpl() {
+ this.classifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy();
+ this.classifier.resetLearning();
+ mainLearnerChangeDetector = new CustomADWINChangeDetector();
+ bgLearnerChangeDetector = new CustomADWINChangeDetector();
+ this.newClassifierReset = false;
+ confidence = confidenceOption.getValue();
+ }
+
+ protected int changeDetected = 0;
+
+ protected int warningDetected = 0;
+
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ //this.numberInstances++;
+ int trueClass = (int) inst.classValue();
+ boolean prediction;
+ if (Utils.maxIndex(this.classifier.getVotesForInstance(inst)) == trueClass) {
+ prediction = true;
+ } else {
+ prediction = false;
+ }
+
+ mainLearnerChangeDetector.input(prediction ? 0.0 : 1.0);
+
+ if (this.newclassifier != null) {
+ if (Utils.maxIndex(this.newclassifier.getVotesForInstance(inst)) == trueClass) {
+ prediction = true;
+ } else {
+ prediction = false;
+ }
+ bgLearnerChangeDetector.input(prediction ? 0.0 : 1.0);
+ }
+
+ if (this.mainLearnerChangeDetector.getChange() && newclassifier == null) {
+ makeNewClassifier();
+ }
+ else if (this.bgLearnerChangeDetector.getChange()) {
+ makeNewClassifier();
+ }
+
+ if (mainLearnerChangeDetector.getWidth() > 200 && bgLearnerChangeDetector.getWidth() > 200) {
+ double oldErrorRate = mainLearnerChangeDetector.getEstimation();
+ double oldWS = mainLearnerChangeDetector.getWidth();
+ double altErrorRate = bgLearnerChangeDetector.getEstimation();
+ double altWS = bgLearnerChangeDetector.getWidth();
+
+ double Bound = computeBound(oldErrorRate, oldWS, altWS);
+
+ if (Bound < oldErrorRate - altErrorRate) {
+ classifier = newclassifier;
+ newclassifier = null;
+ bgLearnerChangeDetector.resetLearning();
+ }
+ else if (Bound < altErrorRate - oldErrorRate) {
+ // Erase alternate tree
+ newclassifier = null;
+ bgLearnerChangeDetector.resetLearning();
+ }
+ }
+
+ this.classifier.trainOnInstance(inst);
+ if (newclassifier != null)
+ newclassifier.trainOnInstance(inst);
+ }
+
+ private double computeBound(double oldErrorRate, double oldWS, double altWS) {
+ double fDelta = confidenceOption.getValue();
+ double fN = 1.0 / altWS + 1.0 / oldWS;
+ return Math.sqrt(2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN);
+ }
+
+ public double[] getVotesForInstance(Instance inst) {
+ return this.classifier.getVotesForInstance(inst);
+ }
+
+ @Override
+ public boolean isRandomizable() {
+ return true;
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ ((AbstractClassifier) this.classifier).getModelDescription(out, indent);
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ List measurementList = new LinkedList();
+ measurementList.add(new Measurement("Change detected", this.changeDetected));
+ measurementList.add(new Measurement("Warning detected", this.warningDetected));
+ Measurement[] modelMeasurements = ((AbstractClassifier) this.classifier).getModelMeasurements();
+ if (modelMeasurements != null) {
+ for (Measurement measurement : modelMeasurements) {
+ measurementList.add(measurement);
+ }
+ }
+ this.changeDetected = 0;
+ this.warningDetected = 0;
+ return measurementList.toArray(new Measurement[measurementList.size()]);
+ }
+
+ @Override
+ public ImmutableCapabilities defineImmutableCapabilities() {
+ if (this.getClass() == PLASTICA.class)
+ return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
+ else
+ return new ImmutableCapabilities(Capability.VIEW_STANDARD);
+ }
+
+ private void makeNewClassifier() {
+ newclassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy();
+ bgLearnerChangeDetector.resetLearning();
+ }
+}
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java
new file mode 100644
index 000000000..8b428e841
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomADWINChangeDetector.java
@@ -0,0 +1,27 @@
+package moa.classifiers.trees.plastic_util;
+
+import moa.classifiers.core.driftdetection.ADWINChangeDetector;
+
+public class CustomADWINChangeDetector extends ADWINChangeDetector {
+ private boolean hadChange = false;
+
+ @Override
+ public boolean getChange() {
+ boolean hadChangeCpy = hadChange;
+ hadChange = false;
+ return hadChangeCpy;
+ }
+
+ @Override
+ public void input(double inputValue) {
+ super.input(inputValue);
+ if (!hadChange)
+ hadChange = super.getChange();
+ }
+
+ public int getWidth() {
+ if (adwin == null)
+ return 0;
+ return adwin.getWidth();
+ }
+}
\ No newline at end of file
diff --git a/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java
new file mode 100644
index 000000000..2ae55406f
--- /dev/null
+++ b/moa/src/main/java/moa/classifiers/trees/plastic_util/CustomEFDTNode.java
@@ -0,0 +1,780 @@
+package moa.classifiers.trees.plastic_util;
+
+import com.yahoo.labs.samoa.instances.Attribute;
+import com.yahoo.labs.samoa.instances.Instance;
+import moa.AbstractMOAObject;
+import moa.classifiers.bayes.NaiveBayes;
+import moa.classifiers.core.AttributeSplitSuggestion;
+import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
+import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
+import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
+import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest;
+import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest;
+import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
+import moa.classifiers.core.splitcriteria.SplitCriterion;
+import moa.classifiers.trees.EFDT;
+import moa.core.AutoExpandVector;
+import moa.core.DoubleVector;
+import moa.core.Utils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.*;
+
+class CustomEFDTNode extends AbstractMOAObject implements PerformsTreeRevision, MeasuresNumberOfLeaves {
+ protected final int gracePeriod;
+ protected final SplitCriterion splitCriterion;
+ protected final Double confidence;
+ protected final boolean useAdaptiveConfidence;
+ protected final Double adaptiveConfidence;
+ protected final String leafPrediction;
+ protected final Integer minSamplesReevaluate;
+ protected Integer depth;
+ protected final Integer maxDepth;
+ protected final Double tau;
+ protected final Double tauReevaluate;
+ protected final Double relMinDeltaG;
+ protected final boolean binaryOnly;
+ protected List usedNominalAttributes;
+ protected HashMap infogainSum = new HashMap<>();
+ private InstanceConditionalTest splitTest;
+ protected int numSplitAttempts = 0;
+ protected final NominalAttributeClassObserver nominalObserverBlueprint;
+ protected final GaussianNumericAttributeClassObserver numericObserverBlueprint = new GaussianNumericAttributeClassObserver();
+
+ protected DoubleVector classDistributionAtTimeOfCreation;
+ protected DoubleVector observedClassDistribution;
+ protected AutoExpandVector attributeObservers = new AutoExpandVector<>();
+ protected Double seenWeight = 0.0;
+ protected int nodeTime = 0;
+ protected Successors successors;
+ protected Attribute splitAttribute;
+ protected final boolean noPrePrune;
+ protected int blockedAttributeIndex;
+ protected boolean performedTreeRevision = false;
+ protected double mcCorrectWeight = 0.0;
+ protected double nbCorrectWeight = 0.0;
+
+ public CustomEFDTNode(SplitCriterion splitCriterion,
+ int gracePeriod,
+ Double confidence,
+ Double adaptiveConfidence,
+ boolean useAdaptiveConfidence,
+ String leafPrediction,
+ Integer minSamplesReevaluate,
+ Integer depth,
+ Integer maxDepth,
+ Double tau,
+ Double tauReevaluate,
+ Double relMinDeltaG,
+ boolean binaryOnly,
+ boolean noPrePrune,
+ NominalAttributeClassObserver nominalObserverBlueprint,
+ DoubleVector observedClassDistribution,
+ List