Skip to content

Commit

Permalink
[SPARK-11520][ML] RegressionMetrics should support instance weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Nov 23, 2015
1 parent 4be360d commit 1d4a5fd
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,41 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs.
* @param predictionAndObservationsWithWeight an RDD of (prediction, observation, weight) tuples.
*/
@Since("1.2.0")
class RegressionMetrics @Since("1.2.0") (
predictionAndObservations: RDD[(Double, Double)]) extends Logging {
predictionAndObservationsWithWeight: => RDD[(Double, Double, Double)]) extends Logging {

/**
* An auxiliary constructor taking RDD without weights of sample datasets.
* In this case, the default weight of each sample should be 1.0
* as MultivariateOnlineSummarizer do so.
* @param predictionAndObservation an RDD of (prediction, observation) pairs
*/
def this(predictionAndObservation: RDD[(Double, Double)]) {
this(predictionAndObservation.map({
case (prediction, observation) => (prediction, observation, 1.0)
}))
}

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndObservations a DataFrame with two double columns:
* prediction and observation
*/
private[mllib] def this(predictionAndObservations: DataFrame) =
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1))))
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1), 1.0)))

/**
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
val summary: MultivariateStatisticalSummary = predictionAndObservationsWithWeight.map {
case (prediction, observation, weight) =>
(Vectors.dense(observation, observation - prediction), weight)
}.aggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
summary
Expand All @@ -57,8 +70,8 @@ class RegressionMetrics @Since("1.2.0") (
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
predictionAndObservationsWithWeight.map {
case (prediction, _, _) => math.pow(prediction - yMean, 2)
}.sum()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,55 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics with same(1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
}


/**
* The following values are hand calculated using the formula:
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
* preds = c(2.25, -0.25, 1.75, 7.75)
* obs = c(3.0, -0.5, 2.0, 7.0)
* weights = c(0.1, 0.2, 0.15, 0.05)
* count = 4
*
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
* (observations, observations - predictions)
* mean (1.7, 0.05)
* variance (7.3, 0.3)
* numNonZeros (0.5, 0.5)
* max (7.0, 0.75)
* min (-0.5, -0.75)
* normL2 (2.0, 0.32596)
* normL1 (1.05, 0.2)
*
* explainedVariance: sum((preds - 1.7)^2) / count = 10.1775
* meanAbsoluteError: normL1(1) / count = 0.05
* meanSquaredError: normL2(1)^2 / count = 0.02656
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.16298
* r2: 1 - normL2(1)^2 / (variance(0) * (count - 1)) = 0.9951484
*/
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 10.1775 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.05 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.02656248 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.16298 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.9951484 absTol 1E-5, "r2 score mismatch")
}
}

0 comments on commit 1d4a5fd

Please sign in to comment.