diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala index c95cbf525a2e8..658142b0ebf53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation.binary -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{UnionRDD, RDD} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.AreaUnderCurve import org.apache.spark.Logging @@ -103,10 +103,17 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) /** * Returns the receiver operating characteristic (ROC) curve, - * which is an RDD of (false positive rate, true positive rate). + * which is an RDD of (false positive rate, true positive rate) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ - def roc(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall) + def roc(): RDD[(Double, Double)] = { + val rocCurve = createCurve(FalsePositiveRate, Recall) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) + } /** * Computes the area under the receiver operating characteristic (ROC) curve. @@ -114,11 +121,16 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) def areaUnderROC(): Double = AreaUnderCurve.of(roc()) /** - * Returns the precision-recall curve, - * which is an RDD of (recall, precision), NOT (precision, recall). + * Returns the precision-recall curve, which is an RDD of (recall, precision), + * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall */ - def pr(): RDD[(Double, Double)] = createCurve(Recall, Precision) + def pr(): RDD[(Double, Double)] = { + val prCurve = createCurve(Recall, Precision) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) + first.union(prCurve) + } /** * Computes the area under the precision-recall curve. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala index a1f3c44becb66..f92adbe8378c6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala @@ -35,15 +35,16 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) } val recall = tp.map(t => t.toDouble / p) val fpr = fp.map(f => f.toDouble / n) - val roc = fpr.zip(recall) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) + val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) } val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)} assert(metrics.thresholds().collect().toSeq === score) - assert(metrics.roc().collect().toSeq === roc) - assert(metrics.areaUnderROC() === AreaUnderCurve.of(roc)) - assert(metrics.pr().collect().toSeq === pr) - assert(metrics.areaUnderPR() === AreaUnderCurve.of(pr)) + assert(metrics.roc().collect().toSeq === rocCurve) + assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) + assert(metrics.pr().collect().toSeq === prCurve) + assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1)) assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2)) assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision))