gov.sandia.cognition.learning.algorithm.ensemble
Class AbstractBaggingLearner<InputType,OutputType,MemberType,EnsembleType extends Evaluator<? super InputType,? extends OutputType>>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,ResultType>
                  extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner<InputType,OutputType,EnsembleType>
                      extended by gov.sandia.cognition.learning.algorithm.ensemble.AbstractBaggingLearner<InputType,OutputType,MemberType,EnsembleType>
Type Parameters:
InputType - The input type for supervised learning. Passed on to the internal learning algorithm. Also the input type for the learned ensemble.
OutputType - The output type for supervised learning. Passed on to the internal learning algorithm. Also the output type of the learned ensemble.
MemberType - The type of ensemble member created by the inner learning algorithm. Usually an evaluator.
EnsembleType - The type of ensemble that the algorithm fills with ensemble members.
All Implemented Interfaces:
AnytimeAlgorithm<EnsembleType>, IterativeAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,EnsembleType>, BatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,EnsembleType>, BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType>>, SupervisedBatchLearner<InputType,OutputType,EnsembleType>, CloneableSerializable, Randomized, Serializable, Cloneable
Direct Known Subclasses:
BaggingCategorizerLearner, BaggingRegressionLearner, BinaryBaggingLearner

@PublicationReference(title="Bagging Predictors",
                      author="Leo Breiman",
                      year=1996,
                      type=Journal,
                      publication="Machine Learning",
                      pages={123,140},
                      url="http://www.springerlink.com/index/L4780124W2874025.pdf")
public abstract class AbstractBaggingLearner<InputType,OutputType,MemberType,EnsembleType extends Evaluator<? super InputType,? extends OutputType>>
extends AbstractAnytimeSupervisedBatchLearner<InputType,OutputType,EnsembleType>
implements Randomized, BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType>>

Learns an ensemble by randomly sampling with replacement (duplicates allowed) some percentage of the size of the data (defaults to 100%) on each iteration to train a new ensemble member. The random sample is referred to as a bag. Each learned ensemble member is given equal weight. The idea here is that randomly sampling from the data and learning an ensemble member that has high variance (such as a decision tree) with respect to the input data, one can improve the performance of that algorithm. By default, the algorithm runs the maxIterations number of steps to create that number of ensemble members. However, one can also use out-of-bag (OOB) error on each iteration to determine a stopping criteria. The OOB error is determined by looking at the performance of the categorizer on the examples that it has not seen.

Since:
3.3.3
Author:
Justin Basilico
See Also:
Serialized Form

Field Summary
protected  ArrayList<InputOutputPair<? extends InputType,OutputType>> bag
          The current bag of data.
protected  int[] dataInBag
          An indicator of whether or not the data is in the current bag.
protected  ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> dataList
          The data stored for efficient random access.
static int DEFAULT_MAX_ITERATIONS
          The default maximum number of iterations is 100.
static double DEFAULT_PERCENT_TO_SAMPLE
          The default percent to sample is 1.0 (which represents 100%).
protected  EnsembleType ensemble
          The ensemble being created by the learner.
protected  BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner
          The learner to use to create the categorizer for each iteration.
protected  double percentToSample
          The percentage of the data to sample with replacement on each iteration.
protected  Random random
          The random number generator to use.
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
AbstractBaggingLearner()
          Creates a new instance of AbstractBaggingLearner.
AbstractBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner)
          Creates a new instance of AbstractBaggingLearner.
AbstractBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner, int maxIterations, double percentToSample, Random random)
          Creates a new instance of AbstractBaggingLearner.
 
Method Summary
protected abstract  void addEnsembleMember(MemberType member)
          Adds a new member to the ensemble.
protected  void cleanupAlgorithm()
          Called to clean up the learning algorithm's state after learning has finished.
protected abstract  EnsembleType createInitialEnsemble()
          Create the initial, empty ensemble for the algorithm to use.
