Skip to content

Commit

Permalink
add initial version of binary classification evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent 221ebce commit aa7e278
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
/**
* Computes the area under the curve (AUC) using the trapezoidal rule.
*/
object AreaUnderCurve {
private[mllib] object AreaUnderCurve {

/**
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
Expand Down Expand Up @@ -53,8 +53,8 @@ object AreaUnderCurve {
*
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
*/
def of(curve: Iterator[(Double, Double)]): Double = {
curve.sliding(2).withPartial(false).aggregate(0.0)(
def of(curve: Iterable[(Double, Double)]): Double = {
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
combop = _ + _
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package org.apache.spark.mllib.evaluation

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._

class BinaryClassificationEvaluator(scoreAndLabel: RDD[(Double, Double)]) {

}

object BinaryClassificationEvaluator {

def get(rdd: RDD[(Double, Double)]) {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = rdd.combineByKey(
createCombiner = (label: Double) => new Counter(0L, 0L) += label,
mergeValue = (c: Counter, label: Double) => c += label,
mergeCombiners = (c1: Counter, c2: Counter) => c1 += c2
).sortByKey(ascending = false)
println(counts.collect().toList)
val agg = counts.values.mapPartitions((iter: Iterator[Counter]) => {
val agg = new Counter()
iter.foreach(agg += _)
Iterator(agg)
}, preservesPartitioning = true).collect()
println(agg.toList)
val cum = agg.scanLeft(new Counter())((agg: Counter, c: Counter) => agg + c)
val total = cum.last
println(total)
println(cum.toList)
val cumCountsRdd = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, Counter)]) => {
val cumCount = cum(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
println("cum: " + cumCountsRdd.collect().toList)
val rocAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
(1.0 * c.numNegatives / total.numNegatives,
1.0 * c.numPositives / total.numPositives)
}))
println(rocAUC)
val prAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
(1.0 * c.numPositives / total.numPositives,
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
}))
println(prAUC)
}

def get(data: Iterable[(Double, Double)]) {
val counts = data.groupBy(_._1).mapValues { s =>
val c = new Counter()
s.foreach(c += _._2)
c
}.toSeq.sortBy(- _._1)
println("counts: " + counts.toList)
val total = new Counter()
val cum = counts.map { s =>
total += s._2
(s._1, total.clone())
}
println("cum: " + cum.toList)
val roc = cum.map { case (s, c) =>
(1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
}
val rocAUC = AreaUnderCurve.of(roc)
println(rocAUC)
val pr = cum.map { case (s, c) =>
(1.0 * c.numPositives / total.numPositives,
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
}
val prAUC = AreaUnderCurve.of(pr)
println(prAUC)
}
}

class Counter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {

def +=(label: Double): Counter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
this
}

def +=(other: Counter): Counter = {
numPositives += other.numPositives
numNegatives += other.numNegatives
this
}

def +(label: Double): Counter = {
this.clone() += label
}

def +(other: Counter): Counter = {
this.clone() += other
}

def sum: Long = numPositives + numNegatives

override def clone(): Counter = {
new Counter(numPositives, numNegatives)
}

override def toString(): String = s"[$numPositives,$numNegatives]"
}

Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
assert(AreaUnderCurve.of(curve.toIterator) === auc)
assert(AreaUnderCurve.of(curve) === auc)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) == auc)
}

test("auc of an empty curve") {
val curve = Seq.empty[(Double, Double)]
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
assert(AreaUnderCurve.of(curve) === 0.0)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) === 0.0)
}

test("auc of a curve with a single point") {
val curve = Seq((1.0, 1.0))
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
assert(AreaUnderCurve.of(curve) === 0.0)
val rddCurve = sc.parallelize(curve, 2)
assert(AreaUnderCurve.of(rddCurve) === 0.0)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext

class BinaryClassificationEvaluationSuite extends FunSuite with LocalSparkContext {
test("test") {
val data = Seq((0.0, 0.0), (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0), (0.9, 1.0))
BinaryClassificationEvaluator.get(data)
val rdd = sc.parallelize(data, 3)
BinaryClassificationEvaluator.get(rdd)
}
}

0 comments on commit aa7e278

Please sign in to comment.