gov.sandia.cognition.text.topic
Class ParallelLatentDirichletAllocationVectorGibbsSampler

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends Vectorizable>,LatentDirichletAllocationVectorGibbsSampler.Result>
                  extended by gov.sandia.cognition.text.topic.LatentDirichletAllocationVectorGibbsSampler
                      extended by gov.sandia.cognition.text.topic.ParallelLatentDirichletAllocationVectorGibbsSampler
All Implemented Interfaces:
AnytimeAlgorithm<LatentDirichletAllocationVectorGibbsSampler.Result>, IterativeAlgorithm, ParallelAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends Vectorizable>,LatentDirichletAllocationVectorGibbsSampler.Result>, BatchLearner<Collection<? extends Vectorizable>,LatentDirichletAllocationVectorGibbsSampler.Result>, CloneableSerializable, Randomized, Serializable, Cloneable

public class ParallelLatentDirichletAllocationVectorGibbsSampler
extends LatentDirichletAllocationVectorGibbsSampler
implements ParallelAlgorithm

A parallel implementation of LatentDirichletAllocationVectorGibbsSampler. It runs the sampling for the different documents using a thread pool.

Since:
3.3.2
Author:
Jason Shepherd
See Also:
Serialized Form

Nested Class Summary
protected  class ParallelLatentDirichletAllocationVectorGibbsSampler.DocumentSampleTask
          A document sampling task
 
Nested classes/interfaces inherited from class gov.sandia.cognition.text.topic.LatentDirichletAllocationVectorGibbsSampler
LatentDirichletAllocationVectorGibbsSampler.Result
 
Field Summary
 
Fields inherited from class gov.sandia.cognition.text.topic.LatentDirichletAllocationVectorGibbsSampler
alpha, beta, burnInIterations, DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_BURN_IN_ITERATIONS, DEFAULT_ITERATIONS_PER_SAMPLE, DEFAULT_MAX_ITERATIONS, DEFAULT_TOPIC_COUNT, documentCount, documentTopicCount, documentTopicSum, iterationsPerSample, occurrenceTopicAssignments, random, result, sampleCount, termCount, topicCount, topicTermCount, topicTermSum
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
ParallelLatentDirichletAllocationVectorGibbsSampler()
          Creates a new ParallelLatentDirichletAllocationVectorGibbsSampler with default parameters.
ParallelLatentDirichletAllocationVectorGibbsSampler(int topicCount, double alpha, double beta, int maxIterations, int burnInIterations, int iterationsPerSample, Random random)
          Creates a new ParallelLatentDirichletAllocationVectorGibbsSampler with the given parameters.
 
Method Summary
protected  void cleanupAlgorithm()
          Called to clean up the learning algorithm's state after learning has finished.
 int getNumThreads()
          Gets the number of threads in the thread pool.
 ThreadPoolExecutor getThreadPool()
          Gets the thread pool for the algorithm to use.
 void setThreadPool(ThreadPoolExecutor threadPool)
          Sets the thread pool for the algorithm to use.
protected  boolean step()
          Called to take a single step of the learning algorithm.
 
Methods inherited from class gov.sandia.cognition.text.topic.LatentDirichletAllocationVectorGibbsSampler
getAlpha, getBeta, getBurnInIterations, getDocumentCount, getIterationsPerSample, getRandom, getResult, getTermCount, getTopicCount, initializeAlgorithm, readParameters, sampleTopic, setAlpha, setBeta, setBurnInIterations, setIterationsPerSample, setRandom, setTopicCount
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
clone, getData, getKeepGoing, learn, setData, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.util.CloneableSerializable
clone
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid
 

Constructor Detail

ParallelLatentDirichletAllocationVectorGibbsSampler

public ParallelLatentDirichletAllocationVectorGibbsSampler()
Creates a new ParallelLatentDirichletAllocationVectorGibbsSampler with default parameters.


ParallelLatentDirichletAllocationVectorGibbsSampler

public ParallelLatentDirichletAllocationVectorGibbsSampler(int topicCount,
                                                           double alpha,
                                                           double beta,
                                                           int maxIterations,
                                                           int burnInIterations,
                                                           int iterationsPerSample,
                                                           Random random)
Creates a new ParallelLatentDirichletAllocationVectorGibbsSampler with the given parameters.

Parameters:
topicCount - The number of topics for the algorithm to create. Must be positive.
alpha - The alpha parameter controlling the Dirichlet distribution for the document-topic probabilities. It acts as a prior weight assigned to the document-topic counts. Must be positive.
beta - The beta parameter controlling the Dirichlet distribution for the topic-term probabilities. It acts as a prior weight assigned to the topic-term counts.
maxIterations - The maximum number of iterations to run for. Must be positive.
burnInIterations - The number of burn-in iterations for the Markov Chain Monte Carlo algorithm to run before sampling begins.
iterationsPerSample - The number of iterations to the Markov Chain Monte Carlo algorithm between samples (after the burn-in iterations).
random - The random number generator to use.
Method Detail

step

protected boolean step()
Description copied from class: AbstractAnytimeBatchLearner
Called to take a single step of the learning algorithm.

Overrides:
step in class LatentDirichletAllocationVectorGibbsSampler
Returns:
True if another step can be taken and false it the algorithm should halt.

cleanupAlgorithm

protected void cleanupAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to clean up the learning algorithm's state after learning has finished.

Overrides:
cleanupAlgorithm in class LatentDirichletAllocationVectorGibbsSampler

getThreadPool

public ThreadPoolExecutor getThreadPool()
Description copied from interface: ParallelAlgorithm
Gets the thread pool for the algorithm to use.

Specified by:
getThreadPool in interface ParallelAlgorithm
Returns:
Thread pool used for parallelization.

setThreadPool

public void setThreadPool(ThreadPoolExecutor threadPool)
Description copied from interface: ParallelAlgorithm
Sets the thread pool for the algorithm to use.

Specified by:
setThreadPool in interface ParallelAlgorithm
Parameters:
threadPool - Thread pool used for parallelization.

getNumThreads

public int getNumThreads()
Description copied from interface: ParallelAlgorithm
Gets the number of threads in the thread pool.

Specified by:
getNumThreads in interface ParallelAlgorithm
Returns:
Number of threads in the thread pool