Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24103][ML][MLLIB] ML Evaluators should use weight column - added weight column for binary classification evaluator #17084

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasRawPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
Expand Down Expand Up @@ -68,21 +69,34 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "areaUnderROC")

@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkNumericType(schema, $(labelCol))
if (isDefined(weightCol)) {
SchemaUtils.checkNumericType(schema, $(weightCol))
}

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
val scoreAndLabelsWithWeights =
dataset.select(
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
else col($(weightCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, label, weight)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param scoreAndLabelsWithOptWeight an RDD of (score, label) or (score, label, weight) tuples.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
Expand All @@ -41,9 +41,19 @@ import org.apache.spark.sql.DataFrame
* partition boundaries.
*/
@Since("1.0.0")
class BinaryClassificationMetrics @Since("1.3.0") (
@Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
@Since("1.3.0") val numBins: Int) extends Logging {
class BinaryClassificationMetrics @Since("3.0.0") (
@Since("1.3.0") val scoreAndLabelsWithOptWeight: RDD[_ <: Product],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see, so last time we decided this was OK because the type change is only visible at compile-time, and this should be source compatible. I don't think the argument name change will be an issue here as it isn't a named optional arg. should be OK but we'll have to check with MiMa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW do we check elsewhere for positive weight? they need to be > 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we currently do not check for the weight > 0, should that live in BinaryClassificationMetrics? I should probably update the other metrics classes as well then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no useful case for negative weights right? I can see allowing 0 for convenience, as it at least has a clear meaning.

Maybe this could be a separate change, but yeah anywhere the user supplies weights, it's useful to check this somewhere along the way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I don't believe there is a useful case for negative weights but 0 may be useful.
I added a check in binary classification metrics. Should I add to the other evaluators in this PR as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically, the check was:

require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

@Since("1.3.0") val numBins: Int = 1000)
extends Logging {
val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabelsWithOptWeight.map {
case (prediction: Double, label: Double, weight: Double) =>
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
(prediction, (label, weight))
case (prediction: Double, label: Double) =>
(prediction, (label, 1.0))
case other =>
throw new IllegalArgumentException(s"Expected tuples, got $other")
}

require(numBins >= 0, "numBins must be nonnegative")

Expand All @@ -53,6 +63,15 @@ class BinaryClassificationMetrics @Since("1.3.0") (
@Since("1.0.0")
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)

/**
* Retrieves the score and labels (for binary compatibility).
* @return The score and labels.
*/
@Since("1.3.0")
def scoreAndLabels: RDD[(Double, Double)] = {
scoreLabelsWeight.map { case (prediction, (label, _)) => (prediction, label) }
}

/**
* An auxiliary constructor taking a DataFrame.
* @param scoreAndLabels a DataFrame with two double columns: score and label
Expand Down Expand Up @@ -146,11 +165,13 @@ class BinaryClassificationMetrics @Since("1.3.0") (
private lazy val (
cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
// Create a bin for each distinct score value, count weighted positives and
// negatives within each bin, and then sort by score values in descending order.
val counts = scoreLabelsWeight.combineByKey(
createCombiner = (labelAndWeight: (Double, Double)) =>
new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2),
srowen marked this conversation as resolved.
Show resolved Hide resolved
mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
c += (labelAndWeight._1, labelAndWeight._2),
srowen marked this conversation as resolved.
Show resolved Hide resolved
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,33 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
val totalPositives = c.numTruePositives + c.numFalsePositives
if (totalPositives == 0) {
val totalPositives = c.weightedTruePositives + c.weightedFalsePositives
if (totalPositives == 0.0) {
1.0
} else {
c.numTruePositives.toDouble / totalPositives
c.weightedTruePositives / totalPositives
}
}
}

/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numNegatives == 0) {
if (c.weightedNegatives == 0.0) {
0.0
} else {
c.numFalsePositives.toDouble / c.numNegatives
c.weightedFalsePositives / c.weightedNegatives
}
}
}

/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numPositives == 0) {
if (c.weightedPositives == 0.0) {
0.0
} else {
c.numTruePositives.toDouble / c.numPositives
c.weightedTruePositives / c.weightedPositives
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary
* Trait for a binary confusion matrix.
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def numTruePositives: Long
/** weighted number of true positives */
def weightedTruePositives: Double

/** number of false positives */
def numFalsePositives: Long
/** weighted number of false positives */
def weightedFalsePositives: Double

/** number of false negatives */
def numFalseNegatives: Long
/** weighted number of false negatives */
def weightedFalseNegatives: Double

/** number of true negatives */
def numTrueNegatives: Long
/** weighted number of true negatives */
def weightedTrueNegatives: Double

/** number of positives */
def numPositives: Long = numTruePositives + numFalseNegatives
/** weighted number of positives */
def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives

/** number of negatives */
def numNegatives: Long = numFalsePositives + numTrueNegatives
/** weighted number of negatives */
def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives
}

/**
Expand All @@ -51,20 +51,22 @@ private[evaluation] case class BinaryConfusionMatrixImpl(
totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {

/** number of true positives */
override def numTruePositives: Long = count.numPositives
override def weightedTruePositives: Double = count.weightedNumPositives

/** number of false positives */
override def numFalsePositives: Long = count.numNegatives
override def weightedFalsePositives: Double = count.weightedNumNegatives

/** number of false negatives */
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
override def weightedFalseNegatives: Double =
totalCount.weightedNumPositives - count.weightedNumPositives

/** number of true negatives */
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
override def weightedTrueNegatives: Double =
totalCount.weightedNumNegatives - count.weightedNumNegatives

/** number of positives */
override def numPositives: Long = totalCount.numPositives
override def weightedPositives: Double = totalCount.weightedNumPositives

/** number of negatives */
override def numNegatives: Long = totalCount.numNegatives
override def weightedNegatives: Double = totalCount.weightedNumNegatives
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,39 @@ package org.apache.spark.mllib.evaluation.binary
/**
* A counter for positives and negatives.
*
* @param numPositives number of positive labels
* @param numNegatives number of negative labels
* @param weightedNumPositives weighted number of positive labels
* @param weightedNumNegatives weighted number of negative labels
*/
private[evaluation] class BinaryLabelCounter(
var numPositives: Long = 0L,
var numNegatives: Long = 0L) extends Serializable {
var weightedNumPositives: Double = 0.0,
var weightedNumNegatives: Double = 0.0) extends Serializable {

/** Processes a label. */
def +=(label: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
if (label > 0.5) weightedNumPositives += 1.0 else weightedNumNegatives += 1.0
this
}

/** Processes a label with a weight. */
def +=(label: Double, weight: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) weightedNumPositives += weight else weightedNumNegatives += weight
this
}

/** Merges another counter. */
def +=(other: BinaryLabelCounter): BinaryLabelCounter = {
numPositives += other.numPositives
numNegatives += other.numNegatives
weightedNumPositives += other.weightedNumPositives
weightedNumNegatives += other.weightedNumNegatives
this
}

override def clone: BinaryLabelCounter = {
new BinaryLabelCounter(numPositives, numNegatives)
new BinaryLabelCounter(weightedNumPositives, weightedNumNegatives)
}

override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
override def toString: String = s"{numPos: $weightedNumPositives, numNeg: $weightedNumNegatives}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.")
}

test("should accept weight column") {
val weightCol = "weight"
// get metric with weight column
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC").setWeightCol(weightCol)
val vectorDF = Seq(
(0d, Vectors.dense(2.5, 12), 1.0),
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
(1d, Vectors.dense(1, 3), 1.0),
(0d, Vectors.dense(10, 2), 1.0)
).toDF("label", "rawPrediction", weightCol)
val result = evaluator.evaluate(vectorDF)
// without weight column
val evaluator2 = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC")
val result2 = evaluator2.evaluate(vectorDF)
assert(result === result2)
// use different weights, validate metrics change
val vectorDF2 = Seq(
(0d, Vectors.dense(2.5, 12), 2.5),
(1d, Vectors.dense(1, 3), 0.1),
(0d, Vectors.dense(10, 2), 2.0)
).toDF("label", "rawPrediction", weightCol)
val result3 = evaluator.evaluate(vectorDF2)
// Since wrong result weighted more heavily, expect the score to be lower
assert(result3 < result)
}

test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,34 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics with weights") {
val w1 = 1.5
val w2 = 0.7
val w3 = 0.4
val scoreAndLabelsWithWeights = sc.parallelize(
Seq((0.1, 0.0, w1), (0.1, 1.0, w2), (0.4, 0.0, w1), (0.6, 0.0, w3),
(0.6, 1.0, w2), (0.6, 1.0, w2), (0.8, 1.0, w1)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, 0)
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives =
Seq(1.0 * w1, 1.0 * w1 + 2.0 * w2, 1.0 * w1 + 2.0 * w2, 3.0 * w2 + 1.0 * w1)
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1)
val numPositives = 3 * w2 + 1 * w1
val numNegatives = 2 * w1 + w3
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
val recalls = numTruePositives.map(t => t / numPositives)
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
val fpr = numFalsePositives.map(f => f / numNegatives)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for RDD where all examples have positive label") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
Expand Down
18 changes: 13 additions & 5 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def isLargerBetter(self):


@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
Expand All @@ -130,6 +130,14 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
>>> str(evaluator2.getRawPredictionCol())
'raw'
>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
...
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
>>> evaluator.evaluate(dataset)
0.70...

.. versionadded:: 1.4.0
"""
Expand All @@ -140,10 +148,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction

@keyword_only
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
metricName="areaUnderROC", weightCol=None):
"""
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
metricName="areaUnderROC", weightCol=None)
"""
super(BinaryClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
Expand All @@ -169,10 +177,10 @@ def getMetricName(self):
@keyword_only
@since("1.4.0")
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
metricName="areaUnderROC", weightCol=None):
"""
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
metricName="areaUnderROC", weightCol=None)
Sets params for binary classification evaluator.
"""
kwargs = self._input_kwargs
Expand Down
Loading