From c1c6c2228a446ed42bf4382d4703309865f6dc54 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Mar 2014 13:47:11 -0700 Subject: [PATCH] add AreaUnderCurve --- .../mllib/evaluation/AreaUnderCurve.scala | 55 +++++++++++++++++++ .../evaluation/AreaUnderCurveSuite.scala | 47 ++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala new file mode 100644 index 0000000000000..8d014c9f38726 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD + +/** + * Computes the area under the curve (AUC) using the trapezoidal rule. + */ +object AreaUnderCurve { + + private def trapezoid(points: Array[(Double, Double)]): Double = { + require(points.length == 2) + (points(1)._1 - points(0)._1) * (points(1)._2 + points(0)._2 ) / 2.0 + } + + /** + * Returns the area under the given curve. + * + * @param curve a RDD of ordered 2D points stored in pairs representing a curve + */ + def of(curve: RDD[(Double, Double)]): Double = { + curve.sliding(2).aggregate(0.0)( + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), + combOp = (_ + _) + ) + } + + /** + * Returns the area under the given curve. + * + * @param curve an iterable of ordered 2D points stored in pairs representing a curve + */ + def of(curve: Iterable[(Double, Double)]): Double = { + curve.sliding(2).map(_.toArray).filter(_.size == 2).aggregate(0.0)( + seqop = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), + combop = (_ + _) + ) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala new file mode 100644 index 0000000000000..78dd65c1721b6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { + + test("auc computation") { + val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) + val auc = 4.0 + assert(AreaUnderCurve.of(curve) === auc) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) == auc) + } + + test("auc of an empty curve") { + val curve = Seq.empty[(Double, Double)] + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } + + test("auc of a curve with a single point") { + val curve = Seq((1.0, 1.0)) + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } +}