gov.sandia.cognition.learning.algorithm.clustering
Class KMeansClusterer<DataType,ClusterType extends Cluster<DataType>>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
          extended by gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm<ResultType>
              extended by gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType>>
                  extended by gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer<DataType,ClusterType>
Type Parameters:
DataType - The type of the data to cluster. This is typically defined by the divergence function used.
ClusterType - The type of Cluster created by the algorithm. This is typically defined by the cluster creator function used.
All Implemented Interfaces:
AnytimeAlgorithm<Collection<ClusterType>>, IterativeAlgorithm, MeasurablePerformanceAlgorithm, StoppableAlgorithm, AnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType>>, BatchLearner<Collection<? extends DataType>,Collection<ClusterType>>, BatchClusterer<DataType,ClusterType>, DivergenceFunctionContainer<ClusterType,DataType>, CloneableSerializable, Serializable, Cloneable
Direct Known Subclasses:
KMeansClustererWithRemoval, OptimizedKMeansClusterer, ParallelizedKMeansClusterer

@CodeReviews(reviews={@CodeReview(reviewer="Kevin R. Dixon",date="2008-10-06",changesNeeded=true,comments={"The constructors for this class are not user friendly.","I\'ve been trying to write a test GUI for k-means for over an hour and STILL can\'t figure out the combination of classes to configure the constructor.","Please make a constructor that configures the class with meaningful, user-friendly default arguments."}),@CodeReview(reviewer="Kevin R. Dixon",date="2008-07-22",changesNeeded=false,comments={"Changed the condition to be \'members.size() > 0\' instead of 1 in createClustersFromAssignments()","Cleaned up javadoc.","Code generally looks fine."})})
@PublicationReferences(references={@PublicationReference(author="Wikipedia",title="K-means algorithm",type=WebPage,year=2008,url="http://en.wikipedia.org/wiki/K-means_algorithm"),@PublicationReference(author="Matteo Matteucci",title="A Tutorial on Clustering Algorithms: k-means Demo",type=WebPage,year=2008,url="http://home.dei.polimi.it/matteucc/Clustering/tutorial_html/AppletKM.html")})
public class KMeansClusterer<DataType,ClusterType extends Cluster<DataType>>
extends AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType>>
implements BatchClusterer<DataType,ClusterType>, MeasurablePerformanceAlgorithm, DivergenceFunctionContainer<ClusterType,DataType>

The KMeansClusterer class implements the standard k-means (k-centroids) clustering algorithm.

Since:
1.0
Author:
Justin Basilico, Kevin R. Dixon
See Also:
Serialized Form

Field Summary
protected  int[] assignments
          The current assignments of elements to clusters.
protected  int[] clusterCounts
          The current number of elements assigned to each cluster.
protected  ArrayList<ClusterType> clusters
          The current set of clusters.
protected  ClusterCreator<ClusterType,DataType> creator
          The cluster creator for creating clusters.
static int DEFAULT_MAX_ITERATIONS
          The default maximum number of iterations is 1000.
static int DEFAULT_NUM_REQUESTED_CLUSTERS
          The default number of requested clusters is 10.
protected  ClusterDivergenceFunction<? super ClusterType,? super DataType> divergenceFunction
          The divergence function between cluster being used.
protected  FixedClusterInitializer<ClusterType,DataType> initializer
          The initializer for the algorithm.
protected  int numRequestedClusters
          The number of clusters requested.
 
Fields inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
data, keepGoing
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
maxIterations
 
Fields inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
DEFAULT_ITERATION, iteration
 
Constructor Summary
KMeansClusterer()
          Creates a new instance of KMeansClusterer with default parameters.
KMeansClusterer(int numRequestedClusters, int maxIterations, FixedClusterInitializer<ClusterType,DataType> initializer, ClusterDivergenceFunction<? super ClusterType,? super DataType> divergenceFunction, ClusterCreator<ClusterType,DataType> creator)
          Creates a new instance of KMeansClusterer using the given parameters.
 
