Skip to content

Commit

Permalink
Oops, forgot to update an extra time in the checkpointer tests, after…
Browse files Browse the repository at this point in the history
… the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before.
  • Loading branch information
jkbradley committed Jul 30, 2015
1 parent 32b23b8 commit d41902c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,22 @@ private[mllib] abstract class PeriodicCheckpointer[T](
}

/** Checkpoint the Dataset */
def checkpoint(data: T): Unit
protected def checkpoint(data: T): Unit

/** Return true iff the Dataset is checkpointed */
def isCheckpointed(data: T): Boolean
protected def isCheckpointed(data: T): Boolean

/**
* Persist the Dataset.
* Note: This should handle checking the current [[StorageLevel]] of the Dataset.
*/
def persist(data: T): Unit
protected def persist(data: T): Unit

/** Unpersist the Dataset */
def unpersist(data: T): Unit
protected def unpersist(data: T): Unit

/** Get list of checkpoint files for this given Dataset */
def getCheckpointFiles(data: T): Iterable[String]
protected def getCheckpointFiles(data: T): Iterable[String]

/**
* Call this at the end to delete any remaining checkpoint files.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,19 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {

override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()

override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed

override def persist(data: Graph[VD, ED]): Unit = {
override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
data.persist()
}
}

override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)

override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles
override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
data.getCheckpointFiles
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ private[mllib] class PeriodicRDDCheckpointer[T](
sc: SparkContext)
extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {

override def checkpoint(data: RDD[T]): Unit = data.checkpoint()
override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()

override def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed

override def persist(data: RDD[T]): Unit = {
override protected def persist(data: RDD[T]): Unit = {
if (data.getStorageLevel == StorageLevel.NONE) {
data.persist()
}
}

override def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)

override def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
data.getCheckpointFile.map(x => x)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
val graph1 = createGraph(sc)
val checkpointer =
new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)

Expand All @@ -58,6 +59,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
val graph1 = createGraph(sc)
val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
checkpointInterval, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont

val rdd1 = createRDD(sc)
val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
checkpointer.update(rdd1)
rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
checkPersistence(rddsToCheck, 1)

Expand All @@ -56,6 +57,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont
sc.setCheckpointDir(path)
val rdd1 = createRDD(sc)
val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
checkpointer.update(rdd1)
rdd1.count()
rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
checkCheckpoint(rddsToCheck, 1, checkpointInterval)
Expand Down

0 comments on commit d41902c

Please sign in to comment.