gov.sandia.cognition.learning.algorithm.regression
Class LogisticRegression

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends InputType,OutputType>>,ResultType>
                  extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner<Vectorizable,Double,LogisticRegression.Function>
                      extended by gov.sandia.cognition.learning.algorithm.regression.LogisticRegression
All Implemented Interfaces:
AnytimeAlgorithm<LogisticRegression.Function>, IterativeAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>, BatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>, SupervisedBatchLearner<Vectorizable,Double,LogisticRegression.Function>, CloneableSerializable, Serializable, Cloneable

@PublicationReferences(references={@PublicationReference(author="Tommi S. Jaakkola",title="Machine learning: lecture 5",type=WebPage,year=2004,url="http://www.ai.mit.edu/courses/6.867-f04/lectures/lecture-5-ho.pdf",notes="Good formulation of logistic regression on slides 15-20"),@PublicationReference(author={"Paul Komarek","Andrew Moore"},title="Making Logistic Regression A Core Data Mining Tool With TR-IRLS",publication="Proceedings of the 5th International Conference on Data Mining Machine Learning",type=Conference,year=2005,url="http://www.autonlab.org/autonweb/14717.html",notes="Good practical overview of logistic regression"),@PublicationReference(author="Christopher M. Bishop",title="Pattern Recognition and Machine Learning",type=Book,year=2006,pages={207,208},notes="Section 4.3.3")})
public class LogisticRegression
extends AbstractAnytimeSupervisedBatchLearner<Vectorizable,Double,LogisticRegression.Function>

Performs Logistic Regression by means of the iterative reweighted least squares (IRLS) algorithm, where the logistic function has an explicit bias term, and a diagonal L2 regularization term. When the regularization term is zero, this is equivalent to unregularized regression. The targets for the data should be probabilities, [0,1].

Since:
2.0
Author:
Kevin R. Dixon
See Also:
Serialized Form

Nested Class Summary
static class LogisticRegression.Function
          Class that is a linear discriminant, followed by a sigmoid function.
 
Field Summary
static int DEFAULT_MAX_ITERATIONS
          Default number of iterations before stopping, 100
static double DEFAULT_REGULARIZATION
          Default regularization, 0.0.
static double DEFAULT_TOLERANCE
          Default tolerance change in weights before stopping, 1.0E-10
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
LogisticRegression()
          Default constructor, with no regularization.
LogisticRegression(double regularization)
          Creates a new instance of LogisticRegression
LogisticRegression(double regularization, double tolerance, int maxIterations)
          Creates a new instance of LogisticRegression
 
Method Summary
protected  void cleanupAlgorithm()
          Called to clean up the learning algorithm's state after learning has finished.
 LogisticRegression clone()
          This makes public the clone method on the Object class and removes the exception that it throws.
 LogisticRegression.Function getObjectToOptimize()
          Getter for objectToOptimize
 double getRegularization()
          Getter for regularization
 LogisticRegression.Function getResult()
          Gets the current result of the algorithm.
 double getTolerance()
          Getter for tolerance
protected  boolean initializeAlgorithm()
          Called to initialize the learning algorithm's state based on the data that is stored in the data field.
 void setObjectToOptimize(LogisticRegression.Function objectToOptimize)
          Setter for objectToOptimize
 void setRegularization(double regularization)
          Setter for regularization
 void setResult(LogisticRegression.Function result)
          Setter for result
 void setTolerance(double tolerance)
          Setter for tolerance
protected  boolean step()
          Called to take a single step of the learning algorithm.
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.learning.algorithm.BatchLearner
learn
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid
 

Field Detail

DEFAULT_MAX_ITERATIONS

public static final int DEFAULT_MAX_ITERATIONS
Default number of iterations before stopping, 100

See Also:
Constant Field Values

DEFAULT_TOLERANCE

public static final double DEFAULT_TOLERANCE
Default tolerance change in weights before stopping, 1.0E-10

See Also:
Constant Field Values

DEFAULT_REGULARIZATION

public static final double DEFAULT_REGULARIZATION
Default regularization, 0.0.

See Also:
Constant Field Values
Constructor Detail

LogisticRegression

public LogisticRegression()
Default constructor, with no regularization.


LogisticRegression

public LogisticRegression(double regularization)
Creates a new instance of LogisticRegression

Parameters:
regularization - L2 ridge regularization term, must be nonnegative, a value of zero is equivalent to unregularized regression.

LogisticRegression

public LogisticRegression(double regularization,
                          double tolerance,
                          int maxIterations)
Creates a new instance of LogisticRegression

Parameters:
regularization - L2 ridge regularization term, must be nonnegative, a value of zero is equivalent to unregularized regression.
tolerance - Tolerance change in weights before stopping
maxIterations - Maximum number of iterations before stopping
Method Detail

clone

public LogisticRegression clone()
Description copied from class: AbstractCloneableSerializable
This makes public the clone method on the Object class and removes the exception that it throws. Its default behavior is to automatically create a clone of the exact type of object that the clone is called on and to copy all primitives but to keep all references, which means it is a shallow copy. Extensions of this class may want to override this method (but call super.clone() to implement a "smart copy". That is, to target the most common use case for creating a copy of the object. Because of the default behavior being a shallow copy, extending classes only need to handle fields that need to have a deeper copy (or those that need to be reset). Some of the methods in ObjectUtil may be helpful in implementing a custom clone method. Note: The contract of this method is that you must use super.clone() as the basis for your implementation.

Specified by:
clone in interface CloneableSerializable
Overrides:
clone in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>
Returns:
A clone of this object.

initializeAlgorithm

protected boolean initializeAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to initialize the learning algorithm's state based on the data that is stored in the data field. The return value indicates if the algorithm can be run or not based on the initialization.

Specified by:
initializeAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>
Returns:
True if the learning algorithm can be run and false if it cannot.

step

protected boolean step()
Description copied from class: AbstractAnytimeBatchLearner
Called to take a single step of the learning algorithm.

Specified by:
step in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>
Returns:
True if another step can be taken and false it the algorithm should halt.

cleanupAlgorithm

protected void cleanupAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to clean up the learning algorithm's state after learning has finished.

Specified by:
cleanupAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends InputOutputPair<? extends Vectorizable,Double>>,LogisticRegression.Function>

getObjectToOptimize

public LogisticRegression.Function getObjectToOptimize()
Getter for objectToOptimize

Returns:
The object to optimize, used as a factory on successive runs of the algorithm.

setObjectToOptimize

public void setObjectToOptimize(LogisticRegression.Function objectToOptimize)
Setter for objectToOptimize

Parameters:
objectToOptimize - The object to optimize, used as a factory on successive runs of the algorithm.

getResult

public LogisticRegression.Function getResult()
Description copied from interface: AnytimeAlgorithm
Gets the current result of the algorithm.

Returns:
Current result of the algorithm.

setResult

public void setResult(LogisticRegression.Function result)
Setter for result

Parameters:
result - Return value from the algorithm

getTolerance

public double getTolerance()
Getter for tolerance

Returns:
Tolerance change in weights before stopping, must be nonnegative.

setTolerance

public void setTolerance(double tolerance)
Setter for tolerance

Parameters:
tolerance - Tolerance change in weights before stopping, must be nonnegative.

getRegularization

public double getRegularization()
Getter for regularization

Returns:
L2 ridge regularization term, must be nonnegative, a value of zero is equivalent to unregularized regression.

setRegularization

public void setRegularization(double regularization)
Setter for regularization

Parameters:
regularization - L2 ridge regularization term, must be nonnegative, a value of zero is equivalent to unregularized regression.