Method Summary
protected  ArrayList<ArrayList<DataType>> assignDataFromIndices()
          Puts the data into a list of lists for each cluster to then estimate
protected  int[] assignDataToClusters(Collection<? extends DataType> data)
          Creates the cluster assignments given the current locations of clusters
protected  void cleanupAlgorithm()
          Called to clean up the learning algorithm's state after learning has finished.
 KMeansClusterer<DataType,ClusterType> clone()
          This makes public the clone method on the Object class and removes the exception that it throws.
protected  void createClustersFromAssignments()
          Creates the set of clusters using the current cluster assignments.
protected  int[] getAssignments()
          Getter for assignments
protected  int getClosestClusterIndex(DataType element)
          Gets the index of the closest cluster for the given element.
protected  ClusterType getCluster(int index)
          Gets the cluster for the given index.
protected  int[] getClusterCounts()
          Getter for clusterCounts
 ArrayList<ClusterType> getClusters()
          Getter for clusters
 ClusterCreator<ClusterType,DataType> getCreator()
          Gets the cluster creator.
 ClusterDivergenceFunction<? super ClusterType,? super DataType> getDivergenceFunction()
          Gets the divergence function used in clustering.
 FixedClusterInitializer<ClusterType,DataType> getInitializer()
          Gets the cluster initializer.
 int getNumChanged()
          Getter for numChanged
protected  int getNumClusters()
          Gets the actual number of clusters that were created.
 int getNumElements()
          Returns the number of elements
 int getNumRequestedClusters()
          Gets the number of clusters that were requested.
 NamedValue<Integer> getPerformance()
          Gets the performance, which is the number changed on the last iteration.
 ArrayList<ClusterType> getResult()
          Gets the current result of the algorithm.
protected  boolean initializeAlgorithm()
          Called to initialize the learning algorithm's state based on the data that is stored in the data field.
protected  boolean setAssignment(int elementIndex, int newClusterIndex)
          Sets the assignment of the given element to the new cluster index, updating the cluster counts as well.
protected  void setClusters(ArrayList<ClusterType> clusters)
          Sets the clusters.
 void setCreator(ClusterCreator<ClusterType,DataType> creator)
          Sets the cluster creator.
 void setData(Collection<? extends DataType> data)
          Gets the data to use for learning.
 void setDivergenceFunction(ClusterDivergenceFunction<? super ClusterType,? super DataType> divergenceFunction)
          Sets the divergence function.
 void setInitializer(FixedClusterInitializer<ClusterType,DataType> initializer)
          Sets the cluster initializer.
protected  void setNumChanged(int numChanged)
          Setter for numChanged
 void setNumRequestedClusters(int numRequestedClusters)
          Sets the number of requested clusters.
protected  boolean step()
          Do a step of the clustering algorithm.
 
Methods inherited from class gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
getData, getKeepGoing, learn, setKeepGoing, stop
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractAnytimeAlgorithm
getMaxIterations, isResultValid, setMaxIterations
 
Methods inherited from class gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm
addIterativeAlgorithmListener, fireAlgorithmEnded, fireAlgorithmStarted, fireStepEnded, fireStepStarted, getIteration, getListeners, removeIterativeAlgorithmListener, setIteration, setListeners
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface gov.sandia.cognition.learning.algorithm.BatchLearner
learn
 
Methods inherited from interface gov.sandia.cognition.algorithm.AnytimeAlgorithm
getMaxIterations, setMaxIterations
 
Methods inherited from interface gov.sandia.cognition.algorithm.IterativeAlgorithm
addIterativeAlgorithmListener, getIteration, removeIterativeAlgorithmListener
 
Methods inherited from interface gov.sandia.cognition.algorithm.StoppableAlgorithm
isResultValid
 

Field Detail

DEFAULT_NUM_REQUESTED_CLUSTERS

public static final int DEFAULT_NUM_REQUESTED_CLUSTERS
The default number of requested clusters is 10.

See Also:
Constant Field Values

DEFAULT_MAX_ITERATIONS

