gov.sandia.cognition.learning.algorithm.hmm
Class BaumWelchAlgorithm<ObservationType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<DataType,HiddenMarkovModel<ObservationType>>
                  extended by gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm<ObservationType,Collection<? extends ObservationType>>
                      extended by gov.sandia.cognition.learning.algorithm.hmm.BaumWelchAlgorithm<ObservationType>
Type Parameters:
ObservationType - Type of Observations handled by the HMM.
All Implemented Interfaces:
AnytimeAlgorithm<HiddenMarkovModel<ObservationType>>, IterativeAlgorithm, MeasurablePerformanceAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>, BatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>, CloneableSerializable, Serializable, Cloneable
Direct Known Subclasses:
ParallelBaumWelchAlgorithm

@PublicationReference(author="Lawrence R. Rabiner",
                      title="A tutorial on hidden Markov models and selected applications in speech recognition",
                      type=Journal,
                      year=1989,
                      publication="Proceedings of the IEEE",
                      pages={257,286},
                      url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf",
                      notes="Rabiner\'s transition matrix is transposed from mine.")
public class BaumWelchAlgorithm<ObservationType>
extends AbstractBaumWelchAlgorithm<ObservationType,Collection<? extends ObservationType>>

Implements the Baum-Welch algorithm, also known as the "forward-backward algorithm", the expectation-maximization algorithm, etc for Hidden Markov Models (HMMs). This is the standard learning algorithm for HMMs. This implementation allows for multiple sequences using the MultiCollection interface.

See Also:
Serialized Form

Field Summary
protected  MultiCollection<? extends ObservationType> multicollection
          The multi-collection of sequences
protected  ArrayList<Vector> sequenceGammas
          The list of all gammas from each sequence
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm
DEFAULT_MAX_ITERATIONS, DEFAULT_REESTIMATE_INITIAL_PROBABILITY, distributionLearner, initialGuess, lastLogLikelihood, PERFORMANCE_NAME, reestimateInitialProbabilities, result
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
BaumWelchAlgorithm()
          Creates a new instance of BaumWelchAlgorithm
BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner, boolean reestimateInitialProbabilities)
          Creates a new instance of BaumWelchAlgorithm
 
Method Summary
protected  void cleanupAlgorithm()
          Called to clean up the learning algorithm's state after learning has finished.
 BaumWelchAlgorithm<ObservationType> clone()
          This makes public the clone method on the Object class and removes the exception that it throws.
protected  Pair<ArrayList<ArrayList<Vector>>,ArrayList<Matrix>> computeSequenceParameters()
          Computes the gammas and A matrices for each sequence.
protected  boolean initializeAlgorithm()
          Called to initialize the learning algorithm's state based on the data that is stored in the data field.
 HiddenMarkovModel<ObservationType> learn(MultiCollection<ObservationType> data)
          Allows the algorithm to learn against multiple sequences of data.
protected  boolean step()
          Called to take a single step of the learning algorithm.
protected  Vector updateInitialProbabilities(ArrayList<Vector> firstGammas)
          Updates the initial probabilities from sequenceGammas
protected  ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> sequenceGammas)
          Updates the probability function from the concatenated gammas from all sequences
protected  double updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hmm)
          Updates the internal sequence likelihoods for the given HMM
protected  Matrix updateTransitionMatrix(ArrayList<Matrix> sequenceTransitionMatrices)
          Computes an updated transition matrix from the scaled estimates
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm
getDistributionLearner, getInitialGuess, getLastLogLikelihood, getPerformance, getReestimateInitialProbabilities, getResult, setDistributionLearner, setInitialGuess, setReestimateInitialProbabilities
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid
 

Field Detail

multicollection

protected transient MultiCollection<? extends ObservationType> multicollection
The multi-collection of sequences


sequenceGammas

protected transient ArrayList<Vector> sequenceGammas
The list of all gammas from each sequence

Constructor Detail

BaumWelchAlgorithm

public BaumWelchAlgorithm()
Creates a new instance of BaumWelchAlgorithm


BaumWelchAlgorithm

public BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess,
                          BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner,
                          boolean reestimateInitialProbabilities)
Creates a new instance of BaumWelchAlgorithm

Parameters:
initialGuess - Initial guess for the iterations.
distributionLearner - Learner for the Probability Functions of the HMM.
reestimateInitialProbabilities - Flag to re-estimate the initial probability Vector.
Method Detail

clone

public BaumWelchAlgorithm<ObservationType> clone()
Description copied from class: AbstractCloneableSerializable
This makes public the clone method on the Object class and removes the exception that it throws. Its default behavior is to automatically create a clone of the exact type of object that the clone is called on and to copy all primitives but to keep all references, which means it is a shallow copy. Extensions of this class may want to override this method (but call super.clone() to implement a "smart copy". That is, to target the most common use case for creating a copy of the object. Because of the default behavior being a shallow copy, extending classes only need to handle fields that need to have a deeper copy (or those that need to be reset). Some of the methods in ObjectUtil may be helpful in implementing a custom clone method. Note: The contract of this method is that you must use super.clone() as the basis for your implementation.

Specified by:
clone in interface CloneableSerializable
Overrides:
clone in class AbstractBaumWelchAlgorithm<ObservationType,Collection<? extends ObservationType>>
Returns:
A clone of this object.

learn

public HiddenMarkovModel<ObservationType> learn(MultiCollection<ObservationType> data)
Allows the algorithm to learn against multiple sequences of data.

Parameters:
data - Multiple sequences of data against which to train.
Returns:
HMM resulting from the locally maximum likelihood estimate of the Baum-Welch algorithm.

initializeAlgorithm

protected boolean initializeAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to initialize the learning algorithm's state based on the data that is stored in the data field. The return value indicates if the algorithm can be run or not based on the initialization.

Specified by:
initializeAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>
Returns:
True if the learning algorithm can be run and false if it cannot.

step

protected boolean step()
Description copied from class: AbstractAnytimeBatchLearner
Called to take a single step of the learning algorithm.

Specified by:
step in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>
Returns:
True if another step can be taken and false it the algorithm should halt.

cleanupAlgorithm

protected void cleanupAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to clean up the learning algorithm's state after learning has finished.

Specified by:
cleanupAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>

computeSequenceParameters

protected Pair<ArrayList<ArrayList<Vector>>,ArrayList<Matrix>> computeSequenceParameters()
Computes the gammas and A matrices for each sequence.

Returns:
Gammas and A matrices for each sequence

updateProbabilityFunctions

protected ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> sequenceGammas)
Updates the probability function from the concatenated gammas from all sequences

Parameters:
sequenceGammas - Concatenated gammas from all sequences
Returns:
Maximum Likelihood probability functions

updateTransitionMatrix

protected Matrix updateTransitionMatrix(ArrayList<Matrix> sequenceTransitionMatrices)
Computes an updated transition matrix from the scaled estimates

Parameters:
sequenceTransitionMatrices - Scaled estimates from each sequence
Returns:
Overall Maximum Likelihood estimate of the transition matrix

updateInitialProbabilities

protected Vector updateInitialProbabilities(ArrayList<Vector> firstGammas)
Updates the initial probabilities from sequenceGammas

Parameters:
firstGammas - The first gamma of the each sequence
Returns:
Updated initial probability Vector for the HMM.

updateSequenceLogLikelihoods

protected double updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hmm)
Updates the internal sequence likelihoods for the given HMM

Parameters:
hmm - Hidden Markov model to consider
Returns:
log likelihood of the observations sequences given the HMM.