diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d881..f1602c1bc5333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 980e0c92531a2..ad83c24ede964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case (prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") + } /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (_: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 5394baab94bcf..8779de590a256 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,10 +18,14 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.ml.linalg.Matrices +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val delta = 1e-7 + test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) - val delta = 0.0000001 val tpRate0 = 2.0 / (2 + 2) val tpRate1 = 3.0 / (3 + 1) val tpRate2 = 1.0 / (1 + 0) @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) + + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) + } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) } }