Skip to content

Commit

Permalink
add tests for BinaryClassificationEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 9, 2014
1 parent ca31da5 commit 9dc3518
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
/**
* Computes the area under the curve (AUC) using the trapezoidal rule.
*/
private[mllib] object AreaUnderCurve {
private[evaluation] object AreaUnderCurve {

/**
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import org.apache.spark.Logging
* @param totalCount label counter for all labels
*/
private case class BinaryConfusionMatrixImpl(
private val count: LabelCounter,
private val totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
count: LabelCounter,
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {

/** number of true positives */
override def tp: Long = count.numPositives
Expand All @@ -54,16 +54,16 @@ private case class BinaryConfusionMatrixImpl(
/**
* Evaluator for binary classification.
*
* @param scoreAndlabels an RDD of (score, label) pairs.
* @param scoreAndLabels an RDD of (score, label) pairs.
*/
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable with Logging {
class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) extends Serializable with Logging {

private lazy val (
cumCounts: RDD[(Double, LabelCounter)],
confusionByThreshold: RDD[(Double, BinaryConfusionMatrix)]) = {
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndlabels.combineByKey(
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
mergeValue = (c: LabelCounter, label: Double) => c += label,
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
Expand All @@ -73,21 +73,21 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
iter.foreach(agg += _)
Iterator(agg)
}, preservesPartitioning = true).collect()
val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
val totalCount = cum.last
logInfo(s"Total counts: totalCount")
val partitionwiseCumCounts = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
val totalCount = partitionwiseCumCounts.last
logInfo(s"Total counts: $totalCount")
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = cum(index)
val cumCount = partitionwiseCumCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
cumCounts.persist()
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount))
val confusions = cumCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
}
(cumCounts, totalCount, scoreAndConfusion)
(cumCounts, confusions)
}

/** Unpersist intermediate RDDs used in the computation. */
Expand Down Expand Up @@ -126,18 +126,18 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))

/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
def fMeasureByThreshold() = fMeasureByThreshold(1.0)
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)

/** Creates a curve of (threshold, metric). */
private def createCurve(y: BinaryClassificationMetric): RDD[(Double, Double)] = {
confusionByThreshold.map { case (s, c) =>
confusions.map { case (s, c) =>
(s, y(c))
}
}

/** Creates a curve of (metricX, metricY). */
private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = {
confusionByThreshold.map { case (_, c) =>
confusions.map { case (_, c) =>
(x(c), y(c))
}
}
Expand All @@ -151,35 +151,29 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
*/
private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {

/** Process a label. */
/** Processes a label. */
def +=(label: Double): LabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
this
}

/** Merge another counter. */
/** Merges another counter. */
def +=(other: LabelCounter): LabelCounter = {
numPositives += other.numPositives
numNegatives += other.numNegatives
this
}

def +(label: Double): LabelCounter = {
this.clone() += label
}

/** Sums this counter and another counter and returns the result in a new counter. */
def +(other: LabelCounter): LabelCounter = {
this.clone() += other
}

def sum: Long = numPositives + numNegatives

override def clone: LabelCounter = {
new LabelCounter(numPositives, numNegatives)
}

override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation.binary
/**
* Trait for a binary classification evaluation metric.
*/
private[evaluation] trait BinaryClassificationMetric {
private[evaluation] trait BinaryClassificationMetric extends Serializable {
def apply(c: BinaryConfusionMatrix): Double
}

Expand All @@ -37,7 +37,7 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetric
}

/** Recall. */
private[evalution] object Recall extends BinaryClassificationMetric {
private[evaluation] object Recall extends BinaryClassificationMetric {
override def apply(c: BinaryConfusionMatrix): Double =
c.tp.toDouble / c.p
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext

class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {

test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.evaluation.AreaUnderCurve

class BinaryClassificationEvaluatorSuite extends FunSuite with LocalSparkContext {
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
val evaluator = new BinaryClassificationEvaluator(scoreAndLabels)
val score = Seq(0.8, 0.6, 0.4, 0.1)
val tp = Seq(1, 3, 3, 4)
val fp = Seq(0, 1, 2, 3)
val p = 4
val n = 3
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
val recall = tp.map(t => t.toDouble / p)
val fpr = fp.map(f => f.toDouble / n)
val roc = fpr.zip(recall)
val pr = recall.zip(precision)
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
assert(evaluator.rocCurve().collect().toSeq === roc)
assert(evaluator.rocAUC() === AreaUnderCurve.of(roc))
assert(evaluator.prCurve().collect().toSeq === pr)
assert(evaluator.prAUC() === AreaUnderCurve.of(pr))
assert(evaluator.fMeasureByThreshold().collect().toSeq === score.zip(f1))
assert(evaluator.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
}
}

0 comments on commit 9dc3518

Please sign in to comment.