Skip to content

Commit

Permalink
move binary evalution classes to evaluation.binary
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 9, 2014
1 parent 8f78958 commit 3d71525
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,76 +15,52 @@
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation
package org.apache.spark.mllib.evaluation.binary

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.AreaUnderCurve
import org.apache.spark.Logging

/**
* Binary confusion matrix.
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
*
* @param count label counter for labels with scores greater than or equal to the current score
* @param total label counter for all labels
* @param totalCount label counter for all labels
*/
case class BinaryConfusionMatrix(
private case class BinaryConfusionMatrixImpl(
private val count: LabelCounter,
private val total: LabelCounter) extends Serializable {
private val totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {

/** number of true positives */
def tp: Long = count.numPositives
override def tp: Long = count.numPositives

/** number of false positives */
def fp: Long = count.numNegatives
override def fp: Long = count.numNegatives

/** number of false negatives */
def fn: Long = total.numPositives - count.numPositives
override def fn: Long = totalCount.numPositives - count.numPositives

/** number of true negatives */
def tn: Long = total.numNegatives - count.numNegatives
override def tn: Long = totalCount.numNegatives - count.numNegatives

/** number of positives */
def p: Long = total.numPositives
override def p: Long = totalCount.numPositives

/** number of negatives */
def n: Long = total.numNegatives
}

private trait Metric {
def apply(c: BinaryConfusionMatrix): Double
}

object Precision extends Metric {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / (c.tp + c.fp)
}

object FalsePositiveRate extends Metric {
override def apply(c: BinaryConfusionMatrix): Double =
c.fp.toDouble / c.n
}

object Recall extends Metric {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / c.p
}

case class FMeasure(beta: Double) extends Metric {
private val beta2 = beta * beta
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
}
override def n: Long = totalCount.numNegatives
}

/**
* Evaluator for binary classification.
*
* @param scoreAndlabels an RDD of (score, label) pairs.
*/
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable {
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable with Logging {

private lazy val (cumCounts: RDD[(Double, LabelCounter)], totalCount: LabelCounter, scoreAndConfusion: RDD[(Double, BinaryConfusionMatrix)]) = {
private lazy val (
cumCounts: RDD[(Double, LabelCounter)],
confusionByThreshold: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndlabels.combineByKey(
Expand All @@ -99,6 +75,7 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
}, preservesPartitioning = true).collect()
val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
val totalCount = cum.last
logInfo(s"Total counts: totalCount")
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = cum(index)
iter.map { case (score, c) =>
Expand All @@ -108,76 +85,71 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
}, preservesPartitioning = true)
cumCounts.persist()
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrix(cumCount, totalCount))
(score, BinaryConfusionMatrixImpl(cumCount, totalCount))
}
(cumCounts, totalCount, scoreAndConfusion)
}

/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
cumCounts.unpersist()
}

/**
* Returns the receiver operating characteristic (ROC) curve.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
def rocCurve(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall)

/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
def rocAUC(): Double = AreaUnderCurve.of(rocCurve())

/**
* Returns the precision-recall curve.
* @see http://en.wikipedia.org/wiki/Precision_and_recall
*/
def prCurve(): RDD[(Double, Double)] = createCurve(Recall, Precision)

/**
* Computes the area under the precision-recall curve.
*/
def prAUC(): Double = AreaUnderCurve.of(prCurve())

/**
* Returns the (threshold, F-Measure) curve.
* @param beta the beta factor in F-Measure computation.
* @return an RDD of (threshold, F-Measure) pairs.
* @see http://en.wikipedia.org/wiki/F1_score
*/
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))

/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
def fMeasureByThreshold() = fMeasureByThreshold(1.0)

private def createCurve(y: Metric): RDD[(Double, Double)] = {
scoreAndConfusion.map { case (s, c) =>
/** Creates a curve of (threshold, metric). */
private def createCurve(y: BinaryClassificationMetric): RDD[(Double, Double)] = {
confusionByThreshold.map { case (s, c) =>
(s, y(c))
}
}

private def createCurve(x: Metric, y: Metric): RDD[(Double, Double)] = {
scoreAndConfusion.map { case (_, c) =>
/** Creates a curve of (metricX, metricY). */
private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = {
confusionByThreshold.map { case (_, c) =>
(x(c), y(c))
}
}
}

class LocalBinaryClassificationEvaluator {
def get(data: Iterable[(Double, Double)]) {
val counts = data.groupBy(_._1).mapValues { s =>
val c = new LabelCounter()
s.foreach(c += _._2)
c
}.toSeq.sortBy(- _._1)
println("counts: " + counts.toList)
val total = new LabelCounter()
val cum = counts.map { s =>
total += s._2
(s._1, total.clone())
}
println("cum: " + cum.toList)
val roc = cum.map { case (s, c) =>
(1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
}
val rocAUC = AreaUnderCurve.of(roc)
println(rocAUC)
val pr = cum.map { case (s, c) =>
(1.0 * c.numPositives / total.numPositives,
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
}
val prAUC = AreaUnderCurve.of(pr)
println(prAUC)
}
}

/**
* A counter for positives and negatives.
*
* @param numPositives
* @param numNegatives
* @param numPositives number of positive labels
* @param numNegatives number of negative labels
*/
private[evaluation]
class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {

/** Process a label. */
def +=(label: Double): LabelCounter = {
Expand Down Expand Up @@ -208,6 +180,6 @@ class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) ext
new LabelCounter(numPositives, numNegatives)
}

override def toString: String = s"[$numPositives,$numNegatives]"
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

/**
* Trait for a binary classification evaluation metric.
*/
private[evaluation] trait BinaryClassificationMetric {
def apply(c: BinaryConfusionMatrix): Double
}

/** Precision. */
private[evaluation] object Precision extends BinaryClassificationMetric {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / (c.tp + c.fp)
}

/** False positive rate. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetric {
override def apply(c: BinaryConfusionMatrix): Double =
c.fp.toDouble / c.n
}

/** Recall. */
private[evalution] object Recall extends BinaryClassificationMetric {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / c.p
}

/**
* F-Measure.
* @param beta the beta constant in F-Measure
* @see http://en.wikipedia.org/wiki/F1_score
*/
private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetric {
private val beta2 = beta * beta
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

/**
* Trait for a binary confusion matrix.
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def tp: Long

/** number of false positives */
def fp: Long

/** number of false negatives */
def fn: Long

/** number of true negatives */
def tn: Long

/** number of positives */
def p: Long = tp + fn

/** number of negatives */
def n: Long = fp + tn
}

0 comments on commit 3d71525

Please sign in to comment.