diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index deee3a9c6b2c5..fd941d2ebed8b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -63,7 +63,10 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } } - def cloneComplement() = new BernoulliSampler[T](lb, ub, !complement) + /** + * Return a sampler with is the complement of the range specified of the current sampler. + */ + def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) override def clone = new BernoulliSampler[T](lb, ub, complement) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 0fbb606267f9b..7112ada039de1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -21,14 +21,13 @@ import scala.reflect.ClassTag import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} +import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.SparkContext._ import org.apache.spark.util.random.BernoulliSampler - -import org.jblas.DoubleMatrix import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -179,13 +178,13 @@ object MLUtils { } /** - * Return a k element list of pairs of RDDs with the first element of each pair + * Return a k element array of pairs of RDDs with the first element of each pair * containing the validation data, a unique 1/Kth of the data and the second - * element, the training data, contain the compliment of that. + * element, the training data, contain the complement of that. */ - def kFold[T : ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat - (1 to numFolds).map { fold => + (1 to numFolds).map { fold => val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 92428ff6e44d8..1ce4948108a80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -120,25 +120,25 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { for (seed <- 1 to 5) { val foldedRdds = MLUtils.kFold(data, folds, seed) assert(foldedRdds.size === folds) - foldedRdds.map { case (test, train) => - val result = test.union(train).collect().sorted - val testSize = test.collect().size.toFloat - assert(testSize > 0, "empty test data") + foldedRdds.map { case (validation, training) => + val result = validation.union(training).collect().sorted + val validationSize = validation.collect().size.toFloat + assert(validationSize > 0, "empty validation data") val p = 1 / folds.toFloat // Within 3 standard deviations of the mean - val range = 3 * math.sqrt(100 * p * (1-p)) + val range = 3 * math.sqrt(100 * p * (1 - p)) val expected = 100 * p val lowerBound = expected - range val upperBound = expected + range - assert(testSize > lowerBound, - s"Test data ($testSize) smaller than expected ($lowerBound)" ) - assert(testSize < upperBound, - s"Test data ($testSize) larger than expected ($upperBound)" ) - assert(train.collect().size > 0, "empty training data") + assert(validationSize > lowerBound, + s"Validation data ($validationSize) smaller than expected ($lowerBound)" ) + assert(validationSize < upperBound, + s"Validation data ($validationSize) larger than expected ($upperBound)" ) + assert(training.collect().size > 0, "empty training data") assert(result === collectedData, - "Each training+test set combined should contain all of the data.") + "Each training+validation set combined should contain all of the data.") } - // K fold cross validation should only have each element in the test set exactly once + // K fold cross validation should only have each element in the validation set exactly once assert(foldedRdds.map(_._1).reduce((x,y) => x.union(y)).collect().sorted === data.collect().sorted) }