Skip to content

Commit

Permalink
[SPARK-8468] [ML] Take the negative of some metrics in RegressionEval…
Browse files Browse the repository at this point in the history
…uator to get correct cross validation

JIRA: https://issues.apache.org/jira/browse/SPARK-8468

Author: Liang-Chi Hsieh <[email protected]>

Closes apache#6905 from viirya/cv_min and squashes the following commits:

930d3db [Liang-Chi Hsieh] Fix python unit test and add document.
d632135 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cv_min
16e3b2c [Liang-Chi Hsieh] Take the negative instead of reciprocal.
c3dd8d9 [Liang-Chi Hsieh] For comments.
b5f52c1 [Liang-Chi Hsieh] Add param to CrossValidator for choosing whether to maximize evaulation value.
  • Loading branch information
viirya authored and jkbradley committed Jun 20, 2015
1 parent 1b6fe9b commit 0b89951
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String)

/**
* param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
*
* Because we will maximize evaluation value (ref: `CrossValidator`),
* when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
* we take and output the negative of this metric.
* @group param
*/
val metricName: Param[String] = {
Expand Down Expand Up @@ -70,13 +74,13 @@ final class RegressionEvaluator(override val uid: String)
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" =>
metrics.rootMeanSquaredError
-metrics.rootMeanSquaredError
case "mse" =>
metrics.meanSquaredError
-metrics.meanSquaredError
case "r2" =>
metrics.r2
case "mae" =>
metrics.meanAbsoluteError
-metrics.meanAbsoluteError
}
metric
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array

/**
* :: Experimental ::
* A param amd its value.
* A param and its value.
*/
@Experimental
case class ParamPair[T](param: Param[T], value: T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext

// default = rmse
val evaluator = new RegressionEvaluator()
assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)

// r2 score
evaluator.setMetricName("r2")
assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)

// mae
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -58,6 +59,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(cvModel.avgMetrics.length === lrParamMaps.length)
}

test("cross validation with linear regression") {
val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))

val trainer = new LinearRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(trainer.regParam, Array(1000.0, 0.001))
.addGrid(trainer.maxIter, Array(0, 10))
.build()
val eval = new RegressionEvaluator()
val cv = new CrossValidator()
.setEstimator(trainer)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)

eval.setMetricName("r2")
val cvModel2 = cv.fit(dataset)
val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent2.getRegParam === 0.001)
assert(parent2.getMaxIter === 10)
assert(cvModel2.avgMetrics.length === lrParamMaps.length)
}

test("validateParams should check estimatorParamMaps") {
import CrossValidatorSuite._

Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
...
>>> evaluator = RegressionEvaluator(predictionCol="raw")
>>> evaluator.evaluate(dataset)
2.842...
-2.842...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
0.993...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
2.649...
-2.649...
"""
# a placeholder to make it appear in the generated doc
# Because we will maximize evaluation value (ref: `CrossValidator`),
# when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
# we take and output the negative of this metric.
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (mse|rmse|r2|mae)")

Expand Down

0 comments on commit 0b89951

Please sign in to comment.