public static final int DEFAULT_MAX_ITERATIONS
The default maximum number of iterations is 1000.

See Also:
Constant Field Values

numRequestedClusters

protected int numRequestedClusters
The number of clusters requested.


initializer

protected FixedClusterInitializer<ClusterType extends Cluster<DataType>,DataType> initializer
The initializer for the algorithm.


divergenceFunction

protected ClusterDivergenceFunction<? super ClusterType extends Cluster<DataType>,? super DataType> divergenceFunction
The divergence function between cluster being used.


creator

protected ClusterCreator<ClusterType extends Cluster<DataType>,DataType> creator
The cluster creator for creating clusters.


clusters

protected ArrayList<ClusterType extends Cluster<DataType>> clusters
The current set of clusters.


assignments

protected int[] assignments
The current assignments of elements to clusters.


clusterCounts

protected int[] clusterCounts
The current number of elements assigned to each cluster.

Constructor Detail

KMeansClusterer

public KMeansClusterer()
Creates a new instance of KMeansClusterer with default parameters.


KMeansClusterer

public KMeansClusterer(int numRequestedClusters,
                       int maxIterations,
                       FixedClusterInitializer<ClusterType,DataType> initializer,
                       ClusterDivergenceFunction<? super ClusterType,? super DataType> divergenceFunction,
                       ClusterCreator<ClusterType,DataType> creator)
Creates a new instance of KMeansClusterer using the given parameters.

Parameters:
numRequestedClusters - The number of clusters requested (k).
maxIterations - Maximum number of iterations before stopping
initializer - The initializer for the clusters.
divergenceFunction - The divergence function.
creator - The cluster creator.
Method Detail

clone

