diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala new file mode 100644 index 0000000000000..acb8c3d21d36f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -0,0 +1,74 @@ +package org.apache.spark.mllib.evaluation + +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ + + +class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{ + + private lazy val numDocs = predictionAndLabels.count() + + lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) => + if (predictions.size >0) + predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) / numDocs + + lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) => + predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs + + lazy val microPrecisionDoc = { + val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) => + (predictions.intersect(labels).size, predictions.size)}. + fold((0, 0)){ case((tp1, predictions1), (tp2, predictions2)) => + (tp1 + tp2, predictions1 + predictions2)} + sumTp.toDouble / sumPredictions + } + + lazy val microRecallDoc = { + val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) => + (predictions.intersect(labels).size, labels.size)}. + fold((0, 0)){ case((tp1, labels1), (tp2, labels2)) => + (tp1 + tp2, labels1 + labels2)} + sumTp.toDouble / sumLabels + } + + private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => + predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() + + private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => + predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() + + private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => + labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() + + def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) 0 else + tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0)) + + def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) 0 else + tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0)) + + def f1MeasureClass(label: Double) = { + val precision = precisionClass(label) + val recall = recallClass(label) + if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall) + } + + private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp} + + lazy val microPrecisionClass = { + val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp} + sumTp.toDouble / (sumTp + sumFp) + } + + lazy val microRecallClass = { + val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn} + sumTp.toDouble / (sumTp + sumFn) + } + + lazy val microF1MeasureClass = { + val precision = microPrecisionClass + val recall = microRecallClass + if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala new file mode 100644 index 0000000000000..62f6958639a74 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -0,0 +1,81 @@ +package org.apache.spark.mllib.evaluation + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD +import org.scalatest.FunSuite + + +class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { + test("Multilabel evaluation metrics") { + /* + * Documents true labels (5x class0, 3x class1, 4x class2): + * doc 0 - predict 0, 1 - class 0, 2 + * doc 1 - predict 0, 2 - class 0, 1 + * doc 2 - predict none - class 0 + * doc 3 - predict 2 - class 2 + * doc 4 - predict 2, 0 - class 2, 0 + * doc 5 - predict 0, 1, 2 - class 0, 1 + * doc 6 - predict 1 - class 1, 2 + * + * predicted classes + * class 0 - doc 0, 1, 4, 5 (total 4) + * class 1 - doc 0, 5, 6 (total 3) + * class 2 - doc 1, 3, 4, 5 (total 4) + * + * true classes + * class 0 - doc 0, 1, 2, 4, 5 (total 5) + * class 1 - doc 1, 5, 6 (total 3) + * class 2 - doc 0, 3, 4, 6 (total 4) + * + */ + val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize( + Seq((Set(0.0, 1.0), Set(0.0, 2.0)), + (Set(0.0, 2.0), Set(0.0, 1.0)), + (Set(), Set(0.0)), + (Set(2.0), Set(2.0)), + (Set(2.0, 0.0), Set(2.0, 0.0)), + (Set(0.0, 1.0, 2.0), Set(0.0, 1.0)), + (Set(1.0), Set(1.0, 2.0))), 2) + val metrics = new MultilabelMetrics(scoreAndLabels) + val delta = 0.00001 + val precision0 = 4.0 / (4 + 0) + val precision1 = 2.0 / (2 + 1) + val precision2 = 2.0 / (2 + 2) + val recall0 = 4.0 / (4 + 1) + val recall1 = 2.0 / (2 + 1) + val recall2 = 2.0 / (2 + 2) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2) + val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2) + val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass) + + val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) + val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) + + println("Ev" + metrics.macroPrecisionDoc) + println(macroPrecisionDoc) + println("Ev" + metrics.macroRecallDoc) + println(macroRecallDoc) + assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta) + assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta) + assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta) + assert(math.abs(metrics.recallClass(0.0) - recall0) < delta) + assert(math.abs(metrics.recallClass(1.0) - recall1) < delta) + assert(math.abs(metrics.recallClass(2.0) - recall2) < delta) + assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta) + assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta) + assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta) + + assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta) + assert(math.abs(metrics.microRecallClass - microRecallClass) < delta) + assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta) + + assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta) + assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta) + + + } + +}