gov.sandia.cognition.learning.algorithm.tree
Class RegressionTreeLearner<InputType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner<InputType,Double>
              extended by gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner<InputType>
Type Parameters:
InputType - The type of the input to the tree.
All Implemented Interfaces:
IterativeAlgorithm, BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,RegressionTree<InputType>>, SupervisedBatchLearner<InputType,Double,RegressionTree<InputType>>, CloneableSerializable, Serializable, Cloneable

public class RegressionTreeLearner<InputType>
extends AbstractDecisionTreeLearner<InputType,Double>
implements SupervisedBatchLearner<InputType,Double,RegressionTree<InputType>>

The RegressionTreeLearner class implements a learning algorithm for a regression tree that makes use of a decider learner and a regresion learner. The tree grows as a decision tree until it gets to a leaf node (determined by a minimum number of nodes), and then learns a regression function at the leaf node.

Since:
2.0
Author:
Justin Basilico
See Also:
Serialized Form

Field Summary
static int DEFAULT_LEAF_COUNT_THRESHOLD
          The default threshold for making a leaf node based on count.
static int DEFAULT_MAX_DEPTH
          The default maximum depth to grow the tree to.
protected  int leafCountThreshold
          The threshold for making a node a leaf, determined by how many instances fall in the threshold.
protected  int maxDepth
          The maximum depth for the tree.
protected  BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner
          The learning algorithm for the regression function.
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
deciderLearner
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
RegressionTreeLearner()
          Creates a new instance of RegressionTreeLearner
RegressionTreeLearner(DeciderLearner<? super InputType,Double,?,?> deciderLearner, BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner)
          Creates a new instance of CategorizationTreeLearner.
RegressionTreeLearner(DeciderLearner<? super InputType,Double,?,?> deciderLearner, BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner, int leafCountThreshold, int maxDepth)
          Creates a new instance of CategorizationTreeLearner.
 
Method Summary
 int getLeafCountThreshold()
          Gets the leaf count threshold, which determines the number of elements at which to learn a regression function.
 int getMaxDepth()
          Gets the maximum depth to grow the tree.
 BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> getRegressionLearner()
          Gets the regression learner that is to be used to fit a function approximator to the values in the tree.
 RegressionTree<InputType> learn(Collection<? extends InputOutputPair<? extends InputType,Double>> data)
          The learn method creates an object of ResultType using data of type DataType, using some form of "learning" algorithm.
protected  RegressionTreeNode<InputType,?> learnNode(Collection<? extends InputOutputPair<? extends InputType,Double>> data, AbstractDecisionTreeNode<InputType,Double,?> parent)
          Recursively learns the regression tree using the given collection of data, returning the created node.
 void setLeafCountThreshold(int leafCountThreshold)
          Sets the leaf count threshold, which determines the number of elements at which to learn a regression function.
 void setMaxDepth(int maxDepth)
          Sets the maximum depth to grow the tree.
 void setRegressionLearner(BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner)
          Sets the regression learner that is to be used to fit a function approximator to the values in the tree.
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
areAllOutputsEqual, getDeciderLearner, learnChildNodes, setDeciderLearner, splitData
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, clone, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.util.CloneableSerializable
clone
 

Field Detail

DEFAULT_LEAF_COUNT_THRESHOLD

public static final int DEFAULT_LEAF_COUNT_THRESHOLD
The default threshold for making a leaf node based on count.

See Also:
Constant Field Values

DEFAULT_MAX_DEPTH

public static final int DEFAULT_MAX_DEPTH
The default maximum depth to grow the tree to.

See Also:
Constant Field Values

regressionLearner

protected BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner
The learning algorithm for the regression function.


leafCountThreshold

protected int leafCountThreshold
The threshold for making a node a leaf, determined by how many instances fall in the threshold.


maxDepth

protected int maxDepth
The maximum depth for the tree. Ignored if less than 1.

Constructor Detail

RegressionTreeLearner

public RegressionTreeLearner()
Creates a new instance of RegressionTreeLearner


RegressionTreeLearner

public RegressionTreeLearner(DeciderLearner<? super InputType,Double,?,?> deciderLearner,
                             BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner)
Creates a new instance of CategorizationTreeLearner.

Parameters:
deciderLearner - The learner for the decision function.
regressionLearner - The learner for the regression function.

RegressionTreeLearner

public RegressionTreeLearner(DeciderLearner<? super InputType,Double,?,?> deciderLearner,
                             BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner,
                             int leafCountThreshold,
                             int maxDepth)
Creates a new instance of CategorizationTreeLearner.

Parameters:
deciderLearner - The learner for the decision function.
regressionLearner - The learner for the regression function.
leafCountThreshold - The leaf count threshold, which determines the number of elements at which to learn a regression function.
maxDepth - The maximum depth to learn the tree. Must be positive.
Method Detail

learn

public RegressionTree<InputType> learn(Collection<? extends InputOutputPair<? extends InputType,Double>> data)
Description copied from interface: BatchLearner
The learn method creates an object of ResultType using data of type DataType, using some form of "learning" algorithm.

Specified by:
learn in interface BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,RegressionTree<InputType>>
Parameters:
data - The data that the learning algorithm will use to create an object of ResultType.
Returns:
The object that is created based on the given data using the learning algorithm.

learnNode

protected RegressionTreeNode<InputType,?> learnNode(Collection<? extends InputOutputPair<? extends InputType,Double>> data,
                                                    AbstractDecisionTreeNode<InputType,Double,?> parent)
Recursively learns the regression tree using the given collection of data, returning the created node.

Specified by:
learnNode in class AbstractDecisionTreeLearner<InputType,Double>
Parameters:
data - The set of data to learn a node from.
parent - The parent node.
Returns:
The regression tree node learned from the given data.

getRegressionLearner

public BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> getRegressionLearner()
Gets the regression learner that is to be used to fit a function approximator to the values in the tree.

Returns:
The regression learner.

setRegressionLearner

public void setRegressionLearner(BatchLearner<Collection<? extends InputOutputPair<? extends InputType,Double>>,? extends Evaluator<? super InputType,Double>> regressionLearner)
Sets the regression learner that is to be used to fit a function approximator to the values in the tree.

Parameters:
regressionLearner - The regression learner.

getLeafCountThreshold

public int getLeafCountThreshold()
Gets the leaf count threshold, which determines the number of elements at which to learn a regression function.

Returns:
The leaf count threshold.

setLeafCountThreshold

public void setLeafCountThreshold(int leafCountThreshold)
Sets the leaf count threshold, which determines the number of elements at which to learn a regression function.

Parameters:
leafCountThreshold - The leaf count threshold. Must be non-negative.

getMaxDepth

public int getMaxDepth()
Gets the maximum depth to grow the tree.

Returns:
The maximum depth to grow the tree. Zero or less means no maximum depth.

setMaxDepth

public void setMaxDepth(int maxDepth)
Sets the maximum depth to grow the tree.

Parameters:
maxDepth - The maximum depth to grow the tree. Zero or less means no maximum depth.