-
Notifications
You must be signed in to change notification settings - Fork 28.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multilabel evaluation metics and tests: macro precision and recall av…
…eraged by docs, micro and per-class precision and recall averaged by class
- Loading branch information
Showing
2 changed files
with
155 additions
and
0 deletions.
There are no files selected for viewing
74 changes: 74 additions & 0 deletions
74
mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package org.apache.spark.mllib.evaluation | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.SparkContext._ | ||
|
||
|
||
class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{ | ||
|
||
private lazy val numDocs = predictionAndLabels.count() | ||
|
||
lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) => | ||
if (predictions.size >0) | ||
predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) / numDocs | ||
|
||
lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) => | ||
predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs | ||
|
||
lazy val microPrecisionDoc = { | ||
val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) => | ||
(predictions.intersect(labels).size, predictions.size)}. | ||
fold((0, 0)){ case((tp1, predictions1), (tp2, predictions2)) => | ||
(tp1 + tp2, predictions1 + predictions2)} | ||
sumTp.toDouble / sumPredictions | ||
} | ||
|
||
lazy val microRecallDoc = { | ||
val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) => | ||
(predictions.intersect(labels).size, labels.size)}. | ||
fold((0, 0)){ case((tp1, labels1), (tp2, labels2)) => | ||
(tp1 + tp2, labels1 + labels2)} | ||
sumTp.toDouble / sumLabels | ||
} | ||
|
||
private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => | ||
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() | ||
|
||
private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => | ||
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() | ||
|
||
private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => | ||
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() | ||
|
||
def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) 0 else | ||
tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0)) | ||
|
||
def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) 0 else | ||
tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0)) | ||
|
||
def f1MeasureClass(label: Double) = { | ||
val precision = precisionClass(label) | ||
val recall = recallClass(label) | ||
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall) | ||
} | ||
|
||
private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp} | ||
|
||
lazy val microPrecisionClass = { | ||
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp} | ||
sumTp.toDouble / (sumTp + sumFp) | ||
} | ||
|
||
lazy val microRecallClass = { | ||
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn} | ||
sumTp.toDouble / (sumTp + sumFn) | ||
} | ||
|
||
lazy val microF1MeasureClass = { | ||
val precision = microPrecisionClass | ||
val recall = microRecallClass | ||
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall) | ||
} | ||
|
||
} |
81 changes: 81 additions & 0 deletions
81
mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package org.apache.spark.mllib.evaluation | ||
|
||
import org.apache.spark.mllib.util.LocalSparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.scalatest.FunSuite | ||
|
||
|
||
class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { | ||
test("Multilabel evaluation metrics") { | ||
/* | ||
* Documents true labels (5x class0, 3x class1, 4x class2): | ||
* doc 0 - predict 0, 1 - class 0, 2 | ||
* doc 1 - predict 0, 2 - class 0, 1 | ||
* doc 2 - predict none - class 0 | ||
* doc 3 - predict 2 - class 2 | ||
* doc 4 - predict 2, 0 - class 2, 0 | ||
* doc 5 - predict 0, 1, 2 - class 0, 1 | ||
* doc 6 - predict 1 - class 1, 2 | ||
* | ||
* predicted classes | ||
* class 0 - doc 0, 1, 4, 5 (total 4) | ||
* class 1 - doc 0, 5, 6 (total 3) | ||
* class 2 - doc 1, 3, 4, 5 (total 4) | ||
* | ||
* true classes | ||
* class 0 - doc 0, 1, 2, 4, 5 (total 5) | ||
* class 1 - doc 1, 5, 6 (total 3) | ||
* class 2 - doc 0, 3, 4, 6 (total 4) | ||
* | ||
*/ | ||
val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize( | ||
Seq((Set(0.0, 1.0), Set(0.0, 2.0)), | ||
(Set(0.0, 2.0), Set(0.0, 1.0)), | ||
(Set(), Set(0.0)), | ||
(Set(2.0), Set(2.0)), | ||
(Set(2.0, 0.0), Set(2.0, 0.0)), | ||
(Set(0.0, 1.0, 2.0), Set(0.0, 1.0)), | ||
(Set(1.0), Set(1.0, 2.0))), 2) | ||
val metrics = new MultilabelMetrics(scoreAndLabels) | ||
val delta = 0.00001 | ||
val precision0 = 4.0 / (4 + 0) | ||
val precision1 = 2.0 / (2 + 1) | ||
val precision2 = 2.0 / (2 + 2) | ||
val recall0 = 4.0 / (4 + 1) | ||
val recall1 = 2.0 / (2 + 1) | ||
val recall2 = 2.0 / (2 + 2) | ||
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) | ||
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) | ||
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) | ||
val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2) | ||
val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2) | ||
val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass) | ||
|
||
val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) | ||
val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) | ||
|
||
println("Ev" + metrics.macroPrecisionDoc) | ||
println(macroPrecisionDoc) | ||
println("Ev" + metrics.macroRecallDoc) | ||
println(macroRecallDoc) | ||
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta) | ||
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta) | ||
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta) | ||
assert(math.abs(metrics.recallClass(0.0) - recall0) < delta) | ||
assert(math.abs(metrics.recallClass(1.0) - recall1) < delta) | ||
assert(math.abs(metrics.recallClass(2.0) - recall2) < delta) | ||
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta) | ||
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta) | ||
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta) | ||
|
||
assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta) | ||
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta) | ||
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta) | ||
|
||
assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta) | ||
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta) | ||
|
||
|
||
} | ||
|
||
} |