Skip to content

Commit

Permalink
add (0, 0), (1, 1) to roc, and (0, 1) to pr
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 9, 2014
1 parent fb4b6d2 commit 3f42e98
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.evaluation.binary

import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{UnionRDD, RDD}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.AreaUnderCurve
import org.apache.spark.Logging
Expand Down Expand Up @@ -103,22 +103,34 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])

/**
* Returns the receiver operating characteristic (ROC) curve,
* which is an RDD of (false positive rate, true positive rate).
* which is an RDD of (false positive rate, true positive rate)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
def roc(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall)
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
}

/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
def areaUnderROC(): Double = AreaUnderCurve.of(roc())

/**
* Returns the precision-recall curve,
* which is an RDD of (recall, precision), NOT (precision, recall).
* Returns the precision-recall curve, which is an RDD of (recall, precision),
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
* @see http://en.wikipedia.org/wiki/Precision_and_recall
*/
def pr(): RDD[(Double, Double)] = createCurve(Recall, Precision)
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
first.union(prCurve)
}

/**
* Computes the area under the precision-recall curve.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
val recall = tp.map(t => t.toDouble / p)
val fpr = fp.map(f => f.toDouble / n)
val roc = fpr.zip(recall)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
assert(metrics.thresholds().collect().toSeq === score)
assert(metrics.roc().collect().toSeq === roc)
assert(metrics.areaUnderROC() === AreaUnderCurve.of(roc))
assert(metrics.pr().collect().toSeq === pr)
assert(metrics.areaUnderPR() === AreaUnderCurve.of(pr))
assert(metrics.roc().collect().toSeq === rocCurve)
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
assert(metrics.pr().collect().toSeq === prCurve)
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1))
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision))
Expand Down

0 comments on commit 3f42e98

Please sign in to comment.