From d41902c085504ca30714d7665ee924c2c2a7fd91 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Jul 2015 23:03:44 -0700 Subject: [PATCH] Oops, forgot to update an extra time in the checkpointer tests, after the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before. --- .../spark/mllib/impl/PeriodicCheckpointer.scala | 10 +++++----- .../spark/mllib/impl/PeriodicGraphCheckpointer.scala | 12 +++++++----- .../spark/mllib/impl/PeriodicRDDCheckpointer.scala | 10 +++++----- .../mllib/impl/PeriodicGraphCheckpointerSuite.scala | 2 ++ .../mllib/impl/PeriodicRDDCheckpointerSuite.scala | 2 ++ 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index a29bafed8d037..72d3aabc9b1f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -106,22 +106,22 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } /** Checkpoint the Dataset */ - def checkpoint(data: T): Unit + protected def checkpoint(data: T): Unit /** Return true iff the Dataset is checkpointed */ - def isCheckpointed(data: T): Boolean + protected def isCheckpointed(data: T): Boolean /** * Persist the Dataset. * Note: This should handle checking the current [[StorageLevel]] of the Dataset. */ - def persist(data: T): Unit + protected def persist(data: T): Unit /** Unpersist the Dataset */ - def unpersist(data: T): Unit + protected def unpersist(data: T): Unit /** Get list of checkpoint files for this given Dataset */ - def getCheckpointFiles(data: T): Iterable[String] + protected def getCheckpointFiles(data: T): Iterable[String] /** * Call this at the end to delete any remaining checkpoint files. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index bebd495d76081..11a059536c50c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -80,17 +80,19 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - override def persist(data: Graph[VD, ED]): Unit = { + override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { data.persist() } } - override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index 42191fae74f7b..f31ed2aa90a64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -79,19 +79,19 @@ private[mllib] class PeriodicRDDCheckpointer[T]( sc: SparkContext) extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { - override def checkpoint(data: RDD[T]): Unit = data.checkpoint() + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() - override def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed - override def persist(data: RDD[T]): Unit = { + override protected def persist(data: RDD[T]): Unit = { if (data.getStorageLevel == StorageLevel.NONE) { data.persist() } } - override def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) - override def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { data.getCheckpointFile.map(x => x) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 993cc99435fe5..e331c75989187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -36,6 +36,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo val graph1 = createGraph(sc) val checkpointer = new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) @@ -58,6 +59,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo val graph1 = createGraph(sc) val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala index c1c8ff5f5e0e9..b2a459a68b5fa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -35,6 +35,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont val rdd1 = createRDD(sc) val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkPersistence(rddsToCheck, 1) @@ -56,6 +57,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont sc.setCheckpointDir(path) val rdd1 = createRDD(sc) val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) rdd1.count() rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkCheckpoint(rddsToCheck, 1, checkpointInterval)