Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions moa/src/main/java/moa/classifiers/meta/StreamingRandomPatches.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ public void trainOnInstanceImpl(Instance instance) {
this.ensemble[i].evaluator.addResult(example, vote.getArrayRef());
// Train using random subspaces without resampling, i.e. all instances are used for training.
if(this.trainingMethodOption.getChosenIndex() == TRAIN_RANDOM_SUBSPACES) {
this.ensemble[i].trainOnInstance(instance,1, this.instancesSeen, this.classifierRandom);
this.ensemble[i].trainOnInstance(instance,1, this.instancesSeen, this.classifierRandom, true);
}
// Train using random patches or resampling, thus we simulate online bagging with poisson(lambda=...)
else {
int k = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
if (k > 0) {
double weight = k;
this.ensemble[i].trainOnInstance(instance, weight, this.instancesSeen, this.classifierRandom);
this.ensemble[i].trainOnInstance(instance, weight, this.instancesSeen, this.classifierRandom, true);
}
}
}
Expand Down Expand Up @@ -202,7 +202,7 @@ protected Measurement[] getModelMeasurementsImpl() {
return null;
}

protected void initEnsemble(Instance instance) {
public void initEnsemble(Instance instance) {
// Init the ensemble.
int ensembleSize = this.ensembleSizeOption.getValue();
this.ensemble = new StreamingRandomPatchesClassifier[ensembleSize];
Expand Down Expand Up @@ -326,7 +326,7 @@ public Classifier[] getSublearners() {
}

public static ArrayList<ArrayList<Integer>> localRandomKCombinations(int k, int length,
int nCombinations, Random random) {
int nCombinations, Random random) {
ArrayList<ArrayList<Integer>> combinations = new ArrayList<>();
for(int i = 0 ; i < nCombinations ; ++i) {
ArrayList<Integer> combination = new ArrayList<>();
Expand Down Expand Up @@ -364,7 +364,7 @@ public static ArrayList<ArrayList<Integer>> allKCombinations(int k, int length)
}

// Inner class representing the base learner of SRP.
protected class StreamingRandomPatchesClassifier {
public class StreamingRandomPatchesClassifier {
public int indexOriginal;
public long createdOn;
public Classifier classifier;
Expand Down Expand Up @@ -512,7 +512,8 @@ public void reset(Instance instance, long instancesSeen, Random random) {
this.classifier.resetLearning();
this.evaluator.reset();
this.createdOn = instancesSeen;
this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy();
if(this.driftOption != null)
this.driftDetectionMethod = ((ChangeDetector) getPreparedClassOption(this.driftOption)).copy();

if(this.subset != null) {
ArrayList<Integer> fIndexes = this.applySubsetResetStrategy(instance, random);
Expand All @@ -530,7 +531,8 @@ public void reset(Instance instance, long instancesSeen, Random random) {
}
}

public void trainOnInstance(Instance instance, double weight, long instancesSeen, Random random) {
public void trainOnInstance(Instance instance, double weight, long instancesSeen,
Random random, boolean updateDriftDetector) {
boolean correctlyClassifies;
// The subset object will be null if we are training with all features
if(this.subset != null) {
Expand All @@ -541,18 +543,18 @@ public void trainOnInstance(Instance instance, double weight, long instancesSeen
this.classifier.trainOnInstance(this.subset.get(0));
correctlyClassifies = this.classifier.correctlyClassifies(this.subset.get(0));
if(this.bkgLearner != null)
this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random);
this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random, updateDriftDetector);
}
else {
Instance weightedInstance = instance.copy();
weightedInstance.setWeight(instance.weight() * weight);
this.classifier.trainOnInstance(weightedInstance);
correctlyClassifies = this.classifier.correctlyClassifies(instance);
if(this.bkgLearner != null)
this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random);
this.bkgLearner.trainOnInstance(instance, weight, instancesSeen, random, updateDriftDetector);
}

if(!this.disableDriftDetector && !this.isBackgroundLearner) {
if(!this.disableDriftDetector && !this.isBackgroundLearner && updateDriftDetector) {

// Check for warning only if useBkgLearner is active
if (!this.disableBkgLearner) {
Expand Down Expand Up @@ -613,4 +615,8 @@ public double[] getVotesForInstance(Instance instance) {
return vote.getArrayRef();
}
}

public StreamingRandomPatchesClassifier[] getEnsembleMembers() {
return this.ensemble;
}
}
Loading