Skip to content

Commit

Permalink
[SPARK-1225, 1241] [MLLIB] Add AreaUnderCurve and BinaryClassificatio…
Browse files Browse the repository at this point in the history
…nMetrics

This PR implements a generic version of `AreaUnderCurve` using the `RDD.sliding` implementation from #136 . It also contains refactoring of #160 for binary classification evaluation.

Author: Xiangrui Meng <[email protected]>

Closes #364 from mengxr/auc and squashes the following commits:

a05941d [Xiangrui Meng] replace TP/FP/TN/FN by their full names
3f42e98 [Xiangrui Meng] add (0, 0), (1, 1) to roc, and (0, 1) to pr
fb4b6d2 [Xiangrui Meng] rename Evaluator to Metrics and add more metrics
b1b7dab [Xiangrui Meng] fix code styles
9dc3518 [Xiangrui Meng] add tests for BinaryClassificationEvaluator
ca31da5 [Xiangrui Meng] remove PredictionAndResponse
3d71525 [Xiangrui Meng] move binary evalution classes to evaluation.binary
8f78958 [Xiangrui Meng] add PredictionAndResponse
dda82d5 [Xiangrui Meng] add confusion matrix
aa7e278 [Xiangrui Meng] add initial version of binary classification evaluator
221ebce [Xiangrui Meng] add a new test to sliding
a920865 [Xiangrui Meng] Merge branch 'sliding' into auc
a9b250a [Xiangrui Meng] move sliding to mllib
cab9a52 [Xiangrui Meng] use last for the last element
db6cb30 [Xiangrui Meng] remove unnecessary toSeq
9916202 [Xiangrui Meng] change RDD.sliding return type to RDD[Seq[T]]
284d991 [Xiangrui Meng] change SlidedRDD to SlidingRDD
c1c6c22 [Xiangrui Meng] add AreaUnderCurve
65461b2 [Xiangrui Meng] Merge branch 'sliding' into auc
5ee6001 [Xiangrui Meng] add TODO
d2a600d [Xiangrui Meng] add sliding to rdd

(cherry picked from commit f5ace8d)
Signed-off-by: Matei Zaharia <[email protected]>
  • Loading branch information
mengxr authored and mateiz committed Apr 11, 2014
1 parent 170b09d commit e6128b5
Show file tree
Hide file tree
Showing 9 changed files with 671 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.rdd.RDDFunctions._

/**
* Computes the area under the curve (AUC) using the trapezoidal rule.
*/
private[evaluation] object AreaUnderCurve {

/**
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
* @param points two 2D points stored in Seq
*/
private def trapezoid(points: Seq[(Double, Double)]): Double = {
require(points.length == 2)
val x = points.head
val y = points.last
(y._1 - x._1) * (y._2 + x._2) / 2.0
}

/**
* Returns the area under the given curve.
*
* @param curve a RDD of ordered 2D points stored in pairs representing a curve
*/
def of(curve: RDD[(Double, Double)]): Double = {
curve.sliding(2).aggregate(0.0)(
seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
combOp = _ + _
)
}

/**
* Returns the area under the given curve.
*
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
*/
def of(curve: Iterable[(Double, Double)]): Double = {
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
combop = _ + _
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

/**
* Trait for a binary classification evaluation metric computer.
*/
private[evaluation] trait BinaryClassificationMetricComputer extends Serializable {
def apply(c: BinaryConfusionMatrix): Double
}

/** Precision. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
}

/** False positive rate. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numFalsePositives.toDouble / c.numNegatives
}

/** Recall. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numTruePositives.toDouble / c.numPositives
}

/**
* F-Measure.
* @param beta the beta constant in F-Measure
* @see http://en.wikipedia.org/wiki/F1_score
*/
private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer {
private val beta2 = beta * beta
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

import org.apache.spark.rdd.{UnionRDD, RDD}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.AreaUnderCurve
import org.apache.spark.Logging

/**
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
*
* @param count label counter for labels with scores greater than or equal to the current score
* @param totalCount label counter for all labels
*/
private case class BinaryConfusionMatrixImpl(
count: LabelCounter,
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {

/** number of true positives */
override def numTruePositives: Long = count.numPositives

/** number of false positives */
override def numFalsePositives: Long = count.numNegatives

/** number of false negatives */
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives

/** number of true negatives */
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives

/** number of positives */
override def numPositives: Long = totalCount.numPositives

/** number of negatives */
override def numNegatives: Long = totalCount.numNegatives
}

/**
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
*/
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
extends Serializable with Logging {

private lazy val (
cumulativeCounts: RDD[(Double, LabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
mergeValue = (c: LabelCounter, label: Double) => c += label,
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
).sortByKey(ascending = false)
val agg = counts.values.mapPartitions({ iter =>
val agg = new LabelCounter()
iter.foreach(agg += _)
Iterator(agg)
}, preservesPartitioning = true).collect()
val partitionwiseCumulativeCounts =
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
val totalCount = partitionwiseCumulativeCounts.last
logInfo(s"Total counts: $totalCount")
val cumulativeCounts = counts.mapPartitionsWithIndex(
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = partitionwiseCumulativeCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
cumulativeCounts.persist()
val confusions = cumulativeCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
}
(cumulativeCounts, confusions)
}

/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
cumulativeCounts.unpersist()
}

/** Returns thresholds in descending order. */
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)

/**
* Returns the receiver operating characteristic (ROC) curve,
* which is an RDD of (false positive rate, true positive rate)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
}

/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
def areaUnderROC(): Double = AreaUnderCurve.of(roc())

/**
* Returns the precision-recall curve, which is an RDD of (recall, precision),
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
* @see http://en.wikipedia.org/wiki/Precision_and_recall
*/
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
first.union(prCurve)
}

/**
* Computes the area under the precision-recall curve.
*/
def areaUnderPR(): Double = AreaUnderCurve.of(pr())

/**
* Returns the (threshold, F-Measure) curve.
* @param beta the beta factor in F-Measure computation.
* @return an RDD of (threshold, F-Measure) pairs.
* @see http://en.wikipedia.org/wiki/F1_score
*/
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))

/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)

/** Returns the (threshold, precision) curve. */
def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)

/** Returns the (threshold, recall) curve. */
def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)

/** Creates a curve of (threshold, metric). */
private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
confusions.map { case (s, c) =>
(s, y(c))
}
}

/** Creates a curve of (metricX, metricY). */
private def createCurve(
x: BinaryClassificationMetricComputer,
y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
confusions.map { case (_, c) =>
(x(c), y(c))
}
}
}

/**
* A counter for positives and negatives.
*
* @param numPositives number of positive labels
* @param numNegatives number of negative labels
*/
private class LabelCounter(
var numPositives: Long = 0L,
var numNegatives: Long = 0L) extends Serializable {

/** Processes a label. */
def +=(label: Double): LabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
this
}

/** Merges another counter. */
def +=(other: LabelCounter): LabelCounter = {
numPositives += other.numPositives
numNegatives += other.numNegatives
this
}

override def clone: LabelCounter = {
new LabelCounter(numPositives, numNegatives)
}

override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.evaluation.binary

/**
* Trait for a binary confusion matrix.
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def numTruePositives: Long

/** number of false positives */
def numFalsePositives: Long

/** number of false negatives */
def numFalseNegatives: Long

/** number of true negatives */
def numTrueNegatives: Long

/** number of positives */
def numPositives: Long = numTruePositives + numFalseNegatives

/** number of negatives */
def numNegatives: Long = numFalsePositives + numTrueNegatives
}
Loading

0 comments on commit e6128b5

Please sign in to comment.