|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectgov.sandia.cognition.util.AbstractCloneableSerializable
gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<DataType,HiddenMarkovModel<ObservationType>>
gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm<ObservationType,Collection<? extends ObservationType>>
gov.sandia.cognition.learning.algorithm.hmm.BaumWelchAlgorithm<ObservationType>
ObservationType
- Type of Observations handled by the HMM.@PublicationReference(author="Lawrence R. Rabiner", title="A tutorial on hidden Markov models and selected applications in speech recognition", type=Journal, year=1989, publication="Proceedings of the IEEE", pages={257,286}, url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf", notes="Rabiner\'s transition matrix is transposed from mine.") public class BaumWelchAlgorithm<ObservationType>
Implements the Baum-Welch algorithm, also known as the "forward-backward algorithm", the expectation-maximization algorithm, etc for Hidden Markov Models (HMMs). This is the standard learning algorithm for HMMs. This implementation allows for multiple sequences using the MultiCollection interface.
Field Summary | |
---|---|
protected MultiCollection<? extends ObservationType> |
multicollection
The multi-collection of sequences |
protected ArrayList<Vector> |
sequenceGammas
The list of all gammas from each sequence |
Fields inherited from class gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm |
---|
DEFAULT_MAX_ITERATIONS, DEFAULT_REESTIMATE_INITIAL_PROBABILITY, distributionLearner, initialGuess, lastLogLikelihood, PERFORMANCE_NAME, reestimateInitialProbabilities, result |
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner |
---|
data, keepGoing |
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm |
---|
maxIterations |
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm |
---|
DEFAULT_ITERATION, iteration |
Constructor Summary | |
---|---|
BaumWelchAlgorithm()
Creates a new instance of BaumWelchAlgorithm |
|
BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess,
BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner,
boolean reestimateInitialProbabilities)
Creates a new instance of BaumWelchAlgorithm |
Method Summary | |
---|---|
protected void |
cleanupAlgorithm()
Called to clean up the learning algorithm's state after learning has finished. |
BaumWelchAlgorithm<ObservationType> |
clone()
This makes public the clone method on the Object class and
removes the exception that it throws. |
protected Pair<ArrayList<ArrayList<Vector>>,ArrayList<Matrix>> |
computeSequenceParameters()
Computes the gammas and A matrices for each sequence. |
protected boolean |
initializeAlgorithm()
Called to initialize the learning algorithm's state based on the data that is stored in the data field. |
HiddenMarkovModel<ObservationType> |
learn(MultiCollection<ObservationType> data)
Allows the algorithm to learn against multiple sequences of data. |
protected boolean |
step()
Called to take a single step of the learning algorithm. |
protected Vector |
updateInitialProbabilities(ArrayList<Vector> firstGammas)
Updates the initial probabilities from sequenceGammas |
protected ArrayList<ProbabilityFunction<ObservationType>> |
updateProbabilityFunctions(ArrayList<Vector> sequenceGammas)
Updates the probability function from the concatenated gammas from all sequences |
protected double |
updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hmm)
Updates the internal sequence likelihoods for the given HMM |
protected Matrix |
updateTransitionMatrix(ArrayList<Matrix> sequenceTransitionMatrices)
Computes an updated transition matrix from the scaled estimates |
Methods inherited from class gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm |
---|
getDistributionLearner, getInitialGuess, getLastLogLikelihood, getPerformance, getReestimateInitialProbabilities, getResult, setDistributionLearner, setInitialGuess, setReestimateInitialProbabilities |
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner |
---|
getData, getKeepGoing, learn, setData, setKeepGoing, stop |
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm |
---|
getMaxIterations, isResultValid, setMaxIterations |
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm |
---|
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners |
Methods inherited from class java.lang.Object |
---|
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm |
---|
getMaxIterations, setMaxIterations |
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm |
---|
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener |
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm |
---|
isResultValid |
Field Detail |
---|
protected transient MultiCollection<? extends ObservationType> multicollection
protected transient ArrayList<Vector> sequenceGammas
Constructor Detail |
---|
public BaumWelchAlgorithm()
public BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner, boolean reestimateInitialProbabilities)
initialGuess
- Initial guess for the iterations.distributionLearner
- Learner for the Probability Functions of the HMM.reestimateInitialProbabilities
- Flag to re-estimate the initial probability Vector.Method Detail |
---|
public BaumWelchAlgorithm<ObservationType> clone()
AbstractCloneableSerializable
Object
class and
removes the exception that it throws. Its default behavior is to
automatically create a clone of the exact type of object that the
clone is called on and to copy all primitives but to keep all references,
which means it is a shallow copy.
Extensions of this class may want to override this method (but call
super.clone()
to implement a "smart copy". That is, to target
the most common use case for creating a copy of the object. Because of
the default behavior being a shallow copy, extending classes only need
to handle fields that need to have a deeper copy (or those that need to
be reset). Some of the methods in ObjectUtil
may be helpful in
implementing a custom clone method.
Note: The contract of this method is that you must use
super.clone()
as the basis for your implementation.
clone
in interface CloneableSerializable
clone
in class AbstractBaumWelchAlgorithm<ObservationType,Collection<? extends ObservationType>>
public HiddenMarkovModel<ObservationType> learn(MultiCollection<ObservationType> data)
data
- Multiple sequences of data against which to train.
protected boolean initializeAlgorithm()
AbstractAnytimeBatchLearner
initializeAlgorithm
in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>
protected boolean step()
AbstractAnytimeBatchLearner
step
in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>
protected void cleanupAlgorithm()
AbstractAnytimeBatchLearner
cleanupAlgorithm
in class AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,HiddenMarkovModel<ObservationType>>
protected Pair<ArrayList<ArrayList<Vector>>,ArrayList<Matrix>> computeSequenceParameters()
protected ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> sequenceGammas)
sequenceGammas
- Concatenated gammas from all sequences
protected Matrix updateTransitionMatrix(ArrayList<Matrix> sequenceTransitionMatrices)
sequenceTransitionMatrices
- Scaled estimates from each sequence
protected Vector updateInitialProbabilities(ArrayList<Vector> firstGammas)
firstGammas
- The first gamma of the each sequence
protected double updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hmm)
hmm
- Hidden Markov model to consider
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |