Skip to content

Commit

Permalink
[SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomFo…
Browse files Browse the repository at this point in the history
…rests

I've added support for sampling_rate not equal to 1.0 . I have two major questions.

1. A Scala style test is failing, since the number of parameters now exceed 10.
2. I would like suggestions to understand how to test this.

Author: MechCoder <[email protected]>

Closes #4073 from MechCoder/spark-3726 and squashes the following commits:

8012fb2 [MechCoder] Add test in Strategy
e0e0d9c [MechCoder] TST: Add better test
d1df1b2 [MechCoder] Add test to verify subsampling behavior
a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0
  • Loading branch information
MechCoder authored and mengxr committed Jan 27, 2015
1 parent f2ba5c6 commit d6894b1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
logDebug("subsamplingRate = " + strategy.subsamplingRate)

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
Expand All @@ -155,19 +156,12 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

val (subsample, withReplacement) = {
// TODO: Have a stricter check for RF in the strategy
val isRandomForest = numTrees > 1
if (isRandomForest) {
(1.0, true)
} else {
(strategy.subsamplingRate, false)
}
}
val withReplacement = if (numTrees > 1) true else false

val baggedInput
= BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
.persist(StorageLevel.MEMORY_AND_DISK)
= BaggedPoint.convertToBaggedRDD(treeInput,
strategy.subsamplingRate, numTrees,
withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class Strategy (
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
require(subsamplingRate > 0 && subsamplingRate <= 1,
s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
s"$subsamplingRate")
}

/** Returns a shallow copy of this instance. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}

test("subsampling rate in RandomForest"){
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
useNodeIdCache = true)

val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
featureSubsetStrategy = "auto", seed = 123)
strategy.subsamplingRate = 0.5
val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
featureSubsetStrategy = "auto", seed = 123)
assert(rf1.toDebugString != rf2.toDebugString)
}

}


0 comments on commit d6894b1

Please sign in to comment.