Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11520][ML] RegressionMetrics should support instance weights #9907

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,41 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs.
* @param predictionAndObservationsWithWeight an RDD of (prediction, observation, weight) tuples.
*/
@Since("1.2.0")
class RegressionMetrics @Since("1.2.0") (
predictionAndObservations: RDD[(Double, Double)]) extends Logging {
predictionAndObservationsWithWeight: => RDD[(Double, Double, Double)]) extends Logging {

/**
* An auxiliary constructor taking RDD without weights of sample datasets.
* In this case, the default weight of each sample should be 1.0
* as MultivariateOnlineSummarizer do so.
* @param predictionAndObservation an RDD of (prediction, observation) pairs
*/
def this(predictionAndObservation: RDD[(Double, Double)]) {
this(predictionAndObservation.map({
case (prediction, observation) => (prediction, observation, 1.0)
}))
}

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndObservations a DataFrame with two double columns:
* prediction and observation
*/
private[mllib] def this(predictionAndObservations: DataFrame) =
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1))))
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1), 1.0)))

/**
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
val summary: MultivariateStatisticalSummary = predictionAndObservationsWithWeight.map {
case (prediction, observation, weight) =>
(Vectors.dense(observation, observation - prediction), weight)
}.aggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(summary, sample) => summary.add(sample._1, sample._2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rename this summary variable, it is used 3 times for different objects.

Also, sample is a bit too generic here, and using the _xx methods are not very intuitive. I suggest you unpack the tuple fully: { case (currentSummary, (vec, weight)) => currentSummary.add(vec, weight) }

(sum1, sum2) => sum1.merge(sum2)
)
summary
Expand All @@ -57,8 +70,8 @@ class RegressionMetrics @Since("1.2.0") (
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
predictionAndObservationsWithWeight.map {
case (prediction, _, _) => math.pow(prediction - yMean, 2)
}.sum()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,55 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics with same(1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
}


/**
* The following values are hand calculated using the formula:
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
* preds = c(2.25, -0.25, 1.75, 7.75)
* obs = c(3.0, -0.5, 2.0, 7.0)
* weights = c(0.1, 0.2, 0.15, 0.05)
* count = 4
*
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
* (observations, observations - predictions)
* mean (1.7, 0.05)
* variance (7.3, 0.3)
* numNonZeros (0.5, 0.5)
* max (7.0, 0.75)
* min (-0.5, -0.75)
* normL2 (2.0, 0.32596)
* normL1 (1.05, 0.2)
*
* explainedVariance: sum((preds - 1.7)^2) / count = 10.1775
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am having an issue with these formulas, based on the very wikipedia page you are referencing. It should not be normalized by the count but by the weight sum, right?

* meanAbsoluteError: normL1(1) / count = 0.05
* meanSquaredError: normL2(1)^2 / count = 0.02656
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.16298
* r2: 1 - normL2(1)^2 / (variance(0) * (count - 1)) = 0.9951484
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing for this formula.

*/
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 10.1775 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.05 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.02656248 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.16298 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.9951484 absTol 1E-5, "r2 score mismatch")
}
}