Skip to content

Commit

Permalink
Multilabel evaluation metics and tests: macro precision and recall av…
Browse files Browse the repository at this point in the history
…eraged by docs, micro and per-class precision and recall averaged by class
  • Loading branch information
avulanov committed Jun 30, 2014
1 parent 67fca18 commit 154164b
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.spark.mllib.evaluation

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._


class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{

private lazy val numDocs = predictionAndLabels.count()

lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) =>
if (predictions.size >0)
predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) / numDocs

lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) =>
predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs

lazy val microPrecisionDoc = {
val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) =>
(predictions.intersect(labels).size, predictions.size)}.
fold((0, 0)){ case((tp1, predictions1), (tp2, predictions2)) =>
(tp1 + tp2, predictions1 + predictions2)}
sumTp.toDouble / sumPredictions
}

lazy val microRecallDoc = {
val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) =>
(predictions.intersect(labels).size, labels.size)}.
fold((0, 0)){ case((tp1, labels1), (tp2, labels2)) =>
(tp1 + tp2, labels1 + labels2)}
sumTp.toDouble / sumLabels
}

private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()

private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()

private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()

def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) 0 else
tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0))

def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) 0 else
tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0))

def f1MeasureClass(label: Double) = {
val precision = precisionClass(label)
val recall = recallClass(label)
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
}

private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp}

lazy val microPrecisionClass = {
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
sumTp.toDouble / (sumTp + sumFp)
}

lazy val microRecallClass = {
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
sumTp.toDouble / (sumTp + sumFn)
}

lazy val microF1MeasureClass = {
val precision = microPrecisionClass
val recall = microRecallClass
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package org.apache.spark.mllib.evaluation

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite


class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
test("Multilabel evaluation metrics") {
/*
* Documents true labels (5x class0, 3x class1, 4x class2):
* doc 0 - predict 0, 1 - class 0, 2
* doc 1 - predict 0, 2 - class 0, 1
* doc 2 - predict none - class 0
* doc 3 - predict 2 - class 2
* doc 4 - predict 2, 0 - class 2, 0
* doc 5 - predict 0, 1, 2 - class 0, 1
* doc 6 - predict 1 - class 1, 2
*
* predicted classes
* class 0 - doc 0, 1, 4, 5 (total 4)
* class 1 - doc 0, 5, 6 (total 3)
* class 2 - doc 1, 3, 4, 5 (total 4)
*
* true classes
* class 0 - doc 0, 1, 2, 4, 5 (total 5)
* class 1 - doc 1, 5, 6 (total 3)
* class 2 - doc 0, 3, 4, 6 (total 4)
*
*/
val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize(
Seq((Set(0.0, 1.0), Set(0.0, 2.0)),
(Set(0.0, 2.0), Set(0.0, 1.0)),
(Set(), Set(0.0)),
(Set(2.0), Set(2.0)),
(Set(2.0, 0.0), Set(2.0, 0.0)),
(Set(0.0, 1.0, 2.0), Set(0.0, 1.0)),
(Set(1.0), Set(1.0, 2.0))), 2)
val metrics = new MultilabelMetrics(scoreAndLabels)
val delta = 0.00001
val precision0 = 4.0 / (4 + 0)
val precision1 = 2.0 / (2 + 1)
val precision2 = 2.0 / (2 + 2)
val recall0 = 4.0 / (4 + 1)
val recall1 = 2.0 / (2 + 1)
val recall2 = 2.0 / (2 + 2)
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2)
val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2)
val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass)

val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)

println("Ev" + metrics.macroPrecisionDoc)
println(macroPrecisionDoc)
println("Ev" + metrics.macroRecallDoc)
println(macroRecallDoc)
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
assert(math.abs(metrics.recallClass(0.0) - recall0) < delta)
assert(math.abs(metrics.recallClass(1.0) - recall1) < delta)
assert(math.abs(metrics.recallClass(2.0) - recall2) < delta)
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta)
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta)
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta)

assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta)
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta)
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta)

assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)


}

}

0 comments on commit 154164b

Please sign in to comment.