gov.sandia.cognition.statistics.bayesian
Class DirichletProcessMixtureModel<ObservationType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,DataDistribution<ParameterType>>
                  extended by gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
                      extended by gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel<ObservationType>
Type Parameters:
ObservationType - Type of observations handled by the mixture model
All Implemented Interfaces:
AnytimeAlgorithm<DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, IterativeAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, BatchLearner<Collection<? extends ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, BayesianEstimator<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, MarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>, CloneableSerializable, Randomized, Serializable, Cloneable
Direct Known Subclasses:
ParallelDirichletProcessMixtureModel

@PublicationReferences(references={@PublicationReference(author="Radform M. Neal",title="Markov Chain Sampling Methods for Dirichlet Process Mixture Models",type=Journal,year=2000,publication="Journal of Computational and Graphical Statistics, Vol. 9, No. 2",pages={249,265},notes="Based in part on Algorithm 2 from Neal"),@PublicationReference(author={"Michael D. Escobar","Mike West"},title="Bayesian Density Estimation and Inference Using Mixtures",type=Journal,publication="Journal of the American Statistical Association",year=1995)})
public class DirichletProcessMixtureModel<ObservationType>
extends AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>

An implementation of Dirichlet Process clustering, which estimates the number of clusters and the centroids of the clusters from a set of data.

Since:
3.0
Author:
Kevin R. Dixon
See Also:
Serialized Form

Nested Class Summary
static class DirichletProcessMixtureModel.DPMMCluster<ObservationType>
          Cluster for a step in the DPMM
protected static class DirichletProcessMixtureModel.DPMMLogConditional
          Container for the log conditional likelihood
static class DirichletProcessMixtureModel.MultivariateMeanCovarianceUpdater
          Updater that creates specified clusters with distinct means and covariances
static class DirichletProcessMixtureModel.MultivariateMeanUpdater
          Updater that creates specified clusters with identical covariances
static class DirichletProcessMixtureModel.Sample<ObservationType>
          A sample from the Dirichlet Process Mixture Model.
static interface DirichletProcessMixtureModel.Updater<ObservationType>
          Updater for the DPMM
 
Field Summary
protected  GammaDistribution alphaInverseSampler
          Samples a new alpha-inverse.
protected  double[] clusterWeights
          Holds the cluster weights so that we don't have to re-allocate them each mcmcUpdate step.
protected  ProbabilityFunction<ObservationType> conditionalPriorPredictive
          Base predictive distribution that determines the value of the new cluster weighting during the Gibbs sampling.
static double DEFAULT_ALPHA
          Default concentration parameter of the Dirichlet Process, 1.0.
static int DEFAULT_NUM_INITIAL_CLUSTERS
          Default number of initial clusters
static boolean DEFAULT_REESTIMATE_ALPHA
           
protected  BetaDistribution etaSampler
          Creates a new value of "eta" which, in turn, helps sample a new alpha.
protected  double initialAlpha
          Initial value of alpha, the concentration parameter of the Dirichlet Process
protected  boolean reestimateAlpha
          Flag to automatically re-estimate the alpha parameter
protected  DirichletProcessMixtureModel.Updater<ObservationType> updater
          Creates the clusters and predictive prior distributions
 
Fields inherited from class gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
currentParameter, DEFAULT_NUM_SAMPLES, previousParameter, random
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
DirichletProcessMixtureModel()
          Creates a new instance of DirichletProcessMixtureModel
 
Method Summary
protected  ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
          Assigns observations to each of the K clusters, plus the as-yet-uncreated new cluster