public KMeansClusterer<DataType,ClusterType> clone()
Description copied from class: AbstractCloneableSerializable
This makes public the clone method on the Object class and removes the exception that it throws. Its default behavior is to automatically create a clone of the exact type of object that the clone is called on and to copy all primitives but to keep all references, which means it is a shallow copy. Extensions of this class may want to override this method (but call super.clone() to implement a "smart copy". That is, to target the most common use case for creating a copy of the object. Because of the default behavior being a shallow copy, extending classes only need to handle fields that need to have a deeper copy (or those that need to be reset). Some of the methods in ObjectUtil may be helpful in implementing a custom clone method. Note: The contract of this method is that you must use super.clone() as the basis for your implementation.

Specified by:
clone in interface CloneableSerializable
Overrides:
clone in class AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType extends Cluster<DataType>>>
Returns:
A clone of this object.

initializeAlgorithm

protected boolean initializeAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to initialize the learning algorithm's state based on the data that is stored in the data field. The return value indicates if the algorithm can be run or not based on the initialization.

Specified by:
initializeAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType extends Cluster<DataType>>>
Returns:
True if the learning algorithm can be run and false if it cannot.

step

protected boolean step()
Do a step of the clustering algorithm. Return the number of elements the changed their cluster membership. If this is zero then the clustering is complete.

Specified by:
step in class AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType extends Cluster<DataType>>>
Returns:
true means keep going, false means stop clustering.

cleanupAlgorithm

protected void cleanupAlgorithm()
Description copied from class: AbstractAnytimeBatchLearner
Called to clean up the learning algorithm's state after learning has finished.

Specified by:
cleanupAlgorithm in class AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType extends Cluster<DataType>>>

assignDataToClusters

protected int[] assignDataToClusters(Collection<? extends DataType> data)
Creates the cluster assignments given the current locations of clusters

Parameters:
data - Data to assign
Returns:
Assignments of the data to each of the k-clusters

setData

public void setData(Collection<? extends DataType> data)
Description copied from class: AbstractAnytimeBatchLearner
Gets the data to use for learning. This is set when learning starts and then cleared out once learning is finished.

Overrides:
setData in class AbstractAnytimeBatchLearner<Collection<? extends DataType>,Collection<ClusterType extends Cluster<DataType>>>
Parameters:
data - The data to use for learning.

assignDataFromIndices

protected ArrayList<ArrayList<DataType>> assignDataFromIndices()
Puts the data into a list of lists for each cluster to then estimate

Returns:
The list of lists for each cluster to then estimate

createClustersFromAssignments

protected void createClustersFromAssignments()
Creates the set of clusters using the current cluster assignments.


getClosestClusterIndex

protected int getClosestClusterIndex(DataType element)
Gets the index of the closest cluster for the given element.

Parameters:
element - The element to get the closet cluster for.
Returns:
The index of the closest cluster.

setAssignment

protected boolean setAssignment(int elementIndex,
                                int newClusterIndex)
Sets the assignment of the given element to the new cluster index, updating the cluster counts as well.

Parameters:
elementIndex - The index of the element.
newClusterIndex - The new cluster the element is assigned to.
Returns:
True if the assignment changed. Otherwise, false.

getCluster

protected ClusterType getCluster(int index)
Gets the cluster for the given index.

Parameters:
index - The index of the cluster.
Returns:
The cluster for the given index.

getNumClusters

protected int getNumClusters()
Gets the actual number of clusters that were created.

Returns:
The actual number of clusters.

getNumRequestedClusters

public int getNumRequestedClusters()
Gets the number of clusters that were requested.

Returns:
The number of clusters that were requested.

getInitializer

public FixedClusterInitializer<ClusterType,DataType> getInitializer()
Gets the cluster initializer.

Returns:
The cluster initializer.

getDivergenceFunction

public ClusterDivergenceFunction<? super ClusterType,? super DataType> getDivergenceFunction()
Gets the divergence function used in clustering.

Specified by:
getDivergenceFunction in interface DivergenceFunctionContainer<ClusterType extends Cluster<DataType>,DataType>
Returns:
The divergence function.

getCreator

public ClusterCreator<ClusterType,DataType> getCreator()
Gets the cluster creator.

Returns:
The cluster creator.

setNumRequestedClusters

public void setNumRequestedClusters(int numRequestedClusters)
Sets the number of requested clusters.

Parameters:
numRequestedClusters - The number of requested clusters.

setInitializer

public void setInitializer(FixedClusterInitializer<ClusterType,DataType> initializer)
Sets the cluster initializer.

Parameters:
initializer - The cluster initializer.

setDivergenceFunction

public void setDivergenceFunction(ClusterDivergenceFunction<? super ClusterType,? super DataType> divergenceFunction)
Sets the divergence function.

Parameters:
divergenceFunction - The divergence function.

setCreator

public void setCreator(ClusterCreator<ClusterType,DataType> creator)
Sets the cluster creator.

Parameters:
creator - The creator for clusters.

getNumElements

public int getNumElements()
Returns the number of elements

Returns:
number of elements being clustered

setClusters

protected void setClusters(ArrayList<ClusterType> clusters)
Sets the clusters.

Parameters:
clusters - The clusters.

getClusters

public ArrayList<ClusterType> getClusters()
Getter for clusters

Returns:
list of clusters in the algorithm

getResult

public ArrayList<ClusterType> getResult()
Description copied from interface: AnytimeAlgorithm
Gets the current result of the algorithm.

Specified by:
getResult in interface AnytimeAlgorithm<Collection<ClusterType extends Cluster<DataType>>>
Returns:
Current result of the algorithm.

getAssignments

protected int[] getAssignments()
Getter for assignments

Returns:
The assignment of elements to clusters

getClusterCounts

protected int[] getClusterCounts()
Getter for clusterCounts

Returns:
counts for how many elements are assigned to each cluster

getNumChanged

public int getNumChanged()
Getter for numChanged

Returns:
Returns the number of samples that changed assignment between iterations

setNumChanged

protected void setNumChanged(int numChanged)
Setter for numChanged

Parameters:
numChanged - Returns the number of samples that changed assignment between iterations

getPerformance

public NamedValue<Integer> getPerformance()
Gets the performance, which is the number changed on the last iteration.

Specified by:
getPerformance in interface MeasurablePerformanceAlgorithm
Returns:
The performance of the algorithm.