Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation #5330

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
validate: Boolean): GradientBoostedTreesModel = {

val timer = new TimeTracker()
timer.start("total")
timer.start("init")
Expand Down Expand Up @@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
baseLearnerWeights(0) = firstTreeWeight
val startingModel = new GradientBoostedTreesModel(
Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))

var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
logDebug("error of gbt = " + predError.values.mean())

// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")

var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1

// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))
// pseudo-residual for second iteration
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}

var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
Expand All @@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
Regression, baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1))

predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
logDebug("error of gbt = " + predError.values.mean())

if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
val currentValidateError = loss.computeError(partialModel, validationInput)

validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
Expand All @@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging {
}
}
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
m += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* absolute error calculation.
* The gradient with respect to F(x) is: sign(F(x) - y)
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
override def gradient(prediction: Double, label: Double): Double = {
if (label - prediction < 0) 1.0 else -1.0
}

override def computeError(prediction: Double, label: Double): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,12 @@ object LogLoss extends Loss {
* Method to calculate the loss gradients for the gradient boosting calculation for binary
* classification
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
val prediction = model.predict(point.features)
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
override def gradient(prediction: Double, label: Double): Double = {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}

override def computeError(prediction: Double, label: Double): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ trait Loss extends Serializable {

/**
* Method to calculate the gradients for the gradient boosting calculation.
* @param model Model of the weak learner.
* @param point Instance of the training dataset.
* @param prediction Predicted feature
* @param label true label.
* @return Loss gradient.
*/
def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double
def gradient(prediction: Double, label: Double): Double

/**
* Method to calculate error of the base learner for the gradient boosting calculation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ object SquaredError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* squares error calculation.
* The gradient with respect to F(x) is: - 2 (y - F(x))
* @param model Ensemble model
* @param point Instance of the training dataset
* @param prediction Predicted label.
* @param label True label.
* @return Loss gradient
*/
override def gradient(
model: TreeEnsembleModel,
point: LabeledPoint): Double = {
2.0 * (model.predict(point.features) - point.label)
override def gradient(prediction: Double, label: Double): Double = {
2.0 * (prediction - label)
}

override def computeError(prediction: Double, label: Double): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,87 @@ class GradientBoostedTreesModel(

val numIterations = trees.length
val evaluationArray = Array.fill(numIterations)(0.0)
val localTreeWeights = treeWeights

var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
remappedData, localTreeWeights(0), trees(0), loss)

var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
val pred = treeWeights(0) * trees(0).predict(i.features)
val error = loss.computeError(pred, i.label)
(pred, error)
}
evaluationArray(0) = predictionAndError.values.mean()

// Avoid the model being copied across numIterations.
val broadcastTrees = sc.broadcast(trees)
val broadcastWeights = sc.broadcast(treeWeights)

(1 until numIterations).map { nTree =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"map" --> "foreach"

predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = broadcastWeights.value(nTree)
iter.map {
case (point, (pred, error)) => {
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
val currentTreeWeight = localTreeWeights(nTree)
iter.map { case (point, (pred, error)) =>
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
}
evaluationArray(nTree) = predictionAndError.values.mean()
}

broadcastTrees.unpersist()
broadcastWeights.unpersist()
evaluationArray
}

}

object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {

/**
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {
data.map { lp =>
val pred = initTreeWeight * initTree.predict(lp.features)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}

/**
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
* @param predictionAndError: predictionError RDD
* @param treeWeight: Learning rate.
* @param tree: Tree using which the prediction and error should be updated.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to each sample.
*/
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {

val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map {
case (lp, (pred, error)) => {
val newPred = pred + tree.predict(lp.features) * treeWeight
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
}
}
newPredError
}

override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
Expand Down