diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 4e372702f0c65..8cd53777f87de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.random.XORShiftRandom * particularly for bagging (e.g., for random forests). * * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsamplingRate. + * number of times which this instance appears in each subsample. * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..2db42bd7df5c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -173,11 +173,26 @@ private[spark] object RandomForest extends Logging { timer.stop("init") + // Variables used to keep track of runtime statistics (i.e. how many + // nodes get processed in each group/iteration on average) + var totalNodesProcessed = 0 + var numGroups = 0 + var minNodesPerGroup = Int.MaxValue + var maxNodesPerGroup = 0 + while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + // Update runtime statistics + val groupSize = nodesForGroup.values.map(_.length).sum + totalNodesProcessed += groupSize + numGroups += 1 + minNodesPerGroup = Math.min(minNodesPerGroup, groupSize) + maxNodesPerGroup = Math.max(maxNodesPerGroup, groupSize) + // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") @@ -196,6 +211,14 @@ private[spark] object RandomForest extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Print out runtime statistics + logInfo(s"Processed $numGroups groups of nodes") + if (numGroups > 0) { + logInfo(s"Max nodes per group: $maxNodesPerGroup") + logInfo(s"Min nodes per group: $minNodesPerGroup") + logInfo(s"Average nodes per group: ${totalNodesProcessed / numGroups.toDouble}") + } + // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { try {