protected  int assignObservationToCluster(ObservationType observation, double[] weights, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
          Probabilistically assigns an observation to a cluster
 DirichletProcessMixtureModel<ObservationType> clone()
          This makes public the clone method on the Object class and removes the exception that it throws.
protected  DirichletProcessMixtureModel.DPMMCluster<ObservationType> createCluster(Collection<ObservationType> clusterAssignment, DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
          Creates a cluster from the given cluster assignment
 DirichletProcessMixtureModel.Sample<ObservationType> createInitialLearnedObject()
          Creates the initial parameters from which to start the Markov chain.
 double getInitialAlpha()
          Getter for initialAlpha
 int getNumInitialClusters()
          Getter for numInitialClusters
 boolean getReestimateAlpha()
          Getter for reestimateAlpha
 DirichletProcessMixtureModel.Updater<ObservationType> getUpdater()
          Getter for updater
protected  void mcmcUpdate()
          Performs a valid MCMC update step.
 void setInitialAlpha(double initialAlpha)
          Setter for initialAlpha
 void setNumInitialClusters(int numInitialClusters)
          Getter for numInitialClusters
 void setReestimateAlpha(boolean reestimateAlpha)
          Setter for reestimateAlpha
 void setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
          Setter for updater
protected  double updateAlpha(double alpha, int numObservations)
          Runs the Gibbs sampler for the concentration parameter, alpha, given the data.
protected  ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments)
          Update each cluster according to the data assigned to it
 
Methods inherited from class gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
cleanupAlgorithm, getBurnInIterations, getCurrentParameter, getIterationsPerSample, getPreviousParameter, getRandom, getResult, initializeAlgorithm, setBurnInIterations, setCurrentParameter, setIterationsPerSample, setRandom, setResult, step
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.learning.algorithm.BatchLearner
learn
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid, stop
 

Field Detail

DEFAULT_ALPHA

public static final double DEFAULT_ALPHA
Default concentration parameter of the Dirichlet Process, 1.0.

See Also:
Constant Field Values

DEFAULT_NUM_INITIAL_CLUSTERS

public static final int DEFAULT_NUM_INITIAL_CLUSTERS
Default number of initial clusters

See Also:
Constant Field Values

DEFAULT_REESTIMATE_ALPHA

public static final boolean DEFAULT_REESTIMATE_ALPHA
See Also:
Constant Field Values

updater

protected DirichletProcessMixtureModel.Updater<ObservationType> updater
Creates the clusters and predictive prior distributions


reestimateAlpha

protected boolean reestimateAlpha
Flag to automatically re-estimate the alpha parameter


initialAlpha

protected double initialAlpha
Initial value of alpha, the concentration parameter of the Dirichlet Process


conditionalPriorPredictive

protected transient ProbabilityFunction<ObservationType> conditionalPriorPredictive
Base predictive distribution that determines the value of the new cluster weighting during the Gibbs sampling.


clusterWeights

protected transient double[] clusterWeights
Holds the cluster weights so that we don't have to re-allocate them each mcmcUpdate step.


etaSampler

protected transient BetaDistribution etaSampler
Creates a new value of "eta" which, in turn, helps sample a new alpha.


alphaInverseSampler

protected transient GammaDistribution alphaInverseSampler
Samples a new alpha-inverse.

Constructor Detail

DirichletProcessMixtureModel

public DirichletProcessMixtureModel()
Creates a new instance of DirichletProcessMixtureModel

Method Detail

clone

public DirichletProcessMixtureModel<ObservationType> clone()
Description copied from class: AbstractCloneableSerializable
This makes public the clone method on the Object class and removes the exception that it throws. Its default behavior is to automatically create a clone of the exact type of object that the clone is called on and to copy all primitives but to keep all references, which means it is a shallow copy. Extensions of this class may want to override this method (but call super.clone() to implement a "smart copy". That is, to target the most common use case for creating a copy of the object. Because of the default behavior being a shallow copy, extending classes only need to handle fields that need to have a deeper copy (or those that need to be reset). Some of the methods in ObjectUtil may be helpful in implementing a custom clone method. Note: The contract of this method is that you must use super.clone() as the basis for your implementation.

Specified by:
clone in interface CloneableSerializable
Overrides:
clone in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
Returns:
A clone of this object.

mcmcUpdate

protected void mcmcUpdate()
Description copied from class: AbstractMarkovChainMonteCarlo
Performs a valid MCMC update step. That is, the function is expected to modify the currentParameter member.

Specified by:
mcmcUpdate in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>

updateClusters

protected ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments)
Update each cluster according to the data assigned to it

Parameters:
clusterAssignments - Observations assigned to each cluster
Returns:
Cluster that contains an update parameter estimate and weighted by the number of observations assigned to the cluster

assignObservationsToClusters

protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K,
                                                                              DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Assigns observations to each of the K clusters, plus the as-yet-uncreated new cluster

Parameters:
K - Number of clusters
logConditional - The log of the conditional.
Returns:
Assignments from observations to clusters

assignObservationToCluster

protected int assignObservationToCluster(ObservationType observation,
                                         double[] weights,
                                         DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Probabilistically assigns an observation to a cluster

Parameters:
observation - Observation that we're assigning
weights - Place holder for the weights that this method will create
logConditional - The log of the conditional.
Returns:
Index of the cluster to assign the observation to. This will be [0,K-1] for an existing cluster and "K" for an as-yet-undecided new cluster.

createCluster

protected DirichletProcessMixtureModel.DPMMCluster<ObservationType> createCluster(Collection<ObservationType> clusterAssignment,
                                                                                  DirichletProcessMixtureModel.Updater<ObservationType> localUpdater)
Creates a cluster from the given cluster assignment

Parameters:
clusterAssignment - Observations assigned to a particular cluster
localUpdater - Updater that recomputes the cluster parameters, needed to ensure thread safety in the parallel implementation
Returns:
Cluster that contains an update parameter estimate and weighted by the number of observations assigned to the cluster

updateAlpha

protected double updateAlpha(double alpha,
                             int numObservations)
Runs the Gibbs sampler for the concentration parameter, alpha, given the data.

Parameters:
alpha - Current value of the concentration parameter
numObservations - Number of observations we're sampling over
Returns:
Updated estimate of alpha

createInitialLearnedObject

public DirichletProcessMixtureModel.Sample<ObservationType> createInitialLearnedObject()
Description copied from class: AbstractMarkovChainMonteCarlo
Creates the initial parameters from which to start the Markov chain.

Specified by:
createInitialLearnedObject in class AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
Returns:
initial parameters from which to start the Markov chain.

getUpdater

public DirichletProcessMixtureModel.Updater<ObservationType> getUpdater()
Getter for updater

Returns:
Creates the clusters and predictive prior distributions

setUpdater

public void setUpdater(DirichletProcessMixtureModel.Updater<ObservationType> updater)
Setter for updater

Parameters:
updater - Creates the clusters and predictive prior distributions

getNumInitialClusters

public int getNumInitialClusters()
Getter for numInitialClusters

Returns:
Number of clusters to initialize

setNumInitialClusters

public void setNumInitialClusters(int numInitialClusters)
Getter for numInitialClusters

Parameters:
numInitialClusters - Number of clusters to initialize

getReestimateAlpha

public boolean getReestimateAlpha()
Getter for reestimateAlpha

Returns:
Flag to automatically re-estimate the alpha parameter

setReestimateAlpha

public void setReestimateAlpha(boolean reestimateAlpha)
Setter for reestimateAlpha

Parameters:
reestimateAlpha - Flag to automatically re-estimate the alpha parameter

getInitialAlpha

public double getInitialAlpha()
Getter for initialAlpha

Returns:
Initial value of alpha, the concentration parameter of the Dirichlet Process

setInitialAlpha

public void setInitialAlpha(double initialAlpha)
Setter for initialAlpha

Parameters:
initialAlpha - Initial value of alpha, the concentration parameter of the Dirichlet Process