-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-11520][ML] RegressionMetrics should support instance weights #9907
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,4 +109,55 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { | |
"root mean squared error mismatch") | ||
assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") | ||
} | ||
|
||
test("regression metrics with same(1.0) weight samples") { | ||
val predictionAndObservationWithWeight = sc.parallelize( | ||
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) | ||
val metrics = new RegressionMetrics(predictionAndObservationWithWeight) | ||
assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, | ||
"explained variance regression score mismatch") | ||
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") | ||
assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") | ||
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, | ||
"root mean squared error mismatch") | ||
assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") | ||
} | ||
|
||
|
||
/** | ||
* The following values are hand calculated using the formula: | ||
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] | ||
* preds = c(2.25, -0.25, 1.75, 7.75) | ||
* obs = c(3.0, -0.5, 2.0, 7.0) | ||
* weights = c(0.1, 0.2, 0.15, 0.05) | ||
* count = 4 | ||
* | ||
* Weighted metrics can be calculated with MultivariateStatisticalSummary. | ||
* (observations, observations - predictions) | ||
* mean (1.7, 0.05) | ||
* variance (7.3, 0.3) | ||
* numNonZeros (0.5, 0.5) | ||
* max (7.0, 0.75) | ||
* min (-0.5, -0.75) | ||
* normL2 (2.0, 0.32596) | ||
* normL1 (1.05, 0.2) | ||
* | ||
* explainedVariance: sum((preds - 1.7)^2) / count = 10.1775 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am having an issue with these formulas, based on the very wikipedia page you are referencing. It should not be normalized by the count but by the weight sum, right? |
||
* meanAbsoluteError: normL1(1) / count = 0.05 | ||
* meanSquaredError: normL2(1)^2 / count = 0.02656 | ||
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.16298 | ||
* r2: 1 - normL2(1)^2 / (variance(0) * (count - 1)) = 0.9951484 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing for this formula. |
||
*/ | ||
test("regression metrics with weighted samples") { | ||
val predictionAndObservationWithWeight = sc.parallelize( | ||
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) | ||
val metrics = new RegressionMetrics(predictionAndObservationWithWeight) | ||
assert(metrics.explainedVariance ~== 10.1775 absTol 1E-5, | ||
"explained variance regression score mismatch") | ||
assert(metrics.meanAbsoluteError ~== 0.05 absTol 1E-5, "mean absolute error mismatch") | ||
assert(metrics.meanSquaredError ~== 0.02656248 absTol 1E-5, "mean squared error mismatch") | ||
assert(metrics.rootMeanSquaredError ~== 0.16298 absTol 1E-5, | ||
"root mean squared error mismatch") | ||
assert(metrics.r2 ~== 0.9951484 absTol 1E-5, "r2 score mismatch") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rename this
summary
variable, it is used 3 times for different objects.Also,
sample
is a bit too generic here, and using the_xx
methods are not very intuitive. I suggest you unpack the tuple fully:{ case (currentSummary, (vec, weight)) => currentSummary.add(vec, weight) }