gov.sandia.cognition.learning.algorithm.tree
Class CategorizationTreeLearner<InputType,OutputType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner<InputType,OutputType>
              extended by gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner<InputType,OutputType>
Type Parameters:
InputType - The input type for the tree.
OutputType - The output type for the tree.
All Implemented Interfaces:
IterativeAlgorithm, BatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,CategorizationTree<InputType,OutputType>>, SupervisedBatchLearner<InputType,OutputType,CategorizationTree<InputType,OutputType>>, CloneableSerializable, Serializable, Cloneable

public class CategorizationTreeLearner<InputType,OutputType>
extends AbstractDecisionTreeLearner<InputType,OutputType>
implements SupervisedBatchLearner<InputType,OutputType,CategorizationTree<InputType,OutputType>>

The CategorizationTreeLearner class implements a supervised learning algorithm for learning a categorization tree.

Since:
2.0
Author:
Justin Basilico
See Also:
Serialized Form

Field Summary
static int DEFAULT_LEAF_COUNT_THRESHOLD
          The default threshold for making a leaf node based on count.
static int DEFAULT_MAX_DEPTH
          The default maximum depth to grow the tree to.
protected  int leafCountThreshold
          The threshold for making a node a leaf, determined by how many instances fall in the threshold.
protected  int maxDepth
          The maximum depth for the tree.
protected  Map<OutputType,Double> priors
          Prior probabilities for the different categories.
protected  Map<OutputType,Integer> trainCounts
          How often each category appears in training data.
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
deciderLearner
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
CategorizationTreeLearner()
          Creates a new instance of CategorizationTreeLearner.
CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner)
          Creates a new instance of CategorizationTreeLearner.
CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner, int leafCountThreshold, int maxDepth)
          Creates a new instance of CategorizationTreeLearner.
CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner, int leafCountThreshold, int maxDepth, Map<OutputType,Double> priors)
          Creates a new instance of CategorizationTreeLearner.
 
Method Summary
 int getLeafCountThreshold()
          Gets the leaf count threshold, which determines the number of elements at which to make an element into a leaf.
 int getMaxDepth()
          Gets the maximum depth to grow the tree.
static
<OutputType>
DefaultDataDistribution<OutputType>
getOutputCounts(Collection<? extends InputOutputPair<?,OutputType>> data)
          Creates a histogram of values based on the output values in the given collection of pairs.
 CategorizationTree<InputType,OutputType> learn(Collection<? extends InputOutputPair<? extends InputType,OutputType>> data)
          The learn method creates an object of ResultType using data of type DataType, using some form of "learning" algorithm.
protected  CategorizationTreeNode<InputType,OutputType,?> learnNode(Collection<? extends InputOutputPair<? extends InputType,OutputType>> data, AbstractDecisionTreeNode<InputType,OutputType,?> parent)
          Recursively learns the categorization tree using the given collection of data, returning the created node.
 void setCategoryPriors(Map<OutputType,Double> priors)
          Set prior category probabilities.
 void setLeafCountThreshold(int leafCountThreshold)
          Sets the leaf count threshold, which determines the number of elements at which to make an element into a leaf.
 void setMaxDepth(int maxDepth)
          Sets the maximum depth to grow the tree.
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
areAllOutputsEqual, getDeciderLearner, learnChildNodes, setDeciderLearner, splitData
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, clone, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.util.CloneableSerializable
clone
 

Field Detail

DEFAULT_LEAF_COUNT_THRESHOLD

public static final int DEFAULT_LEAF_COUNT_THRESHOLD
The default threshold for making a leaf node based on count.

See Also:
Constant Field Values

DEFAULT_MAX_DEPTH

public static final int DEFAULT_MAX_DEPTH
The default maximum depth to grow the tree to.

See Also:
Constant Field Values

leafCountThreshold

protected int leafCountThreshold
The threshold for making a node a leaf, determined by how many instances fall in the threshold.


maxDepth

protected int maxDepth
The maximum depth for the tree. Ignored if less than 1.


priors

