Skip to content

Commit

Permalink
Refactor the method and stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Mar 10, 2015
1 parent dbda033 commit bc99ac6
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -53,18 +52,14 @@ import org.apache.spark.storage.StorageLevel
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
extends Serializable with Logging {

private val numIterations = boostingStrategy.numIterations
private var baseLearners = new Array[DecisionTreeModel](numIterations)
private var baseLearnerWeights = new Array[Double](numIterations)

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return a gradient boosted trees model that can be used for prediction
*/
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
val fitGradientBoostingModel = algo match {
algo match {
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
Expand All @@ -74,42 +69,6 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
baseLearners = fitGradientBoostingModel.trees
baseLearnerWeights = fitGradientBoostingModel.treeWeights
fitGradientBoostingModel
}

/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param loss: evaluation metric that defaults to boostingStrategy.loss
* @return an array with index i having the losses or errors for the ensemble
* containing trees 1 to i + 1
*/
def evaluateEachIteration(
data: RDD[LabeledPoint],
loss: Loss = boostingStrategy.loss) : Array[Double] = {

val algo = boostingStrategy.treeStrategy.algo
val remappedData = algo match {
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
case _ => data
}
val initialTree = baseLearners(0)
val evaluationArray = Array.fill(numIterations)(0.0)

// Initial weight is 1.0
var predictionRDD = remappedData.map(i => initialTree.predict(i.features))
evaluationArray(0) = loss.computeError(remappedData, predictionRDD)

(1 until numIterations).map {nTree =>
predictionRDD = (remappedData zip predictionRDD) map {
case (point, pred) =>
pred + baseLearners(nTree).predict(point.features) * baseLearnerWeights(nTree)
}
evaluationArray(nTree) = loss.computeError(remappedData, predictionRDD)
}
evaluationArray
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,13 @@ object AbsoluteError extends Loss {
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean absolute error of model on data
* @param datum: LabeledPoint
* @param prediction: Predicted label.
* @return Absolute error of model on the given datapoint.
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) => {
val err = yTrue.label - yPred
math.abs(err)
}
}
errorAcrossSamples.mean()
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
val err = datum.label - prediction
math.abs(err)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,14 @@ object LogLoss extends Loss {
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean log loss of model on data
* @param datum: LabeledPoint
* @param prediction: Predicted label.
* @return log loss of model on the datapoint.
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) =>
val margin = 2.0 * yTrue.label * yPred
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
errorAcrossSamples.mean()
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
val margin = 2.0 * datum.label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ trait Loss extends Serializable {
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Measure of model error on data
* @param datum: LabeledPoint
* @param prediction: Predicted label.
* @return Measure of model error on datapoint.
*/
def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]) : Double
def computeError(datum: LabeledPoint, prediction: Double) : Double

}
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,13 @@ object SquaredError extends Loss {
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param prediction: RDD[Double] of predicted labels.
* @return Mean squared error of model on data
* @param datum: LabeledPoint
* @param prediction: Predicted label.
* @return Mean squared error of model on datapoint.
*/
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
val errorAcrossSamples = (data zip prediction) map {
case (yTrue, yPred) =>
val err = yPred - yTrue.label
err * err
}
errorAcrossSamples.mean()
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
val err = prediction - datum.label
err * err
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
Expand Down Expand Up @@ -108,6 +110,53 @@ class GradientBoostedTreesModel(
}

override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion

/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param loss: evaluation metric.
* @return an array with index i having the losses or errors for the ensemble
* containing trees 1 to i + 1
*/
def evaluateEachIteration(
data: RDD[LabeledPoint],
loss: Loss) : Array[Double] = {

val sc = data.sparkContext
val remappedData = algo match {
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
case _ => data
}
val initialTree = trees(0)
val numIterations = trees.length
val evaluationArray = Array.fill(numIterations)(0.0)

// Initial weight is 1.0
var predictionErrorModel = remappedData.map {i =>
val pred = initialTree.predict(i.features)
val error = loss.computeError(i, pred)
(pred, error)
}
evaluationArray(0) = predictionErrorModel.values.mean()

// Avoid the model being copied across numIterations.
val broadcastTrees = sc.broadcast(trees)
val broadcastWeights = sc.broadcast(treeWeights)

(1 until numIterations).map {nTree =>
predictionErrorModel = (remappedData zip predictionErrorModel) map {
case (point, (pred, error)) => {
val newPred = pred + (
broadcastTrees.value(nTree).predict(point.features) * broadcastWeights.value(nTree))
val newError = loss.computeError(point, newPred)
(newPred, newError)
}
}
evaluationArray(nTree) = predictionErrorModel.values.mean()
}
evaluationArray
}

}

object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
assert(numTrees !== numIterations)

// Test that it performs better on the validation dataset.
val gbtModel = new GradientBoostedTrees(boostingStrategy)
val gbt = gbtModel.run(trainRdd)
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
Expand All @@ -193,7 +192,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {

// Test that results from evaluateEachIteration comply with runWithValidation.
// Note that convergenceTol is set to 0.0
val evaluationArray = gbtModel.evaluateEachIteration(validateRdd)
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
assert(evaluationArray.length === numIterations)
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
var i = 1
Expand Down

0 comments on commit bc99ac6

Please sign in to comment.