gov.sandia.cognition.learning.experiment
Class CrossFoldCreator<DataType>

java.lang.Object
  extended by gov.sandia.cognition.util.AbstractCloneableSerializable
      extended by gov.sandia.cognition.util.AbstractRandomized
          extended by gov.sandia.cognition.learning.experiment.CrossFoldCreator<DataType>
Type Parameters:
DataType - The type of data to create the folds for.
All Implemented Interfaces:
ValidationFoldCreator<DataType,DataType>, CloneableSerializable, Randomized, Serializable, Cloneable

public class CrossFoldCreator<DataType>
extends AbstractRandomized
implements ValidationFoldCreator<DataType,DataType>

The CrossFoldCreator implements a validation fold creator that creates folds for a typical k-fold cross-validation experiment. That is, it splits the data into k folds where each item appears in the testing set in exactly 1 fold and in the training set in the remaining k - 1 folds. At the limit where k is equal to the size of the data, this becomes leave-one-out cross-validation, but is typically used in the case where leave-one-out cross-validation is too costly to run and k is set to a much smaller value.

Since:
2.0
Author:
Justin Basilico
See Also:
Serialized Form

Field Summary
static int DEFAULT_NUM_FOLDS
          The default number of folds is 10.
protected  int numFolds
          The number of folds to create.
 
Fields inherited from class gov.sandia.cognition.util.AbstractRandomized
random
 
Constructor Summary
CrossFoldCreator()
          Creates a new instance of CrossFoldCreator with a default number of folds (10) and a default Random number generator.
CrossFoldCreator(int numFolds)
          Creates a new CrossFoldCreator.
CrossFoldCreator(int numFolds, Random random)
          Creates a new CrossFoldCreator.
 
Method Summary
protected static void checkNumFolds(int numFolds)
          Checks the given number of folds to make sure that it is greater than 1.
 List<PartitionedDataset<DataType>> createFolds(Collection<? extends DataType> data)
          Creates the requested number of cross-validation folds from the given data.
static
<DataType> List<PartitionedDataset<DataType>>
createFolds(Collection<? extends DataType> data, int numFolds, Random random)
          Creates the requested number of cross-validation folds from the given data.
 int getNumFolds()
          Gets the number of folds to create.
 void setNumFolds(int numFolds)
          Sets the number of folds to create.
 
Methods inherited from class gov.sandia.cognition.util.AbstractRandomized
clone, getRandom, setRandom
 
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

DEFAULT_NUM_FOLDS

public static final int DEFAULT_NUM_FOLDS
The default number of folds is 10.

See Also:
Constant Field Values

numFolds

protected int numFolds
The number of folds to create.

Constructor Detail

CrossFoldCreator

public CrossFoldCreator()
Creates a new instance of CrossFoldCreator with a default number of folds (10) and a default Random number generator.


CrossFoldCreator

public CrossFoldCreator(int numFolds)
Creates a new CrossFoldCreator.

Parameters:
numFolds - The number of folds to create.

CrossFoldCreator

public CrossFoldCreator(int numFolds,
                        Random random)
Creates a new CrossFoldCreator.

Parameters:
numFolds - The number of folds to create.
random - The random number generator to use.
Method Detail

createFolds

public List<PartitionedDataset<DataType>> createFolds(Collection<? extends DataType> data)
Creates the requested number of cross-validation folds from the given data. The number of folds returned will be the minimum of the number of requested folds and the size of the data because it cannot create more folds than elements of the data.

Specified by:
createFolds in interface ValidationFoldCreator<DataType,DataType>
Parameters:
data - The data to create the folds for.
Returns:
The created cross-validation folds.

createFolds

public static <DataType> List<PartitionedDataset<DataType>> createFolds(Collection<? extends DataType> data,
                                                                        int numFolds,
                                                                        Random random)
Creates the requested number of cross-validation folds from the given data. The number of folds returned will be the minimum of the number of requested folds and the size of the data because it cannot create more folds than elements of the data.

Type Parameters:
DataType - The type of data to create folds over.
Parameters:
data - The data to create the folds for.
numFolds - The number of folds to create.
random - The random number generator to use.
Returns:
The created cross-validation folds.

getNumFolds

public int getNumFolds()
Gets the number of folds to create.

Returns:
The number of folds to create.

setNumFolds

public void setNumFolds(int numFolds)
Sets the number of folds to create. The number of folds must be greater than one.

Parameters:
numFolds - The number of folds to create.

checkNumFolds

protected static void checkNumFolds(int numFolds)
Checks the given number of folds to make sure that it is greater than 1.

Parameters:
numFolds - The number of folds.