From e89a030aa63945ddda22e2f879e9e0cafebe0801 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 26 Jun 2017 00:19:31 -0400 Subject: [PATCH] Updated based on comments - fixed since tag, renamed vars, added check for weight col --- .../BinaryClassificationEvaluator.scala | 6 +++- .../BinaryClassificationMetrics.scala | 6 ++-- .../BinaryClassificationMetricComputers.scala | 12 +++---- .../binary/BinaryConfusionMatrix.scala | 36 +++++++++---------- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index fc7cdd0b3042f..88a435018b21f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -85,11 +85,15 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkNumericType(schema, $(labelCol)) + if (isDefined(weightCol)) { + SchemaUtils.checkNumericType(schema, $(weightCol)) + } // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabelsWithWeights = dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else col($(weightCol)).cast(DoubleType)) .rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), (label, weight)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 0cd7d088739eb..c6b9cb6758f71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -52,12 +52,12 @@ class BinaryClassificationMetrics @Since("2.2.0") ( * Retrieves the score and labels (for binary compatibility). * @return The score and labels. */ - @Since("1.0.0") + @Since("1.3.0") def scoreAndLabels: RDD[(Double, Double)] = { scoreAndLabelsWithWeights.map(values => (values._1, values._2._1)) } - @Since("1.0.0") + @Since("1.3.0") def this(@Since("1.3.0") scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") numBins: Int) = this(numBins, scoreAndLabels.map(scoreAndLabel => (scoreAndLabel._1, (scoreAndLabel._2, 1.0)))) @@ -164,7 +164,7 @@ class BinaryClassificationMetrics @Since("2.2.0") ( // negatives within each bin, and then sort by score values in descending order. val counts = scoreAndLabelsWithWeights.combineByKey( createCombiner = (labelAndWeight: (Double, Double)) => - new BinaryLabelCounter(0L, 0L) += (labelAndWeight._1, labelAndWeight._2), + new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2), mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) => c += (labelAndWeight._1, labelAndWeight._2), mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 5a4c6aef50b7b..784db3196c3a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -27,11 +27,11 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl /** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - val totalPositives = c.numTruePositives + c.numFalsePositives + val totalPositives = c.weightedTruePositives + c.weightedFalsePositives if (totalPositives == 0) { 1.0 } else { - c.numTruePositives.toDouble / totalPositives + c.weightedTruePositives.toDouble / totalPositives } } } @@ -39,10 +39,10 @@ private[evaluation] object Precision extends BinaryClassificationMetricComputer /** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - if (c.numNegatives == 0) { + if (c.weightedNegatives == 0) { 0.0 } else { - c.numFalsePositives.toDouble / c.numNegatives + c.weightedFalsePositives.toDouble / c.weightedNegatives } } } @@ -50,10 +50,10 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricC /** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - if (c.numPositives == 0) { + if (c.weightedPositives == 0) { 0.0 } else { - c.numTruePositives.toDouble / c.numPositives + c.weightedTruePositives.toDouble / c.weightedPositives } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala index 1c0130700421e..c398fb8302801 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary * Trait for a binary confusion matrix. */ private[evaluation] trait BinaryConfusionMatrix { - /** number of true positives */ - def numTruePositives: Double + /** weighted number of true positives */ + def weightedTruePositives: Double - /** number of false positives */ - def numFalsePositives: Double + /** weighted number of false positives */ + def weightedFalsePositives: Double - /** number of false negatives */ - def numFalseNegatives: Double + /** weighted number of false negatives */ + def weightedFalseNegatives: Double - /** number of true negatives */ - def numTrueNegatives: Double + /** weighted number of true negatives */ + def weightedTrueNegatives: Double - /** number of positives */ - def numPositives: Double = numTruePositives + numFalseNegatives + /** weighted number of positives */ + def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives - /** number of negatives */ - def numNegatives: Double = numFalsePositives + numTrueNegatives + /** weighted number of negatives */ + def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives } /** @@ -51,20 +51,20 @@ private[evaluation] case class BinaryConfusionMatrixImpl( totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix { /** number of true positives */ - override def numTruePositives: Double = count.numPositives + override def weightedTruePositives: Double = count.numPositives /** number of false positives */ - override def numFalsePositives: Double = count.numNegatives + override def weightedFalsePositives: Double = count.numNegatives /** number of false negatives */ - override def numFalseNegatives: Double = totalCount.numPositives - count.numPositives + override def weightedFalseNegatives: Double = totalCount.numPositives - count.numPositives /** number of true negatives */ - override def numTrueNegatives: Double = totalCount.numNegatives - count.numNegatives + override def weightedTrueNegatives: Double = totalCount.numNegatives - count.numNegatives /** number of positives */ - override def numPositives: Double = totalCount.numPositives + override def weightedPositives: Double = totalCount.numPositives /** number of negatives */ - override def numNegatives: Double = totalCount.numNegatives + override def weightedNegatives: Double = totalCount.numNegatives }