Skip to content

Commit

Permalink
Updated based on comments - fixed since tag, renamed vars, added chec…
Browse files Browse the repository at this point in the history
…k for weight col
  • Loading branch information
imatiach-msft committed Apr 16, 2018
1 parent adb8f7a commit e89a030
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,15 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkNumericType(schema, $(labelCol))
if (isDefined(weightCol)) {
SchemaUtils.checkNumericType(schema, $(weightCol))
}

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabelsWithWeights =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else col($(weightCol)).cast(DoubleType))
.rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), (label, weight))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ class BinaryClassificationMetrics @Since("2.2.0") (
* Retrieves the score and labels (for binary compatibility).
* @return The score and labels.
*/
@Since("1.0.0")
@Since("1.3.0")
def scoreAndLabels: RDD[(Double, Double)] = {
scoreAndLabelsWithWeights.map(values => (values._1, values._2._1))
}

@Since("1.0.0")
@Since("1.3.0")
def this(@Since("1.3.0") scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") numBins: Int) =
this(numBins, scoreAndLabels.map(scoreAndLabel => (scoreAndLabel._1, (scoreAndLabel._2, 1.0))))

Expand Down Expand Up @@ -164,7 +164,7 @@ class BinaryClassificationMetrics @Since("2.2.0") (
// negatives within each bin, and then sort by score values in descending order.
val counts = scoreAndLabelsWithWeights.combineByKey(
createCombiner = (labelAndWeight: (Double, Double)) =>
new BinaryLabelCounter(0L, 0L) += (labelAndWeight._1, labelAndWeight._2),
new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2),
mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
c += (labelAndWeight._1, labelAndWeight._2),
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,33 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
val totalPositives = c.numTruePositives + c.numFalsePositives
val totalPositives = c.weightedTruePositives + c.weightedFalsePositives
if (totalPositives == 0) {
1.0
} else {
c.numTruePositives.toDouble / totalPositives
c.weightedTruePositives.toDouble / totalPositives
}
}
}

/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numNegatives == 0) {
if (c.weightedNegatives == 0) {
0.0
} else {
c.numFalsePositives.toDouble / c.numNegatives
c.weightedFalsePositives.toDouble / c.weightedNegatives
}
}
}

/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numPositives == 0) {
if (c.weightedPositives == 0) {
0.0
} else {
c.numTruePositives.toDouble / c.numPositives
c.weightedTruePositives.toDouble / c.weightedPositives
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary
* Trait for a binary confusion matrix.
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def numTruePositives: Double
/** weighted number of true positives */
def weightedTruePositives: Double

/** number of false positives */
def numFalsePositives: Double
/** weighted number of false positives */
def weightedFalsePositives: Double

/** number of false negatives */
def numFalseNegatives: Double
/** weighted number of false negatives */
def weightedFalseNegatives: Double

/** number of true negatives */
def numTrueNegatives: Double
/** weighted number of true negatives */
def weightedTrueNegatives: Double

/** number of positives */
def numPositives: Double = numTruePositives + numFalseNegatives
/** weighted number of positives */
def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives

/** number of negatives */
def numNegatives: Double = numFalsePositives + numTrueNegatives
/** weighted number of negatives */
def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives
}

/**
Expand All @@ -51,20 +51,20 @@ private[evaluation] case class BinaryConfusionMatrixImpl(
totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {

/** number of true positives */
override def numTruePositives: Double = count.numPositives
override def weightedTruePositives: Double = count.numPositives

/** number of false positives */
override def numFalsePositives: Double = count.numNegatives
override def weightedFalsePositives: Double = count.numNegatives

/** number of false negatives */
override def numFalseNegatives: Double = totalCount.numPositives - count.numPositives
override def weightedFalseNegatives: Double = totalCount.numPositives - count.numPositives

/** number of true negatives */
override def numTrueNegatives: Double = totalCount.numNegatives - count.numNegatives
override def weightedTrueNegatives: Double = totalCount.numNegatives - count.numNegatives

/** number of positives */
override def numPositives: Double = totalCount.numPositives
override def weightedPositives: Double = totalCount.numPositives

/** number of negatives */
override def numNegatives: Double = totalCount.numNegatives
override def weightedNegatives: Double = totalCount.numNegatives
}

0 comments on commit e89a030

Please sign in to comment.