Skip to content

Commit

Permalink
add regression metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
liangyanbo committed Oct 28, 2014
1 parent fae095b commit 43bb12b
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.rdd.RDDFunctions._

/**
* :: Experimental ::
* Evaluator for regression.
*
* @param valuesAndPreds an RDD of (value, pred) pairs.
*/
@Experimental
class RegressionMetrics(valuesAndPreds: RDD[(Double, Double)]) extends Logging {

/**
* Use MultivariateOnlineSummarizer to calculate mean and variance of different combination.
* MultivariateOnlineSummarizer is a numerically stable algorithm to compute mean and variance
* in a online fashion.
*/
private lazy val summarizer: MultivariateOnlineSummarizer = {
val summarizer: MultivariateOnlineSummarizer = valuesAndPreds.map{
case (value,pred) => Vectors.dense(
Array(value, pred, value - pred, math.abs(value - pred), math.pow(value - pred, 2.0))
)
}.treeAggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(sum1,sum2) => sum1.merge(sum2)
)
summarizer
}

/**
* Computes the explained variance regression score
*/
def explainedVarianceScore(): Double = {
1 - summarizer.variance(2) / summarizer.variance(0)
}

/**
* Computes the mean absolute error, which is a risk function corresponding to the
* expected value of the absolute error loss or l1-norm loss.
*/
def mae(): Double = {
summarizer.mean(3)
}

/**
* Computes the mean square error, which is a risk function corresponding to the
* expected value of the squared error loss or quadratic loss.
*/
def mse(): Double = {
summarizer.mean(4)
}

/**
* Computes R^2^, the coefficient of determination.
* @return
*/
def r2_socre(): Double = {
1 - summarizer.mean(4) * summarizer.count / (summarizer.variance(0) * (summarizer.count - 1))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class RegressionMetricsSuite extends FunSuite with LocalSparkContext {

test("regression metrics") {
val valuesAndPreds = sc.parallelize(
Seq((3.0,2.5),(-0.5,0.0),(2.0,2.0),(7.0,8.0)),2)
val metrics = new RegressionMetrics(valuesAndPreds)
assert(metrics.explainedVarianceScore() ~== 0.95717 absTol 1E-5,"explained variance regression score mismatch")
assert(metrics.mae() ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.mse() ~== 0.375 absTol 1E-5, "mean square error mismatch")
assert(metrics.r2_socre() ~== 0.94861 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics with complete fitting") {
val valuesAndPreds = sc.parallelize(
Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)),2)
val metrics = new RegressionMetrics(valuesAndPreds)
assert(metrics.explainedVarianceScore() ~== 1.0 absTol 1E-5,"explained variance regression score mismatch")
assert(metrics.mae() ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.mse() ~== 0.0 absTol 1E-5, "mean square error mismatch")
assert(metrics.r2_socre() ~== 1.0 absTol 1E-5, "r2 score mismatch")
}
}

0 comments on commit 43bb12b

Please sign in to comment.