Skip to content

Commit

Permalink
[SPARK-3726] Allow sampling_rate not equal to 1.0
Browse files Browse the repository at this point in the history
This reverts commit 6685b4494d2cb1ec72dbc540d2d747c75c6939ee.
  • Loading branch information
MechCoder committed Jan 17, 2015
1 parent 76389c5 commit a7bfc70
Showing 1 changed file with 3 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ private class RandomForest (
timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
val subsample = strategy.subsamplingRate
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
logDebug("algo = " + strategy.algo)
Expand All @@ -140,6 +141,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
logDebug("subsamplingRate = " + subsample)

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
Expand All @@ -155,15 +157,7 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

val (subsample, withReplacement) = {
// TODO: Have a stricter check for RF in the strategy
val isRandomForest = numTrees > 1
if (isRandomForest) {
(1.0, true)
} else {
(strategy.subsamplingRate, false)
}
}
val withReplacement = if (numTrees > 1) true else false

val baggedInput
= BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
Expand Down

0 comments on commit a7bfc70

Please sign in to comment.