Skip to content

Commit

Permalink
[SPARK-6025] Add helper method evaluateEachIteration to extract learn…
Browse files Browse the repository at this point in the history
…ing curve
  • Loading branch information
MechCoder committed Mar 5, 2015
1 parent 1f1fccc commit dbda033
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/mllib-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
iterations.
and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
(which gives the error or loss per iteration) to tune the number of iterations.

### Examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -52,14 +53,18 @@ import org.apache.spark.storage.StorageLevel
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
extends Serializable with Logging {

private val numIterations = boostingStrategy.numIterations
private var baseLearners = new Array[DecisionTreeModel](numIterations)
private var baseLearnerWeights = new Array[Double](numIterations)

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return a gradient boosted trees model that can be used for prediction
*/
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
val fitGradientBoostingModel = algo match {
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
Expand All @@ -69,6 +74,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
baseLearners = fitGradientBoostingModel.trees
baseLearnerWeights = fitGradientBoostingModel.treeWeights
fitGradientBoostingModel
}

/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param loss: evaluation metric that defaults to boostingStrategy.loss
* @return an array with index i having the losses or errors for the ensemble
* containing trees 1 to i + 1
*/
def evaluateEachIteration(
data: RDD[LabeledPoint],
loss: Loss = boostingStrategy.loss) : Array[Double] = {

val algo = boostingStrategy.treeStrategy.algo
val remappedData = algo match {
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
case _ => data
}
val initialTree = baseLearners(0)
val evaluationArray = Array.fill(numIterations)(0.0)

// Initial weight is 1.0
var predictionRDD = remappedData.map(i => initialTree.predict(i.features))
evaluationArray(0) = loss.computeError(remappedData, predictionRDD)

(1 until numIterations).map {nTree =>
predictionRDD = (remappedData zip predictionRDD) map {
case (point, pred) =>
pred + baseLearners(nTree).predict(point.features) * baseLearnerWeights(nTree)
}
evaluationArray(nTree) = loss.computeError(remappedData, predictionRDD)
}
evaluationArray
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,23 @@ object AbsoluteError extends Loss {
math.abs(err)
}.mean()
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean absolute error of model on data
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) => {
val err = yTrue.label - yPred
math.abs(err)
}
}
errorAcrossSamples.mean()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,23 @@ object LogLoss extends Loss {
2.0 * MLUtils.log1pExp(-margin)
}.mean()
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean log loss of model on data
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) =>
val margin = 2.0 * yTrue.label * yPred
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
errorAcrossSamples.mean()
}

}
10 changes: 10 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,14 @@ trait Loss extends Serializable {
*/
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Measure of model error on data
*/
def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]) : Double

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,22 @@ object SquaredError extends Loss {
err * err
}.mean()
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean squared error of model on data
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) =>
val err = yPred - yTrue.label
err * err
}
errorAcrossSamples.mean()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
assert(gbtValidate.numTrees !== numIterations)
val numTrees = gbtValidate.numTrees
assert(numTrees !== numIterations)

// Test that it performs better on the validation dataset.
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
val gbtModel = new GradientBoostedTrees(boostingStrategy)
val gbt = gbtModel.run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
Expand All @@ -188,6 +190,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
assert(errorWithValidation <= errorWithoutValidation)

// Test that results from evaluateEachIteration comply with runWithValidation.
// Note that convergenceTol is set to 0.0
val evaluationArray = gbtModel.evaluateEachIteration(validateRdd)
assert(evaluationArray.length === numIterations)
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
var i = 1
while (i < numTrees) {
assert(evaluationArray(i) < evaluationArray(i - 1))
i += 1
}
}
}
}
Expand Down

0 comments on commit dbda033

Please sign in to comment.