Skip to content

Commit

Permalink
Code review follow up.
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 9, 2014
1 parent e8741a7 commit 5a33f1d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}
}

def cloneComplement() = new BernoulliSampler[T](lb, ub, !complement)

override def clone = new BernoulliSampler[T](lb, ub, complement)
}

Expand Down
12 changes: 6 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ object MLUtils {
*/
def kFold[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
val foldsF = folds.toFloat
1.to(folds).map(fold => ((
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
complement = false), seed),
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
complement = true), seed)
))).toList
1.to(folds).map { fold =>
val sampler = new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, complement = false)
val train = new PartitionwiseSampledRDD(rdd, sampler, seed)
val test = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
(train, test)
}.toList
}

/**
Expand Down

0 comments on commit 5a33f1d

Please sign in to comment.