gov.sandia.cognition.statistics.bayesian
Class ParallelDirichletProcessMixtureModel<ObservationType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends ObservationType>,DataDistribution<ParameterType>>
                  extended by gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>
                      extended by gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel<ObservationType>
                          extended by gov.sandia.cognition.statistics.bayesian.ParallelDirichletProcessMixtureModel<ObservationType>
Type Parameters:
ObservationType - Type of observations handled by the algorithm
All Implemented Interfaces:
AnytimeAlgorithm<DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, IterativeAlgorithm, ParallelAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, BatchLearner<Collection<? extends ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, BayesianEstimator<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>,DataDistribution<DirichletProcessMixtureModel.Sample<ObservationType>>>, MarkovChainMonteCarlo<ObservationType,DirichletProcessMixtureModel.Sample<ObservationType>>, CloneableSerializable, Randomized, Serializable, Cloneable

public class ParallelDirichletProcessMixtureModel<ObservationType>
extends DirichletProcessMixtureModel<ObservationType>
implements ParallelAlgorithm

A Parallelized version of vanilla Dirichlet Process Mixture Model learning. In particular, this class parallelizes the assignment of observations to clusters and the Gibbs sampling updating of clusters from their constituent observations.

Since:
3.0
Author:
Kevin R. Dixon
See Also:
Serialized Form

Nested Class Summary
protected  class ParallelDirichletProcessMixtureModel.ClusterUpdaterTask
          Tasks that update the values of the clusters for Gibbs sampling
static class ParallelDirichletProcessMixtureModel.DPMMAssignments
          Assignments from the DPMM
protected  class ParallelDirichletProcessMixtureModel.ObservationAssignmentTask
          Task that assign observations to cluster indices
 
Nested classes/interfaces inherited from class gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel
DirichletProcessMixtureModel.DPMMCluster<ObservationType>, DirichletProcessMixtureModel.DPMMLogConditional, DirichletProcessMixtureModel.MultivariateMeanCovarianceUpdater, DirichletProcessMixtureModel.MultivariateMeanUpdater, DirichletProcessMixtureModel.Sample<ObservationType>, DirichletProcessMixtureModel.Updater<ObservationType>
 
Field Summary
protected  ArrayList<ParallelDirichletProcessMixtureModel.ObservationAssignmentTask> assignmentTasks
          Tasks that assign observations to clusters
protected  ArrayList<ParallelDirichletProcessMixtureModel.ClusterUpdaterTask> clusterUpdaterTasks
          Tasks that update the values of the clusters for Gibbs sampling
 
Fields inherited from class gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel
alphaInverseSampler, clusterWeights, conditionalPriorPredictive, DEFAULT_ALPHA, DEFAULT_NUM_INITIAL_CLUSTERS, DEFAULT_REESTIMATE_ALPHA, etaSampler, initialAlpha, reestimateAlpha, updater
 
Fields inherited from class gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
currentParameter, DEFAULT_NUM_SAMPLES, previousParameter, random
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
ParallelDirichletProcessMixtureModel()
          Creates a new instance of ParallelDirichletProcessMixtureModel
 
Method Summary
protected  ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K, DirichletProcessMixtureModel.DPMMLogConditional logConditional)
          Assigns observations to each of the K clusters, plus the as-yet-uncreated new cluster
 int getNumThreads()
          Gets the number of threads in the thread pool.
 ThreadPoolExecutor getThreadPool()
          Gets the thread pool for the algorithm to use.
 void setThreadPool(ThreadPoolExecutor threadPool)
          Sets the thread pool for the algorithm to use.
protected  ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments)
          Update each cluster according to the data assigned to it
 
Methods inherited from class gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel
assignObservationToCluster, clone, createCluster, createInitialLearnedObject, getInitialAlpha, getNumInitialClusters, getReestimateAlpha, getUpdater, mcmcUpdate, setInitialAlpha, setNumInitialClusters, setReestimateAlpha, setUpdater, updateAlpha
 
Methods inherited from class gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
cleanupAlgorithm, getBurnInIterations, getCurrentParameter, getIterationsPerSample, getPreviousParameter, getRandom, getResult, initializeAlgorithm, setBurnInIterations, setCurrentParameter, setIterationsPerSample, setRandom, setResult, step
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.util.CloneableSerializable
clone
 
Methods inherited from interface gov.sandia.cognition.learning.algorithm.BatchLearner
learn
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid, stop
 

Field Detail

assignmentTasks

protected transient ArrayList<ParallelDirichletProcessMixtureModel.ObservationAssignmentTask> assignmentTasks
Tasks that assign observations to clusters


clusterUpdaterTasks

protected transient ArrayList<ParallelDirichletProcessMixtureModel.ClusterUpdaterTask> clusterUpdaterTasks
Tasks that update the values of the clusters for Gibbs sampling

Constructor Detail

ParallelDirichletProcessMixtureModel

public ParallelDirichletProcessMixtureModel()
Creates a new instance of ParallelDirichletProcessMixtureModel

Method Detail

getNumThreads

public int getNumThreads()
Description copied from interface: ParallelAlgorithm
Gets the number of threads in the thread pool.

Specified by:
getNumThreads in interface ParallelAlgorithm
Returns:
Number of threads in the thread pool

getThreadPool

public ThreadPoolExecutor getThreadPool()
Description copied from interface: ParallelAlgorithm
Gets the thread pool for the algorithm to use.

Specified by:
getThreadPool in interface ParallelAlgorithm
Returns:
Thread pool used for parallelization.

setThreadPool

public void setThreadPool(ThreadPoolExecutor threadPool)
Description copied from interface: ParallelAlgorithm
Sets the thread pool for the algorithm to use.

Specified by:
setThreadPool in interface ParallelAlgorithm
Parameters:
threadPool - Thread pool used for parallelization.

assignObservationsToClusters

protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K,
                                                                              DirichletProcessMixtureModel.DPMMLogConditional logConditional)
Description copied from class: DirichletProcessMixtureModel
Assigns observations to each of the K clusters, plus the as-yet-uncreated new cluster

Overrides:
assignObservationsToClusters in class DirichletProcessMixtureModel<ObservationType>
Parameters:
K - Number of clusters
logConditional - The log of the conditional.
Returns:
Assignments from observations to clusters

updateClusters

protected ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments)
Description copied from class: DirichletProcessMixtureModel
Update each cluster according to the data assigned to it

Overrides:
updateClusters in class DirichletProcessMixtureModel<ObservationType>
Parameters:
clusterAssignments - Observations assigned to each cluster
Returns:
Cluster that contains an update parameter estimate and weighted by the number of observations assigned to the cluster