Skip to content

Commit

Permalink
[SPARK-24101][ML][MLLIB] ML Evaluators should use weight column - add…
Browse files Browse the repository at this point in the history
…ed weight column for multiclass classification evaluator

## What changes were proposed in this pull request?

The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data.

I've closed the PR: apache#16557
 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update.

Note: I've updated the JIRA to:
https://issues.apache.org/jira/browse/SPARK-24101
Which is a child of JIRA:
https://issues.apache.org/jira/browse/SPARK-18693

## How was this patch tested?

I added tests to the metrics class.

Closes apache#17086 from imatiach-msft/ilmat/multiclass-evaluate.

Authored-by: Ilya Matiach <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
imatiach-msft authored and jackylee-ch committed Feb 18, 2019
1 parent 2c64e49 commit a40aab9
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
Expand Down Expand Up @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "f1")

@Since("2.0.0")
Expand All @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(prediction: Double, label: Double) => (prediction, label)
val predictionAndLabelsWithWeights =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd.map {
case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
}
val metrics = new MulticlassMetrics(predictionAndLabels)
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure
case "weightedPrecision" => metrics.weightedPrecision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for multiclass classification.
*
* @param predictionAndLabels an RDD of (prediction, label) pairs.
* @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
* (prediction, label) pairs.
*/
@Since("1.1.0")
class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) {
class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
case (prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
case (prediction: Double, label: Double) =>
(prediction, label, 1.0)
case other =>
throw new IllegalArgumentException(s"Expected tuples, got $other")
}

/**
* An auxiliary constructor taking a DataFrame.
Expand All @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
private lazy val labelCountByClass: Map[Double, Double] =
predLabelsWeight.map {
case (_: Double, label: Double, weight: Double) =>
(label, weight)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val labelCount: Double = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Double] = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
(label, if (label == prediction) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
private lazy val fpByClass: Map[Double, Double] = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
(prediction, if (prediction != label) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val confusions = predictionAndLabels
.map { case (prediction, label) =>
((label, prediction), 1)
private lazy val confusions = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
((label, prediction), weight)
}.reduceByKey(_ + _)
.collectAsMap()

Expand All @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
while (i < n) {
var j = 0
while (j < n) {
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0)
j += 1
}
i += 1
Expand All @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
*/
@Since("1.1.0")
def falsePositiveRate(label: Double): Double = {
val fp = fpByClass.getOrElse(label, 0)
fp.toDouble / (labelCount - labelCountByClass(label))
val fp = fpByClass.getOrElse(label, 0.0)
fp / (labelCount - labelCountByClass(label))
}

/**
Expand All @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
@Since("1.1.0")
def precision(label: Double): Double = {
val tp = tpByClass(label)
val fp = fpByClass.getOrElse(label, 0)
val fp = fpByClass.getOrElse(label, 0.0)
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
}

Expand All @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* @param label the label.
*/
@Since("1.1.0")
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label)

/**
* Returns f-measure for a given label (category)
Expand Down Expand Up @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* out of the total number of instances.)
*/
@Since("2.0.0")
lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
lazy val accuracy: Double = tpByClass.values.sum / labelCount

/**
* Returns weighted true positive rate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
package org.apache.spark.mllib.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.ml.linalg.Matrices
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext

class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {

private val delta = 1e-7

test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
Expand All @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
val metrics = new MulticlassMetrics(predictionAndLabels)
val delta = 0.0000001
val tpRate0 = 2.0 / (2 + 2)
val tpRate1 = 3.0 / (3 + 1)
val tpRate2 = 1.0 / (1 + 0)
Expand All @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta)
assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta)
assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta)
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
assert(metrics.precision(0.0) ~== precision0 relTol delta)
assert(metrics.precision(1.0) ~== precision1 relTol delta)
assert(metrics.precision(2.0) ~== precision2 relTol delta)
assert(metrics.recall(0.0) ~== recall0 relTol delta)
assert(metrics.recall(1.0) ~== recall1 relTol delta)
assert(metrics.recall(2.0) ~== recall2 relTol delta)
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)

assert(metrics.accuracy ~==
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta)
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
val weight0 = 4.0 / 9
val weight1 = 4.0 / 9
val weight2 = 1.0 / 9
assert(metrics.weightedTruePositiveRate ~==
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
assert(metrics.weightedFalsePositiveRate ~==
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
assert(metrics.weightedPrecision ~==
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
assert(metrics.weightedRecall ~==
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
assert(metrics.weightedFMeasure ~==
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
assert(metrics.weightedFMeasure(2.0) ~==
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
assert(metrics.labels === labels)
}

test("Multiclass evaluation metrics with weights") {
/*
* Confusion matrix for 3-class classification with total 9 instances with 2 weights:
* |2 * w1|1 * w2 |1 * w1| true class0 (4 instances)
* |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances)
* |0 |0 |1 * w2| true class2 (1 instance)
*/
val w1 = 2.2
val w2 = 1.5
val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2
val confusionMatrix = Matrices.dense(3, 3,
Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2))
val labels = Array(0.0, 1.0, 2.0)
val predictionAndLabelsWithWeights = sc.parallelize(
Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2),
(1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2),
(2.0, 0.0, w1)), 2)
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0)
val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1))
val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2))
val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2))
val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2)
val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2)
val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val recall2 = (1.0 * w2) / (1.0 * w2 + 0)
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
assert(metrics.precision(0.0) ~== precision0 relTol delta)
assert(metrics.precision(1.0) ~== precision1 relTol delta)
assert(metrics.precision(2.0) ~== precision2 relTol delta)
assert(metrics.recall(0.0) ~== recall0 relTol delta)
assert(metrics.recall(1.0) ~== recall1 relTol delta)
assert(metrics.recall(2.0) ~== recall2 relTol delta)
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)

assert(math.abs(metrics.accuracy -
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
assert(math.abs(metrics.weightedTruePositiveRate -
((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta)
assert(math.abs(metrics.weightedFalsePositiveRate -
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
assert(math.abs(metrics.weightedPrecision -
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
assert(math.abs(metrics.weightedRecall -
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
assert(math.abs(metrics.weightedFMeasure -
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
assert(math.abs(metrics.weightedFMeasure(2.0) -
((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
assert(metrics.labels.sameElements(labels))
assert(metrics.accuracy ~==
(2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta)
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw
val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw
val weight2 = 1 * w2 / tw
assert(metrics.weightedTruePositiveRate ~==
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
assert(metrics.weightedFalsePositiveRate ~==
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
assert(metrics.weightedPrecision ~==
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
assert(metrics.weightedRecall ~==
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
assert(metrics.weightedFMeasure ~==
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
assert(metrics.weightedFMeasure(2.0) ~==
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
assert(metrics.labels === labels)
}
}

0 comments on commit a40aab9

Please sign in to comment.