Skip to content

Commit

Permalink
Combine regression and classification tests into a single one
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Feb 24, 2015
1 parent e4d799b commit 1bb21d4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 51 deletions.
2 changes: 1 addition & 1 deletion docs/mllib-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ training. The method runWithValidation has been provided to make use of this opt
first one being the training dataset and the second being the validation dataset.

The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
iterations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ object GradientBoostedTrees extends Logging {
boostingStrategy.treeStrategy.algo,
baseLearners.slice(0, bestM),
baseLearnerWeights.slice(0, bestM))
} else if (currentValidateError < bestValidateError){
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* validation input between two iterations is less than the validationTol
* then stop. Ignored when [[run]] is used.
*/

@Experimental
case class BoostingStrategy(
// Required boosting parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,62 +159,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}

test("runWithValidation performs better on a validation dataset (Regression)") {
test("runWithValidation stops early and performs better on a validation dataset") {
// Set numIterations large enough so that it stops early.
val numIterations = 20
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)

val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty)
Array(SquaredError, AbsoluteError).foreach { error =>
val boostingStrategy =
new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0)

val gbtValidate = new GradientBoostedTrees(boostingStrategy).
runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations)

val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
val errorWithoutValidation = error.computeError(gbt, validateRdd)
val errorWithValidation = error.computeError(gbtValidate, validateRdd)
assert(errorWithValidation < errorWithoutValidation)
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
(algos zip losses) map {
case (algo, loss) => {
val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations)

// Test that it performs better on the validation dataset.
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
(loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
} else {
(loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
assert(errorWithValidation <= errorWithoutValidation)
}
}
}

test("runWithValidation performs better on a validation dataset (Classification)") {
// Set numIterations large enough so that it stops early.
val numIterations = 20
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)

val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0)

// Test that it stops early.
val gbtValidate = new GradientBoostedTrees(boostingStrategy).
runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations)

// Remap labels to {-1, 1}
val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))

// The error checked for internally in the GradientBoostedTrees is based on Regression.
// Hence for the validation model, the Classification error need not be strictly less than
// that done with validation.
val gbtValidateRegressor = new GradientBoostedTreesModel(
Regression, gbtValidate.trees, gbtValidate.treeWeights)
val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput)

val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights)
val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput)

assert(errorWithValidation < errorWithoutValidation)
}

}

private object GradientBoostedTreesSuite {
Expand Down

0 comments on commit 1bb21d4

Please sign in to comment.