From 8f45533b9a5f7c3c1f46d0d15a9f1815fa6227d5 Mon Sep 17 00:00:00 2001 From: Siddharth Murching Date: Thu, 23 Jun 2016 16:40:26 -0700 Subject: [PATCH 1/4] Fix typo in BaggedPoint.scala, add simple instrumentation to Random Forests --- .../spark/ml/tree/impl/BaggedPoint.scala | 2 +- .../spark/ml/tree/impl/RandomForest.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 4e372702f0c65..8cd53777f87de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.random.XORShiftRandom * particularly for bagging (e.g., for random forests). * * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsamplingRate. + * number of times which this instance appears in each subsample. * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..01459c61ca926 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -173,11 +173,26 @@ private[spark] object RandomForest extends Logging { timer.stop("init") + // Variables used to keep track of runtime statistics (i.e. how many + // nodes get processed in each group/iteration on average) + var totalNodesProcessed = 0 + var numGroups = 0 + var minNodesPerGroup = Int.MaxValue + var maxNodesPerGroup = 0 + while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + // Update runtime statistics + val groupSize = nodesForGroup.values.map(_.length).sum + totalNodesProcessed += groupSize + numGroups += 1 + minNodesPerGroup = Math.min(minNodesPerGroup, groupSize) + maxNodesPerGroup = Math.max(maxNodesPerGroup, groupSize) + // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") @@ -196,6 +211,14 @@ private[spark] object RandomForest extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Print out runtime statistics + logInfo(s"Max nodes per group $maxNodesPerGroup") + logInfo(s"Min nodes per group $minNodesPerGroup") + logInfo(s"Processed $numGroups groups of nodes") + if (numGroups > 0) { + logInfo(s"Avg nodes per group ${totalNodesProcessed / numGroups.toDouble}") + } + // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { try { From bd7d24d4f5a79eca6ff9629706c254beba74bc45 Mon Sep 17 00:00:00 2001 From: Siddharth Murching Date: Thu, 23 Jun 2016 17:40:02 -0700 Subject: [PATCH 2/4] Reorder instrumentation logging statements to look nicer --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 01459c61ca926..fe344c371207c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -212,11 +212,11 @@ private[spark] object RandomForest extends Logging { logInfo(s"$timer") // Print out runtime statistics - logInfo(s"Max nodes per group $maxNodesPerGroup") - logInfo(s"Min nodes per group $minNodesPerGroup") logInfo(s"Processed $numGroups groups of nodes") + logInfo(s"Max nodes per group: $maxNodesPerGroup") + logInfo(s"Min nodes per group: $minNodesPerGroup") if (numGroups > 0) { - logInfo(s"Avg nodes per group ${totalNodesProcessed / numGroups.toDouble}") + logInfo(s"Average nodes per group: ${totalNodesProcessed / numGroups.toDouble}") } // Delete any remaining checkpoints used for node Id cache. From f5a6893a1314de5f6a33bd6fb912a77a6cb19fa1 Mon Sep 17 00:00:00 2001 From: Siddharth Murching Date: Thu, 23 Jun 2016 21:57:28 -0700 Subject: [PATCH 3/4] Max/min nodes per group statistics don't make sense if no groups of nodes are procesed; updated log statements to reflect this --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index fe344c371207c..7e5e1d74d1cce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -213,9 +213,9 @@ private[spark] object RandomForest extends Logging { // Print out runtime statistics logInfo(s"Processed $numGroups groups of nodes") - logInfo(s"Max nodes per group: $maxNodesPerGroup") - logInfo(s"Min nodes per group: $minNodesPerGroup") if (numGroups > 0) { + logInfo(s"Max nodes per group: $maxNodesPerGroup") + logInfo(s"Min nodes per group: $minNodesPerGroup") logInfo(s"Average nodes per group: ${totalNodesProcessed / numGroups.toDouble}") } From 7fb031eff488ca657e89220193866af0b39a358a Mon Sep 17 00:00:00 2001 From: Siddharth Murching Date: Thu, 23 Jun 2016 22:04:04 -0700 Subject: [PATCH 4/4] Remove spaces at end of line --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 7e5e1d74d1cce..2db42bd7df5c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -215,7 +215,7 @@ private[spark] object RandomForest extends Logging { logInfo(s"Processed $numGroups groups of nodes") if (numGroups > 0) { logInfo(s"Max nodes per group: $maxNodesPerGroup") - logInfo(s"Min nodes per group: $minNodesPerGroup") + logInfo(s"Min nodes per group: $minNodesPerGroup") logInfo(s"Average nodes per group: ${totalNodesProcessed / numGroups.toDouble}") }