diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 34583da800079..562663ad36b40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -27,19 +27,19 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl /** Precision. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = - c.tp.toDouble / (c.tp + c.fp) + c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) } /** False positive rate. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = - c.fp.toDouble / c.n + c.numFalsePositives.toDouble / c.numNegatives } /** Recall. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = - c.tp.toDouble / c.p + c.numTruePositives.toDouble / c.numPositives } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala index 658142b0ebf53..ed7b0fc943367 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -33,22 +33,22 @@ private case class BinaryConfusionMatrixImpl( totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { /** number of true positives */ - override def tp: Long = count.numPositives + override def numTruePositives: Long = count.numPositives /** number of false positives */ - override def fp: Long = count.numNegatives + override def numFalsePositives: Long = count.numNegatives /** number of false negatives */ - override def fn: Long = totalCount.numPositives - count.numPositives + override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives /** number of true negatives */ - override def tn: Long = totalCount.numNegatives - count.numNegatives + override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives /** number of positives */ - override def p: Long = totalCount.numPositives + override def numPositives: Long = totalCount.numPositives /** number of negatives */ - override def n: Long = totalCount.numNegatives + override def numNegatives: Long = totalCount.numNegatives } /** @@ -57,10 +57,10 @@ private case class BinaryConfusionMatrixImpl( * @param scoreAndLabels an RDD of (score, label) pairs. */ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) - extends Serializable with Logging { + extends Serializable with Logging { private lazy val ( - cumCounts: RDD[(Double, LabelCounter)], + cumulativeCounts: RDD[(Double, LabelCounter)], confusions: RDD[(Double, BinaryConfusionMatrix)]) = { // Create a bin for each distinct score value, count positives and negatives within each bin, // and then sort by score values in descending order. @@ -74,32 +74,32 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) iter.foreach(agg += _) Iterator(agg) }, preservesPartitioning = true).collect() - val partitionwiseCumCounts = + val partitionwiseCumulativeCounts = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c) - val totalCount = partitionwiseCumCounts.last + val totalCount = partitionwiseCumulativeCounts.last logInfo(s"Total counts: $totalCount") - val cumCounts = counts.mapPartitionsWithIndex( + val cumulativeCounts = counts.mapPartitionsWithIndex( (index: Int, iter: Iterator[(Double, LabelCounter)]) => { - val cumCount = partitionwiseCumCounts(index) + val cumCount = partitionwiseCumulativeCounts(index) iter.map { case (score, c) => cumCount += c (score, cumCount.clone()) } }, preservesPartitioning = true) - cumCounts.persist() - val confusions = cumCounts.map { case (score, cumCount) => + cumulativeCounts.persist() + val confusions = cumulativeCounts.map { case (score, cumCount) => (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) } - (cumCounts, confusions) + (cumulativeCounts, confusions) } /** Unpersist intermediate RDDs used in the computation. */ def unpersist() { - cumCounts.unpersist() + cumulativeCounts.unpersist() } /** Returns thresholds in descending order. */ - def thresholds(): RDD[Double] = cumCounts.map(_._1) + def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** * Returns the receiver operating characteristic (ROC) curve, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala index f846d05cd894c..75a75b216002a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -22,20 +22,20 @@ package org.apache.spark.mllib.evaluation.binary */ private[evaluation] trait BinaryConfusionMatrix { /** number of true positives */ - def tp: Long + def numTruePositives: Long /** number of false positives */ - def fp: Long + def numFalsePositives: Long /** number of false negatives */ - def fn: Long + def numFalseNegatives: Long /** number of true negatives */ - def tn: Long + def numTrueNegatives: Long /** number of positives */ - def p: Long = tp + fn + def numPositives: Long = numTruePositives + numFalseNegatives /** number of negatives */ - def n: Long = fp + tn + def numNegatives: Long = numFalsePositives + numTrueNegatives } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala index f92adbe8378c6..173fdaefab3da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala @@ -27,27 +27,29 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val score = Seq(0.8, 0.6, 0.4, 0.1) - val tp = Seq(1, 3, 3, 4) - val fp = Seq(0, 1, 2, 3) - val p = 4 - val n = 3 - val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) } - val recall = tp.map(t => t.toDouble / p) - val fpr = fp.map(f => f.toDouble / n) + val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = Seq(1, 3, 3, 4) + val numFalsePositives = Seq(0, 1, 2, 3) + val numPositives = 4 + val numNegatives = 3 + val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recall = numTruePositives.map(t => t.toDouble / numPositives) + val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr - val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) } - val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)} - assert(metrics.thresholds().collect().toSeq === score) + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + assert(metrics.thresholds().collect().toSeq === threshold) assert(metrics.roc().collect().toSeq === rocCurve) assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) assert(metrics.pr().collect().toSeq === prCurve) assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) - assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1)) - assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2)) - assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision)) - assert(metrics.recallByThreshold().collect().toSeq === score.zip(recall)) + assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) + assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) + assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) + assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) } }