diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala new file mode 100644 index 0000000000000..76fe96a5938c0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ + +/** + * Evaluator for multiclass classification. + * + * @param scoreAndLabels an RDD of (score, label) pairs. + */ +class MulticlassMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging { + + /* class = category; label = instance of class; prediction = instance of class */ + + private lazy val labelCountByClass = scoreAndLabels.values.countByValue() + private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count} + private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => + (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap + private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => + (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap + + /** + * Returns Precision for a given label (category) + * @param label the label. + * @return Precision. + */ + def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0 + else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble + + /** + * Returns Recall for a given label (category) + * @param label the label. + * @return Recall. + */ + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble + + /** + * Returns F1-measure for a given label (category) + * @param label the label. + * @return F1-measure. + */ + def f1Measure(label: Double): Double = + 2 * precision(label) * recall(label) / (precision(label) + recall(label)) + + /** + * Returns micro-averaged Recall + * (equals to microPrecision and microF1measure for multiclass classifier) + * @return microRecall. + */ + def microRecall: Double = + tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble + + /** + * Returns micro-averaged Precision + * (equals to microPrecision and microF1measure for multiclass classifier) + * @return microPrecision. + */ + def microPrecision: Double = microRecall + + /** + * Returns micro-averaged F1-measure + * (equals to microPrecision and microRecall for multiclass classifier) + * @return microF1measure. + */ + def microF1Measure: Double = microRecall + + /** + * Returns weighted averaged Recall + * @return weightedRecall. + */ + def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => + wRecall + recall(category) * count.toDouble / labelCount.toDouble} + + /** + * Returns weighted averaged Precision + * @return weightedPrecision. + */ + def weightedPrecision: Double = + labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => + wPrecision + precision(category) * count.toDouble / labelCount.toDouble} + + /** + * Returns weighted averaged F1-measure + * @return weightedF1Measure. + */ + def weightedF1Measure: Double = + labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) => + wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble} + + /** + * Returns map with Precisions for individual classes + * @return precisionPerClass. + */ + def precisionPerClass = + labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap + + /** + * Returns map with Recalls for individual classes + * @return recallPerClass. + */ + def recallPerClass = + labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap + + /** + * Returns map with F1-measures for individual classes + * @return f1MeasurePerClass. + */ + def f1MeasurePerClass = + labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala new file mode 100644 index 0000000000000..b4e3664ab7916 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.mllib.util.LocalSparkContext +import org.scalatest.FunSuite + +class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { + test("Multiclass evaluation metrics") { + /* + * Confusion matrix for 3-class classification with total 9 instances: + * |2|1|1| true class0 (4 instances) + * |1|3|0| true class1 (4 instances) + * |0|0|1| true class2 (1 instance) + * + */ + val scoreAndLabels = sc.parallelize( + Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), + (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) + val metrics = new MulticlassMetrics(scoreAndLabels) + + val delta = 0.00001 + val precision0 = 2.0 / (2.0 + 1.0) + val precision1 = 3.0 / (3.0 + 1.0) + val precision2 = 1.0 / (1.0 + 1.0) + val recall0 = 2.0 / (2.0 + 2.0) + val recall1 = 3.0 / (3.0 + 1.0) + val recall2 = 1.0 / (1.0 + 0.0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta) + + assert(math.abs(metrics.microRecall - + (2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta) + assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta) + assert(math.abs(metrics.microRecall - metrics.microF1Measure) < delta) + assert(math.abs(metrics.microRecall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedPrecision - + ((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta) + assert(math.abs(metrics.weightedRecall - + ((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta) + assert(math.abs(metrics.weightedF1Measure - + ((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta) + + } +}