Skip to content

Commit

Permalink
Added weight column for binary classification evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Apr 16, 2018
1 parent 5003736 commit adb8f7a
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasRawPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))

/**
* Default number of bins to use for binary classification evaluation.
*/
val defaultNumberOfBins = 1000

/**
* param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`)
* @group param
Expand Down Expand Up @@ -68,6 +74,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "areaUnderROC")

@Since("2.0.0")
Expand All @@ -77,12 +87,16 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
SchemaUtils.checkNumericType(schema, $(labelCol))

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
val scoreAndLabelsWithWeights =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), (label, weight))
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, (label, weight))
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metrics = new BinaryClassificationMetrics(defaultNumberOfBins, scoreAndLabelsWithWeights)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param scoreAndLabelsWithWeights an RDD of (score, (label, weight)) pairs.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
Expand All @@ -41,12 +41,26 @@ import org.apache.spark.sql.DataFrame
* partition boundaries.
*/
@Since("1.0.0")
class BinaryClassificationMetrics @Since("1.3.0") (
@Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
@Since("1.3.0") val numBins: Int) extends Logging {
class BinaryClassificationMetrics @Since("2.2.0") (
val numBins: Int,
@Since("2.2.0") val scoreAndLabelsWithWeights: RDD[(Double, (Double, Double))])
extends Logging {

require(numBins >= 0, "numBins must be nonnegative")

/**
* Retrieves the score and labels (for binary compatibility).
* @return The score and labels.
*/
@Since("1.0.0")
def scoreAndLabels: RDD[(Double, Double)] = {
scoreAndLabelsWithWeights.map(values => (values._1, values._2._1))
}

@Since("1.0.0")
def this(@Since("1.3.0") scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") numBins: Int) =
this(numBins, scoreAndLabels.map(scoreAndLabel => (scoreAndLabel._1, (scoreAndLabel._2, 1.0))))

/**
* Defaults `numBins` to 0.
*/
Expand Down Expand Up @@ -146,11 +160,13 @@ class BinaryClassificationMetrics @Since("1.3.0") (
private lazy val (
cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
// Create a bin for each distinct score value, count weighted positives and
// negatives within each bin, and then sort by score values in descending order.
val counts = scoreAndLabelsWithWeights.combineByKey(
createCombiner = (labelAndWeight: (Double, Double)) =>
new BinaryLabelCounter(0L, 0L) += (labelAndWeight._1, labelAndWeight._2),
mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
c += (labelAndWeight._1, labelAndWeight._2),
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ package org.apache.spark.mllib.evaluation.binary
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def numTruePositives: Long
def numTruePositives: Double

/** number of false positives */
def numFalsePositives: Long
def numFalsePositives: Double

/** number of false negatives */
def numFalseNegatives: Long
def numFalseNegatives: Double

/** number of true negatives */
def numTrueNegatives: Long
def numTrueNegatives: Double

/** number of positives */
def numPositives: Long = numTruePositives + numFalseNegatives
def numPositives: Double = numTruePositives + numFalseNegatives

/** number of negatives */
def numNegatives: Long = numFalsePositives + numTrueNegatives
def numNegatives: Double = numFalsePositives + numTrueNegatives
}

/**
Expand All @@ -51,20 +51,20 @@ private[evaluation] case class BinaryConfusionMatrixImpl(
totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {

/** number of true positives */
override def numTruePositives: Long = count.numPositives
override def numTruePositives: Double = count.numPositives

/** number of false positives */
override def numFalsePositives: Long = count.numNegatives
override def numFalsePositives: Double = count.numNegatives

/** number of false negatives */
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
override def numFalseNegatives: Double = totalCount.numPositives - count.numPositives

/** number of true negatives */
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
override def numTrueNegatives: Double = totalCount.numNegatives - count.numNegatives

/** number of positives */
override def numPositives: Long = totalCount.numPositives
override def numPositives: Double = totalCount.numPositives

/** number of negatives */
override def numNegatives: Long = totalCount.numNegatives
override def numNegatives: Double = totalCount.numNegatives
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ package org.apache.spark.mllib.evaluation.binary
* @param numNegatives number of negative labels
*/
private[evaluation] class BinaryLabelCounter(
var numPositives: Long = 0L,
var numNegatives: Long = 0L) extends Serializable {
var numPositives: Double = 0.0,
var numNegatives: Double = 0.0) extends Serializable {

/** Processes a label. */
def +=(label: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
if (label > 0.5) numPositives += 1.0 else numNegatives += 1.0
this
}

/** Processes a label with a weight. */
def +=(label: Double, weight: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += weight else numNegatives += weight
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.")
}

test("should accept weight column") {
val weightCol = "weight"
// get metric with weight column
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC").setWeightCol(weightCol)
val vectorDF = Seq(
(0d, Vectors.dense(2.5, 12), 1.0),
(1d, Vectors.dense(1, 3), 1.0),
(0d, Vectors.dense(10, 2), 1.0)
).toDF("label", "rawPrediction", weightCol)
val result = evaluator.evaluate(vectorDF)
// without weight column
val evaluator2 = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC")
val result2 = evaluator2.evaluate(vectorDF)
assert(result === result2)
// use different weights, validate metrics change
val vectorDF2 = Seq(
(0d, Vectors.dense(2.5, 12), 2.5),
(1d, Vectors.dense(1, 3), 0.1),
(0d, Vectors.dense(10, 2), 2.0)
).toDF("label", "rawPrediction", weightCol)
val result3 = evaluator.evaluate(vectorDF2)
// Since wrong result weighted more heavily, expect the score to be lower
assert(result3 < result)
}

test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,34 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics with weights") {
val w1 = 1.5
val w2 = 0.7
val w3 = 0.4
val scoreAndLabelsWithWeights = sc.parallelize(
Seq((0.1, (0.0, w1)), (0.1, (1.0, w2)), (0.4, (0.0, w1)), (0.6, (0.0, w3)),
(0.6, (1.0, w2)), (0.6, (1.0, w2)), (0.8, (1.0, w1))), 2)
val metrics = new BinaryClassificationMetrics(0, scoreAndLabelsWithWeights)
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives =
Seq(1.0 * w1, 1.0 * w1 + 2.0 * w2, 1.0 * w1 + 2.0 * w2, 3.0 * w2 + 1.0 * w1)
val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1)
val numPositives = 3 * w2 + 1 * w1
val numNegatives = 2 * w1 + w3
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
val recalls = numTruePositives.map(t => t / numPositives)
val fpr = numFalsePositives.map(f => f / numNegatives)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for RDD where all examples have positive label") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
Expand Down

0 comments on commit adb8f7a

Please sign in to comment.