From 1d4a5fd5888a6e2860d45a4d755e4b5bea63690e Mon Sep 17 00:00:00 2001 From: Lewuathe Date: Mon, 23 Nov 2015 18:51:21 +0900 Subject: [PATCH] [SPARK-11520][ML] RegressionMetrics should support instance weights --- .../mllib/evaluation/RegressionMetrics.scala | 29 ++++++++--- .../evaluation/RegressionMetricsSuite.scala | 51 +++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 1d8f4fe340fb4..4f8287c04e816 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,11 +27,23 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @param predictionAndObservationsWithWeight an RDD of (prediction, observation, weight) tuples. */ @Since("1.2.0") class RegressionMetrics @Since("1.2.0") ( - predictionAndObservations: RDD[(Double, Double)]) extends Logging { + predictionAndObservationsWithWeight: => RDD[(Double, Double, Double)]) extends Logging { + + /** + * An auxiliary constructor taking RDD without weights of sample datasets. + * In this case, the default weight of each sample should be 1.0 + * as MultivariateOnlineSummarizer do so. + * @param predictionAndObservation an RDD of (prediction, observation) pairs + */ + def this(predictionAndObservation: RDD[(Double, Double)]) { + this(predictionAndObservation.map({ + case (prediction, observation) => (prediction, observation, 1.0) + })) + } /** * An auxiliary constructor taking a DataFrame. @@ -39,16 +51,17 @@ class RegressionMetrics @Since("1.2.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1), 1.0))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { - val summary: MultivariateStatisticalSummary = predictionAndObservations.map { - case (prediction, observation) => Vectors.dense(observation, observation - prediction) + val summary: MultivariateStatisticalSummary = predictionAndObservationsWithWeight.map { + case (prediction, observation, weight) => + (Vectors.dense(observation, observation - prediction), weight) }.aggregate(new MultivariateOnlineSummarizer())( - (summary, v) => summary.add(v), + (summary, sample) => summary.add(sample._1, sample._2), (sum1, sum2) => sum1.merge(sum2) ) summary @@ -57,8 +70,8 @@ class RegressionMetrics @Since("1.2.0") ( private lazy val SStot = summary.variance(0) * (summary.count - 1) private lazy val SSreg = { val yMean = summary.mean(0) - predictionAndObservations.map { - case (prediction, _) => math.pow(prediction - yMean, 2) + predictionAndObservationsWithWeight.map { + case (prediction, _, _) => math.pow(prediction - yMean, 2) }.sum() } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 4b7f1be58f99b..bd3056ec5e296 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -109,4 +109,55 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { "root mean squared error mismatch") assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") } + + test("regression metrics with same(1.0) weight samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + } + + + /** + * The following values are hand calculated using the formula: + * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + * preds = c(2.25, -0.25, 1.75, 7.75) + * obs = c(3.0, -0.5, 2.0, 7.0) + * weights = c(0.1, 0.2, 0.15, 0.05) + * count = 4 + * + * Weighted metrics can be calculated with MultivariateStatisticalSummary. + * (observations, observations - predictions) + * mean (1.7, 0.05) + * variance (7.3, 0.3) + * numNonZeros (0.5, 0.5) + * max (7.0, 0.75) + * min (-0.5, -0.75) + * normL2 (2.0, 0.32596) + * normL1 (1.05, 0.2) + * + * explainedVariance: sum((preds - 1.7)^2) / count = 10.1775 + * meanAbsoluteError: normL1(1) / count = 0.05 + * meanSquaredError: normL2(1)^2 / count = 0.02656 + * rootMeanSquaredError: sqrt(meanSquaredError) = 0.16298 + * r2: 1 - normL2(1)^2 / (variance(0) * (count - 1)) = 0.9951484 + */ + test("regression metrics with weighted samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 10.1775 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.05 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.02656248 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.16298 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.9951484 absTol 1E-5, "r2 score mismatch") + } }