Skip to content

Commit

Permalink
Automatically cleanup checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
witgo committed Apr 10, 2015
1 parent b5c51c8 commit c0087e0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
22 changes: 21 additions & 1 deletion core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.util.Utils

/**
Expand All @@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
private case class CleanAccum(accId: Long) extends CleanupTask
private case class CleanCheckpoint(rddId: Int) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
Expand Down Expand Up @@ -139,6 +140,11 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

/** Register a RDDCheckpointData for cleanup when it is garbage collected. */
def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int) {
registerForCleanup(rdd, CleanCheckpoint(parentId))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
Expand All @@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId, blocking = blockOnCleanupTasks)
}
}
}
Expand Down Expand Up @@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

/** Perform checkpoint cleanup. */
def doCleanCheckpoint(rddId: Int, blocking: Boolean) {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
RDDCheckpointData.clearRDDCheckpointData(sc,rddId, blocking)
logInfo("Cleaned rdd checkpoint data " + rddId)
}
catch {
case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e)
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand Down
31 changes: 26 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path

import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException}
import org.apache.spark._
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}

/**
Expand Down Expand Up @@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}

// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
Expand All @@ -92,8 +92,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
Expand Down Expand Up @@ -130,5 +135,21 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

// Used for synchronization
private[spark] object RDDCheckpointData
private[spark] object RDDCheckpointData {
def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
if (sc.checkpointDir.isDefined) {
Some(new Path(sc.checkpointDir.get, "rdd-" + rddId))
} else {
None
}
}

def clearRDDCheckpointData(sc: SparkContext, rddId: Int, blocking: Boolean = true) = {
rddCheckpointDataPath(sc, rddId).foreach { path =>
val fs = path.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(path)) {
fs.delete(path, true)
}
}
}
}
49 changes: 48 additions & 1 deletion core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.storage._
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand Down Expand Up @@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
postGCTester.assertCleanup()
}

test("automatically cleanup checkpoint") {
val checkpointDir = java.io.File.createTempFile("temp", "")
checkpointDir.deleteOnExit()
checkpointDir.delete()
var rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
var rddId = rdd.id

// Confirm the checkpoint directory exists
assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined)
val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get
val fs = path.getFileSystem(sc.hadoopConfiguration)
assert(fs.exists(path))

// the checkpoint is not cleaned by default (without the configuration set)
var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))

sc.stop()
val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint").
set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
sc = new SparkContext(conf)
rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
rddId = rdd.id

// Confirm the checkpoint directory exists
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))

// Test that GC causes checkpoint data cleanup after dereferencing the RDD
postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
}

test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
Expand Down

0 comments on commit c0087e0

Please sign in to comment.