protected  void fillBag(int sampleCount)
          Fills the internal bag field by sampling the given number of samples.
 ArrayList<InputOutputPair<? extends InputType,OutputType>> getBag()
          Gets the most recently created bag.
 int[] getDataInBag()
          Gets the array of counts of the number of samples of each example in the current bag.
 ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> getDataList()
          Gets the data the learner is using as an array list.
 EnsembleType getEnsemble()
          Gets the ensemble created by this learner.
 BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> getLearner()
          Gets the learner used to learn each ensemble member.
 double getPercentToSample()
          Gets the percentage of the total data to sample on each iteration.
 Random getRandom()
          Gets the random number generator used by this object.
 EnsembleType getResult()
          Gets the ensemble created by this learner.
protected  boolean initializeAlgorithm()
          Called to initialize the learning algorithm's state based on the data that is stored in the data field.
protected  void setBag(ArrayList<InputOutputPair<? extends InputType,OutputType>> bag)
          Sets the most recently created bag.
protected  void setDataInBag(int[] dataInBag)
          Sets the array of counts of the number of samples of each example in the current bag.
protected  void setDataList(ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> dataList)
          Sets the data the learner is using as an array list.
protected  void setEnsemble(EnsembleType ensemble)
          Sets the ensemble created by this learner.
 void setLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner)
          Sets the learner used to learn each ensemble member.
 void setPercentToSample(double percentToSample)
          Sets the percentage of the data to sample (with replacement) on each iteration.
 void setRandom(Random random)
          Sets the random number generator used by this object.
protected  boolean step()
          Called to take a single step of the learning algorithm.
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
clone, getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.learning.algorithm.BatchLearner
learn
 
Methods inherited from interface gov.sandia.cognition.util.CloneableSerializable
clone
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid
 

Field Detail

DEFAULT_MAX_ITERATIONS

public static final int DEFAULT_MAX_ITERATIONS
The default maximum number of iterations is 100.

See Also:
Constant Field Values

DEFAULT_PERCENT_TO_SAMPLE

public static final double DEFAULT_PERCENT_TO_SAMPLE
The default percent to sample is 1.0 (which represents 100%).

See Also:
Constant Field Values

learner

protected BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner
The learner to use to create the categorizer for each iteration.


percentToSample

protected double percentToSample
The percentage of the data to sample with replacement on each iteration. Must be positive. Represented as a floating point number with 1.0 meaning 100%.


random

protected Random random
The random number generator to use.


ensemble

protected transient EnsembleType extends Evaluator<? super InputType,? extends OutputType> ensemble
The ensemble being created by the learner.


dataList

protected transient ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> dataList
The data stored for efficient random access.


dataInBag

protected transient int[] dataInBag
An indicator of whether or not the data is in the current bag.


bag

protected transient ArrayList<InputOutputPair<? extends InputType,OutputType>> bag
The current bag of data.

Constructor Detail

AbstractBaggingLearner

public AbstractBaggingLearner()
Creates a new instance of AbstractBaggingLearner.


AbstractBaggingLearner

public AbstractBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner)
Creates a new instance of AbstractBaggingLearner.

Parameters:
learner - The learner to use to create the ensemble member on each iteration.

AbstractBaggingLearner

public AbstractBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner,
                              int maxIterations,
                              double percentToSample,
                              Random random)
Creates a new instance of AbstractBaggingLearner.

Parameters:
learner - The learner to use to create the ensemble member on each iteration.
maxIterations - The maximum number of iterations to run for, which is also the number of learners to create.
percentToSample - The percentage of the total size of the data to sample on each iteration. Must be positive.
random - The random number generator to use.
Method Detail

initializeAlgorithm

protected boolean initializeAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to initialize the learning algorithm's state based on the data that is stored in the data field. The return value indicates if the algorithm can be run or not based on the initialization.

Specified by:
initializeAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,EnsembleType extends Evaluator<? super InputType,? extends OutputType>>
Returns:
True if the learning algorithm can be run and false if it cannot.

