Skip to content

Commit

Permalink
Minor
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Mar 15, 2015
1 parent 6e8aa10 commit 352001f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,8 @@ object AbsoluteError extends Loss {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param prediction Predicted label.
* @param datum LabeledPoint.
* @return Absolute error of model on the given datapoint.
*/
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
val err = datum.label - prediction
override def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,8 @@ object LogLoss extends Loss {
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param prediction Predicted label.
* @param datum LabeledPoint
* @return log loss of model on the datapoint.
*/
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
val margin = 2.0 * datum.label * prediction
override def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ trait Loss extends Serializable {
* @return Measure of model error on data
*/
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map(point => computeError(model.predict(point.features), point)).mean()
data.map(point => computeError(model.predict(point.features), point.label)).mean()
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param prediction Predicted label.
* @param datum LabeledPoint
* @param label True label.
* @return Measure of model error on datapoint.
*/
def computeError(prediction: Double, datum: LabeledPoint): Double
def computeError(prediction: Double, label: Double): Double

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,8 @@ object SquaredError extends Loss {
2.0 * (model.predict(point.features) - point.label)
}

/**
* Method to calculate loss when the predictions are already known.
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
* predicted values from previously fit trees.
* @param prediction Predicted label.
* @param datum LabeledPoint
* @return Mean squared error of model on datapoint.
*/
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
val err = prediction - datum.label
override def computeError(prediction: Double, label: Double): Double = {
val err = prediction - label
err * err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class GradientBoostedTreesModel(

var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
val pred = treeWeights(0) * trees(0).predict(i.features)
val error = loss.computeError(pred, i)
val error = loss.computeError(pred, i.label)
(pred, error)
}
evaluationArray(0) = predictionAndError.values.mean()
Expand All @@ -143,13 +143,13 @@ class GradientBoostedTreesModel(
val broadcastWeights = sc.broadcast(treeWeights)

(1 until numIterations).map { nTree =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = broadcastWeights.value(nTree)
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
iter map {
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = broadcastWeights.value(nTree)
iter.map {
case (point, (pred, error)) => {
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
val newError = loss.computeError(newPred, point)
val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
}
Expand Down

0 comments on commit 352001f

Please sign in to comment.