diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 1af465f5d1df4..902ba01b08ebb 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -434,7 +434,7 @@ training. The method runWithValidation has been provided to make use of this opt first one being the training dataset and the second being the validation dataset. The training is stopped when the improvement in the validation error is not more than a certain tolerance -(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error decreases initially and later increases. There might be cases in which the validation error does not change monotonically, and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of iterations. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 65459707b9188..b4466ff40937f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -236,7 +236,7 @@ object GradientBoostedTrees extends Logging { boostingStrategy.treeStrategy.algo, baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) - } else if (currentValidateError < bestValidateError){ + } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError bestM = m + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 35b479fac5280..664c8df019233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -38,7 +38,6 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * validation input between two iterations is less than the validationTol * then stop. Ignored when [[run]] is used. */ - @Experimental case class BoostingStrategy( // Required boosting parameters diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index b4732a381f54a..b437aeaaf0547 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -159,62 +159,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } - test("runWithValidation performs better on a validation dataset (Regression)") { + test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - Array(SquaredError, AbsoluteError).foreach { error => - val boostingStrategy = - new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) - - val gbtValidate = new GradientBoostedTrees(boostingStrategy). - runWithValidation(trainRdd, validateRdd) - assert(gbtValidate.numTrees !== numIterations) - - val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) - val errorWithoutValidation = error.computeError(gbt, validateRdd) - val errorWithValidation = error.computeError(gbtValidate, validateRdd) - assert(errorWithValidation < errorWithoutValidation) + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + (algos zip losses) map { + case (algo, loss) => { + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + } } } - test("runWithValidation performs better on a validation dataset (Classification)") { - // Set numIterations large enough so that it stops early. - val numIterations = 20 - val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) - val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - - val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) - - // Test that it stops early. - val gbtValidate = new GradientBoostedTrees(boostingStrategy). - runWithValidation(trainRdd, validateRdd) - assert(gbtValidate.numTrees !== numIterations) - - // Remap labels to {-1, 1} - val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - - // The error checked for internally in the GradientBoostedTrees is based on Regression. - // Hence for the validation model, the Classification error need not be strictly less than - // that done with validation. - val gbtValidateRegressor = new GradientBoostedTreesModel( - Regression, gbtValidate.trees, gbtValidate.treeWeights) - val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput) - - val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) - val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights) - val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) - - assert(errorWithValidation < errorWithoutValidation) - } - } private object GradientBoostedTreesSuite {