Skip to content

Commit

Permalink
Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 9, 2014
1 parent 08f8e4d commit c0b7fa4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
33 changes: 7 additions & 26 deletions core/src/main/scala/org/apache/spark/rdd/FoldedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import cern.jet.random.Poisson
import cern.jet.random.engine.DRand

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.BernoulliSampler

private[spark]
class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
Expand All @@ -32,24 +33,10 @@ class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition w

class FoldedRDD[T: ClassTag](
prev: RDD[T],
fold: Int,
folds: Int,
fold: Float,
folds: Float,
seed: Int)
extends RDD[T](prev) {

override def getPartitions: Array[Partition] = {
val rg = new Random(seed)
firstParent[T].partitions.map(x => new FoldedRDDPartition(x, rg.nextInt))
}

override def getPreferredLocations(split: Partition): Seq[String] =
firstParent[T].preferredLocations(split.asInstanceOf[FoldedRDDPartition].prev)

override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
val split = splitIn.asInstanceOf[FoldedRDDPartition]
val rand = new Random(split.seed)
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold-1))
}
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds,fold/folds, false), seed) {
}

/**
Expand All @@ -58,14 +45,8 @@ class FoldedRDD[T: ClassTag](
*/
class CompositeFoldedRDD[T: ClassTag](
prev: RDD[T],
fold: Int,
folds: Int,
fold: Float,
folds: Float,
seed: Int)
extends FoldedRDD[T](prev, fold, folds, seed) {

override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
val split = splitIn.asInstanceOf[FoldedRDDPartition]
val rand = new Random(split.seed)
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold-1))
}
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds, fold/folds, true), seed) {
}
18 changes: 16 additions & 2 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,28 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("FoldedRDD") {
val data = sc.parallelize(1 to 100, 2)
val lowerFoldedRdd = new FoldedRDD(data, 1, 2, 1)
val upperFoldedRdd = new FoldedRDD(data, 2, 2, 1)
val lowerCompositeFoldedRdd = new CompositeFoldedRDD(data, 1, 2, 1)
assert(lowerFoldedRdd.collect().sorted.size == 50)
assert(lowerCompositeFoldedRdd.collect().sorted.size == 50)
assert(lowerFoldedRdd.subtract(lowerCompositeFoldedRdd).collect().sorted ===
lowerFoldedRdd.collect().sorted)
assert(upperFoldedRdd.collect().sorted.size == 50)
}

test("kfoldRdd") {
val data = sc.parallelize(1 to 100, 2)
for (folds <- 1 to 10) {
val collectedData = data.collect().sorted
for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
val foldedRdds = data.kFoldRdds(folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map{case (test, train) =>
assert(test.union(train).collect().sorted === data.collect().sorted,
val result = test.union(train).collect().sorted
assert(result === collectedData,
"Each training+test set combined contains all of the data")
}
// K fold cross validation should only have each element in the test set exactly once
Expand Down

0 comments on commit c0b7fa4

Please sign in to comment.