protected Map<OutputType,Double> priors
Prior probabilities for the different categories. If null, the priors default to the category frequencies of the training data.


trainCounts

protected Map<OutputType,Integer> trainCounts
How often each category appears in training data.

Constructor Detail

CategorizationTreeLearner

public CategorizationTreeLearner()
Creates a new instance of CategorizationTreeLearner.


CategorizationTreeLearner

public CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner)
Creates a new instance of CategorizationTreeLearner.

Parameters:
deciderLearner - The learner for the decision function

CategorizationTreeLearner

public CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner,
                                 int leafCountThreshold,
                                 int maxDepth)
Creates a new instance of CategorizationTreeLearner.

Parameters:
deciderLearner - The learner for the decision function.
leafCountThreshold - The leaf count threshold. Must be non-negative.
maxDepth - The maximum depth for the tree.

CategorizationTreeLearner

public CategorizationTreeLearner(DeciderLearner<? super InputType,OutputType,?,?> deciderLearner,
                                 int leafCountThreshold,
                                 int maxDepth,
                                 Map<OutputType,Double> priors)
Creates a new instance of CategorizationTreeLearner.

Parameters:
deciderLearner - The learner for the decision function.
leafCountThreshold - The leaf count threshold. Must be non-negative.
maxDepth - The maximum depth for the tree.
priors - Prior probabilities for categories. (See setCategoryPriors().)
Method Detail

learn

public CategorizationTree<InputType,OutputType> learn(Collection<? extends InputOutputPair<? extends InputType,OutputType>> data)
Description copied from interface: BatchLearner
The learn method creates an object of ResultType using data of type DataType, using some form of "learning" algorithm.

Specified by:
learn in interface BatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,CategorizationTree<InputType,OutputType>>
Parameters:
data - The data that the learning algorithm will use to create an object of ResultType.
Returns:
The object that is created based on the given data using the learning algorithm.

learnNode

protected CategorizationTreeNode<InputType,OutputType,?> learnNode(Collection<? extends InputOutputPair<? extends InputType,OutputType>> data,
                                                                   AbstractDecisionTreeNode<InputType,OutputType,?> parent)
Recursively learns the categorization tree using the given collection of data, returning the created node.

Specified by:
learnNode in class AbstractDecisionTreeLearner<InputType,OutputType>
Parameters:
data - The set of data to learn a node from.
parent - The parent node.
Returns:
The categorization tree node learned from the given data.

getOutputCounts

public static <OutputType> DefaultDataDistribution<OutputType> getOutputCounts(Collection<? extends InputOutputPair<?,OutputType>> data)
Creates a histogram of values based on the output values in the given collection of pairs.

Type Parameters:
OutputType - The type of the outputs to count over.
Parameters:
data - The data to create the output count histogram for.
Returns:
The output count histogram.

getLeafCountThreshold

public int getLeafCountThreshold()
Gets the leaf count threshold, which determines the number of elements at which to make an element into a leaf.

Returns:
The leaf count threshold.

setLeafCountThreshold

public void setLeafCountThreshold(int leafCountThreshold)
Sets the leaf count threshold, which determines the number of elements at which to make an element into a leaf.

Parameters:
leafCountThreshold - The leaf count threshold. Must be non-negative.

getMaxDepth

public int getMaxDepth()
Gets the maximum depth to grow the tree.

Returns:
The maximum depth to grow the tree. Zero or less means no maximum depth.

setMaxDepth

public void setMaxDepth(int maxDepth)
Sets the maximum depth to grow the tree.

Parameters:
maxDepth - The maximum depth to grow the tree. Zero or less means no maximum depth.

setCategoryPriors

public void setCategoryPriors(Map<OutputType,Double> priors)

Set prior category probabilities. A higher prior probability for a category will cause the tree learner to weight examples from that category more highly.

If the priors are not manually specified (through this method or passing priors into the constructor), prior probabilities default to the frequencies of the different categories in the training data.

Parameters:
priors - If null, use default prior probabilities. Otherwise, priors becomes the new prior weights. In the latter case, priors.keySet() contain the same values as the possible categories in data passed to the learn() method.