Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3723] [MLlib] Adding instrumentation to random forests #13881

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.util.random.XORShiftRandom
* particularly for bagging (e.g., for random forests).
*
* This holds one instance, as well as an array of weights which represent the (weighted)
* number of times which this instance appears in each subsamplingRate.
* number of times which this instance appears in each subsample.
* E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,26 @@ private[spark] object RandomForest extends Logging {

timer.stop("init")

// Variables used to keep track of runtime statistics (i.e. how many
// nodes get processed in each group/iteration on average)
var totalNodesProcessed = 0
var numGroups = 0
var minNodesPerGroup = Int.MaxValue
var maxNodesPerGroup = 0

while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)

// Update runtime statistics
val groupSize = nodesForGroup.values.map(_.length).sum
totalNodesProcessed += groupSize
numGroups += 1
minNodesPerGroup = Math.min(minNodesPerGroup, groupSize)
maxNodesPerGroup = Math.max(maxNodesPerGroup, groupSize)

// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
Expand All @@ -196,6 +211,14 @@ private[spark] object RandomForest extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

// Print out runtime statistics
logInfo(s"Processed $numGroups groups of nodes")
if (numGroups > 0) {
logInfo(s"Max nodes per group: $maxNodesPerGroup")
logInfo(s"Min nodes per group: $minNodesPerGroup")
logInfo(s"Average nodes per group: ${totalNodesProcessed / numGroups.toDouble}")
}

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
try {
Expand Down