From 5a33f1d6d91e1ec1afb62c920abc0443f13725cf Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 23 Feb 2014 17:33:11 -0800 Subject: [PATCH] Code review follow up. --- .../org/apache/spark/util/random/RandomSampler.scala | 2 ++ .../scala/org/apache/spark/mllib/util/MLUtils.scala | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 850c94bef9ea8..deee3a9c6b2c5 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -63,6 +63,8 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } } + def cloneComplement() = new BernoulliSampler[T](lb, ub, !complement) + override def clone = new BernoulliSampler[T](lb, ub, complement) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index ef2b25a39b81e..4fe021a33eabc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -187,12 +187,12 @@ object MLUtils { */ def kFold[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = { val foldsF = folds.toFloat - 1.to(folds).map(fold => (( - new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, - complement = false), seed), - new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, - complement = true), seed) - ))).toList + 1.to(folds).map { fold => + val sampler = new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, complement = false) + val train = new PartitionwiseSampledRDD(rdd, sampler, seed) + val test = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) + (train, test) + }.toList } /**