Skip to content

Commit

Permalink
[SPARK-5604][MLLIB] remove checkpointDir from trees
Browse files Browse the repository at this point in the history
This is the second part of SPARK-5604, which removes checkpointDir from tree strategies. Note that this is a break change. I will mention it in the migration guide.

Author: Xiangrui Meng <[email protected]>

Closes apache#4407 from mengxr/SPARK-5604-1 and squashes the following commits:

13a276d [Xiangrui Meng] remove checkpointDir from trees
  • Loading branch information
mengxr committed Feb 6, 2015
1 parent 7dc4965 commit 6b88825
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}

params.checkpointDir.foreach(sc.setCheckpointDir)

val strategy
= new Strategy(
algo = params.algo,
Expand All @@ -282,7 +284,6 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache,
checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ private class RandomForest (
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
* the node Id cache periodically. This is the checkpoint directory
* to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates.
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored.
*/
@Experimental
class Strategy (
Expand All @@ -82,7 +81,6 @@ class Strategy (
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {

def isMulticlassClassification =
Expand Down Expand Up @@ -165,7 +163,7 @@ class Strategy (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
* The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)).
* @param checkpointDir The checkpoint directory where
* the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
private[tree] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
val checkpointDir: Option[String],
val checkpointInterval: Int) {

// Keep a reference to a previous node Ids for instances.
Expand All @@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0

// If a checkpoint directory is given, and there's no prior checkpoint directory,
// then set the checkpoint directory with the given one.
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
}

/**
* Update the node index values in the cache.
* This updates the RDD and its lineage.
Expand Down Expand Up @@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for.
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache.
Expand All @@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
def init(
data: RDD[BaggedPoint[TreePoint]],
numTrees: Int,
checkpointDir: Option[String],
checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = {
new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)),
checkpointDir,
checkpointInterval)
}
}

0 comments on commit 6b88825

Please sign in to comment.