Skip to content

Commit

Permalink
replace TP/FP/TN/FN by their full names
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 10, 2014
1 parent 3f42e98 commit a05941d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
/** Precision. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / (c.tp + c.fp)
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
}

/** False positive rate. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.fp.toDouble / c.n
c.numFalsePositives.toDouble / c.numNegatives
}

/** Recall. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / c.p
c.numTruePositives.toDouble / c.numPositives
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ private case class BinaryConfusionMatrixImpl(
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {

/** number of true positives */
override def tp: Long = count.numPositives
override def numTruePositives: Long = count.numPositives

/** number of false positives */
override def fp: Long = count.numNegatives
override def numFalsePositives: Long = count.numNegatives

/** number of false negatives */
override def fn: Long = totalCount.numPositives - count.numPositives
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives

/** number of true negatives */
override def tn: Long = totalCount.numNegatives - count.numNegatives
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives

/** number of positives */
override def p: Long = totalCount.numPositives
override def numPositives: Long = totalCount.numPositives

/** number of negatives */
override def n: Long = totalCount.numNegatives
override def numNegatives: Long = totalCount.numNegatives
}

/**
Expand All @@ -57,10 +57,10 @@ private case class BinaryConfusionMatrixImpl(
* @param scoreAndLabels an RDD of (score, label) pairs.
*/
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
extends Serializable with Logging {
extends Serializable with Logging {

private lazy val (
cumCounts: RDD[(Double, LabelCounter)],
cumulativeCounts: RDD[(Double, LabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
Expand All @@ -74,32 +74,32 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
iter.foreach(agg += _)
Iterator(agg)
}, preservesPartitioning = true).collect()
val partitionwiseCumCounts =
val partitionwiseCumulativeCounts =
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
val totalCount = partitionwiseCumCounts.last
val totalCount = partitionwiseCumulativeCounts.last
logInfo(s"Total counts: $totalCount")
val cumCounts = counts.mapPartitionsWithIndex(
val cumulativeCounts = counts.mapPartitionsWithIndex(
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = partitionwiseCumCounts(index)
val cumCount = partitionwiseCumulativeCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
cumCounts.persist()
val confusions = cumCounts.map { case (score, cumCount) =>
cumulativeCounts.persist()
val confusions = cumulativeCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
}
(cumCounts, confusions)
(cumulativeCounts, confusions)
}

/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
cumCounts.unpersist()
cumulativeCounts.unpersist()
}

/** Returns thresholds in descending order. */
def thresholds(): RDD[Double] = cumCounts.map(_._1)
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)

/**
* Returns the receiver operating characteristic (ROC) curve,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ package org.apache.spark.mllib.evaluation.binary
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def tp: Long
def numTruePositives: Long

/** number of false positives */
def fp: Long
def numFalsePositives: Long

/** number of false negatives */
def fn: Long
def numFalseNegatives: Long

/** number of true negatives */
def tn: Long
def numTrueNegatives: Long

/** number of positives */
def p: Long = tp + fn
def numPositives: Long = numTruePositives + numFalseNegatives

/** number of negatives */
def n: Long = fp + tn
def numNegatives: Long = numFalsePositives + numTrueNegatives
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,29 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val score = Seq(0.8, 0.6, 0.4, 0.1)
val tp = Seq(1, 3, 3, 4)
val fp = Seq(0, 1, 2, 3)
val p = 4
val n = 3
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
val recall = tp.map(t => t.toDouble / p)
val fpr = fp.map(f => f.toDouble / n)
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives = Seq(1, 3, 3, 4)
val numFalsePositives = Seq(0, 1, 2, 3)
val numPositives = 4
val numNegatives = 3
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
val recall = numTruePositives.map(t => t.toDouble / numPositives)
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
assert(metrics.thresholds().collect().toSeq === score)
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
assert(metrics.thresholds().collect().toSeq === threshold)
assert(metrics.roc().collect().toSeq === rocCurve)
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
assert(metrics.pr().collect().toSeq === prCurve)
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1))
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision))
assert(metrics.recallByThreshold().collect().toSeq === score.zip(recall))
assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
}
}

0 comments on commit a05941d

Please sign in to comment.