-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation #4677
Changes from 7 commits
77549a9
3e74372
55e5c3b
fad9b6e
b928a19
b48a70f
e4d799b
1bb21d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | |
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
val algo = boostingStrategy.treeStrategy.algo | ||
algo match { | ||
case Regression => GradientBoostedTrees.boost(input, boostingStrategy) | ||
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) | ||
case Classification => | ||
// Map labels to -1, +1 so binary classification can be treated as regression. | ||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
GradientBoostedTrees.boost(remappedInput, boostingStrategy) | ||
GradientBoostedTrees.boost(remappedInput, | ||
remappedInput, boostingStrategy, validate=false) | ||
case _ => | ||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
} | ||
|
@@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | |
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
run(input.rdd) | ||
} | ||
} | ||
|
||
/** | ||
* Method to validate a gradient boosting model | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix param name |
||
* @param validationInput Validation dataset: | ||
RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
Should be different from and follow the same distribution as input. | ||
e.g., these two datasets could be created from an original dataset | ||
by using [[org.apache.spark.rdd.RDD.randomSplit()]] | ||
* @return a gradient boosted trees model that can be used for prediction | ||
*/ | ||
def runWithValidation( | ||
input: RDD[LabeledPoint], | ||
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
val algo = boostingStrategy.treeStrategy.algo | ||
algo match { | ||
case Regression => GradientBoostedTrees.boost( | ||
input, validationInput, boostingStrategy, validate=true) | ||
case Classification => | ||
// Map labels to -1, +1 so binary classification can be treated as regression. | ||
val remappedInput = input.map( | ||
x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
val remappedValidationInput = validationInput.map( | ||
x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, | ||
validate=true) | ||
case _ => | ||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
} | ||
} | ||
|
||
/** | ||
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. | ||
*/ | ||
def runWithValidation( | ||
input: JavaRDD[LabeledPoint], | ||
validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
runWithValidation(input.rdd, validationInput.rdd) | ||
} | ||
} | ||
|
||
object GradientBoostedTrees extends Logging { | ||
|
||
|
@@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging { | |
/** | ||
* Internal method for performing regression using trees as base learners. | ||
* @param input training dataset | ||
* @param validationInput validation dataset, ignored if validate is set to false. | ||
* @param boostingStrategy boosting parameters | ||
* @param validate whether or not to use the validation dataset. | ||
* @return a gradient boosted trees model that can be used for prediction | ||
*/ | ||
private def boost( | ||
input: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { | ||
validationInput: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy, | ||
validate: Boolean): GradientBoostedTreesModel = { | ||
|
||
val timer = new TimeTracker() | ||
timer.start("total") | ||
|
@@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging { | |
val learningRate = boostingStrategy.learningRate | ||
// Prepare strategy for individual trees, which use regression with variance impurity. | ||
val treeStrategy = boostingStrategy.treeStrategy.copy | ||
val validationTol = boostingStrategy.validationTol | ||
treeStrategy.algo = Regression | ||
treeStrategy.impurity = Variance | ||
treeStrategy.assertValid() | ||
|
@@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging { | |
baseLearnerWeights(0) = 1.0 | ||
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) | ||
logDebug("error of gbt = " + loss.computeError(startingModel, input)) | ||
|
||
// Note: A model of type regression is used since we require raw prediction | ||
timer.stop("building tree 0") | ||
|
||
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 | ||
var bestM = 1 | ||
|
||
// psuedo-residual for second iteration | ||
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), | ||
point.features)) | ||
|
||
var m = 1 | ||
while (m < numIterations) { | ||
timer.start(s"building tree $m") | ||
|
@@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging { | |
val partialModel = new GradientBoostedTreesModel( | ||
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) | ||
logDebug("error of gbt = " + loss.computeError(partialModel, input)) | ||
|
||
if (validate) { | ||
// Stop training early if | ||
// 1. Reduction in error is less than the validationTol or | ||
// 2. If the error increases, that is if the model is overfit. | ||
// We want the model returned corresponding to the best validation error. | ||
val currentValidateError = loss.computeError(partialModel, validationInput) | ||
if (bestValidateError - currentValidateError < validationTol) { | ||
return new GradientBoostedTreesModel( | ||
boostingStrategy.treeStrategy.algo, | ||
baseLearners.slice(0, bestM), | ||
baseLearnerWeights.slice(0, bestM)) | ||
} else if (currentValidateError < bestValidateError){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: space before { |
||
bestValidateError = currentValidateError | ||
bestM = m + 1 | ||
} | ||
} | ||
// Update data with pseudo-residuals | ||
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), | ||
point.features)) | ||
|
@@ -191,4 +255,5 @@ object GradientBoostedTrees extends Logging { | |
new GradientBoostedTreesModel( | ||
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,15 +34,20 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} | |
* weak hypotheses used in the final model. | ||
* @param learningRate Learning rate for shrinking the contribution of each estimator. The | ||
* learning rate should be between in the interval (0, 1] | ||
* @param validationTol Useful when runWithValidation is used. If the error rate on the | ||
* validation input between two iterations is less than the validationTol | ||
* then stop. Ignored when [[run]] is used. | ||
*/ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove empty line |
||
@Experimental | ||
case class BoostingStrategy( | ||
// Required boosting parameters | ||
@BeanProperty var treeStrategy: Strategy, | ||
@BeanProperty var loss: Loss, | ||
// Optional boosting parameters | ||
@BeanProperty var numIterations: Int = 100, | ||
@BeanProperty var learningRate: Double = 0.1) extends Serializable { | ||
@BeanProperty var learningRate: Double = 0.1, | ||
@BeanProperty var validationTol: Double = 1e-5) extends Serializable { | ||
|
||
/** | ||
* Check validity of parameters. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,6 +158,63 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
} | ||
} | ||
} | ||
|
||
test("runWithValidation performs better on a validation dataset (Regression)") { | ||
// Set numIterations large enough so that it stops early. | ||
val numIterations = 20 | ||
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) | ||
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) | ||
|
||
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, | ||
categoricalFeaturesInfo = Map.empty) | ||
Array(SquaredError, AbsoluteError).foreach { error => | ||
val boostingStrategy = | ||
new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) | ||
|
||
val gbtValidate = new GradientBoostedTrees(boostingStrategy). | ||
runWithValidation(trainRdd, validateRdd) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please put period (.) on line with runWithValidation:
|
||
assert(gbtValidate.numTrees !== numIterations) | ||
|
||
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
val errorWithoutValidation = error.computeError(gbt, validateRdd) | ||
val errorWithValidation = error.computeError(gbtValidate, validateRdd) | ||
assert(errorWithValidation < errorWithoutValidation) | ||
} | ||
} | ||
|
||
test("runWithValidation performs better on a validation dataset (Classification)") { | ||
// Set numIterations large enough so that it stops early. | ||
val numIterations = 20 | ||
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) | ||
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) | ||
|
||
val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, | ||
categoricalFeaturesInfo = Map.empty) | ||
val boostingStrategy = | ||
new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) | ||
|
||
// Test that it stops early. | ||
val gbtValidate = new GradientBoostedTrees(boostingStrategy). | ||
runWithValidation(trainRdd, validateRdd) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please put period (.) on line with runWithValidation:
|
||
assert(gbtValidate.numTrees !== numIterations) | ||
|
||
// Remap labels to {-1, 1} | ||
val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) | ||
|
||
// The error checked for internally in the GradientBoostedTrees is based on Regression. | ||
// Hence for the validation model, the Classification error need not be strictly less than | ||
// that done with validation. | ||
val gbtValidateRegressor = new GradientBoostedTreesModel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, I misunderstood this the first time you asked about it. It's weird to create a regression model and test using LogLoss. I would test on validateRdd, not on trainRdd. That's really all we need to check. And it should let you keep the model a Classification model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have addressed all your comment except this. This test fails if I don't make this explicit conversion. I think what happens is the number of true labels classified is the same whether or not I run with validation in because of the dataset that is being tested here. i.e when I run without validation, there might be an increase in the validation error but there is no change in the number of labels that are predicted correctly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and I'm not sure it's that weird, because that is what is being done internally :P , unless you have other ideas to test this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I got confused about which dataset remappedInput was from. In that case, I think it's just a flaky test. I think it would be sufficient to check for error <= instead of <, especially since you are already checking that it stops early. |
||
Regression, gbtValidate.trees, gbtValidate.treeWeights) | ||
val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput) | ||
|
||
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights) | ||
val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) | ||
|
||
assert(errorWithValidation < errorWithoutValidation) | ||
} | ||
|
||
} | ||
|
||
private object GradientBoostedTreesSuite { | ||
|
@@ -166,4 +223,6 @@ private object GradientBoostedTreesSuite { | |
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) | ||
|
||
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) | ||
val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) | ||
val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep the backticks:
validationTol
(same for BoostingStrategy). But the asterisks for bold are not needed.