step

protected boolean step()
Description copied from class: AbstractAnytimeBatchLearner
Called to take a single step of the learning algorithm.

Specified by:
step in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,EnsembleType extends Evaluator<? super InputType,? extends OutputType>>
Returns:
True if another step can be taken and false it the algorithm should halt.

createInitialEnsemble

protected abstract EnsembleType createInitialEnsemble()
Create the initial, empty ensemble for the algorithm to use.

Returns:
A new ensemble for the algorithm to use.

addEnsembleMember

protected abstract void addEnsembleMember(MemberType member)
Adds a new member to the ensemble.

Parameters:
member - The new member to add to the ensemble.

fillBag

protected void fillBag(int sampleCount)
Fills the internal bag field by sampling the given number of samples.

Parameters:
sampleCount - The number to sample.

cleanupAlgorithm

protected void cleanupAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to clean up the learning algorithm's state after learning has finished.

Specified by:
cleanupAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,EnsembleType extends Evaluator<? super InputType,? extends OutputType>>

getResult

public EnsembleType getResult()
Gets the ensemble created by this learner.

Specified by:
getResult in interface AnytimeAlgorithm<EnsembleType extends Evaluator<? super InputType,? extends OutputType>>
Returns:
The ensemble created by this learner.

getLearner

public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> getLearner()
Gets the learner used to learn each ensemble member.

Specified by:
getLearner in interface BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType>>
Returns:
The learner used for each ensemble member.

setLearner

public void setLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType,OutputType>>,? extends MemberType> learner)
Sets the learner used to learn each ensemble member. Must be a supervised learning algorithm that takes in a collection of input-output pairs of the given data types and produces an evaluator for those data types.

Parameters:
learner - The learner used for each ensemble member.

getPercentToSample

public double getPercentToSample()
Gets the percentage of the total data to sample on each iteration.

Returns:
The percentage of the total data to sample on each iteration.

setPercentToSample

public void setPercentToSample(double percentToSample)
Sets the percentage of the data to sample (with replacement) on each iteration. Must be greater than zero. The percent is represented as a floating point number with 1.0 representing 100%.

Parameters:
percentToSample - The percent of the data to sample on each iteration. Must be greater than zero. Defaults to 100%.

getRandom

public Random getRandom()
Description copied from interface: Randomized
Gets the random number generator used by this object.

Specified by:
getRandom in interface Randomized
Returns:
The random number generator used by this object.

setRandom

public void setRandom(Random random)
Description copied from interface: Randomized
Sets the random number generator used by this object.

Specified by:
setRandom in interface Randomized
Parameters:
random - The random number generator for this object to use.

getEnsemble

public EnsembleType getEnsemble()
Gets the ensemble created by this learner.

Returns:
The ensemble created by this learner.

setEnsemble

protected void setEnsemble(EnsembleType ensemble)
Sets the ensemble created by this learner.

Parameters:
ensemble - The ensemble created by this learner.

getDataList

public ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> getDataList()
Gets the data the learner is using as an array list.

Returns:
The data as an array list.

setDataList

protected void setDataList(ArrayList<? extends InputOutputPair<? extends InputType,OutputType>> dataList)
Sets the data the learner is using as an array list.

Parameters:
dataList - The data as an array list.

getDataInBag

public int[] getDataInBag()
Gets the array of counts of the number of samples of each example in the current bag.

Returns:
The bag counts.

setDataInBag

protected void setDataInBag(int[] dataInBag)
Sets the array of counts of the number of samples of each example in the current bag.

Parameters:
dataInBag - The bag counts.

getBag

public ArrayList<InputOutputPair<? extends InputType,OutputType>> getBag()
Gets the most recently created bag.

Returns:
The most recently created bag.

setBag

protected void setBag(ArrayList<InputOutputPair<? extends InputType,OutputType>> bag)
Sets the most recently created bag.

Parameters:
bag - The most recently created bag.