From e2c91c37dff6b6d9ae002a9095ed969955d11cac Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Mon, 30 Jun 2014 13:51:04 +0400 Subject: [PATCH] Fixes to mutliclass metics --- .../mllib/evaluation/MulticlassMetrics.scala | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index dff84224a0e31..ebcee86acedfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -60,72 +60,75 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log * @param label the label. * @return F1-measure. */ - def f1Measure(label: Double): Double = - 2 * precision(label) * recall(label) / (precision(label) + recall(label)) + def f1Measure(label: Double): Double ={ + val p = precision(label) + val r = recall(label) + if((p + r) == 0) 0 else 2 * p * r / (p + r) + } /** * Returns micro-averaged Recall * (equals to microPrecision and microF1measure for multiclass classifier) * @return microRecall. */ - def microRecall: Double = - tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble + lazy val microRecall: Double = + tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount /** * Returns micro-averaged Precision * (equals to microPrecision and microF1measure for multiclass classifier) * @return microPrecision. */ - def microPrecision: Double = microRecall + lazy val microPrecision: Double = microRecall /** * Returns micro-averaged F1-measure * (equals to microPrecision and microRecall for multiclass classifier) * @return microF1measure. */ - def microF1Measure: Double = microRecall + lazy val microF1Measure: Double = microRecall /** * Returns weighted averaged Recall * @return weightedRecall. */ - def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => - wRecall + recall(category) * count.toDouble / labelCount.toDouble} + lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => + wRecall + recall(category) * count.toDouble / labelCount} /** * Returns weighted averaged Precision * @return weightedPrecision. */ - def weightedPrecision: Double = + lazy val weightedPrecision: Double = labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => - wPrecision + precision(category) * count.toDouble / labelCount.toDouble} + wPrecision + precision(category) * count.toDouble / labelCount} /** * Returns weighted averaged F1-measure * @return weightedF1Measure. */ - def weightedF1Measure: Double = + lazy val weightedF1Measure: Double = labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) => - wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble} + wF1measure + f1Measure(category) * count.toDouble / labelCount} /** * Returns map with Precisions for individual classes * @return precisionPerClass. */ - def precisionPerClass = + lazy val precisionPerClass = labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap /** * Returns map with Recalls for individual classes * @return recallPerClass. */ - def recallPerClass = + lazy val recallPerClass = labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap /** * Returns map with F1-measures for individual classes * @return f1MeasurePerClass. */ - def f1MeasurePerClass = + lazy val f1MeasurePerClass = labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap }