diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 8b30da69ec86a..3b7ba7288c0f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD import scala.collection.Map @@ -31,19 +32,19 @@ import scala.collection.Map * @param predictionAndLabels an RDD of (prediction, label) pairs. */ @Experimental -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logging { +class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum private lazy val tpByClass: Map[Double, Int] = predictionAndLabels .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) - }.reduceByKey(_ + _) + (label, if (label == prediction) 1 else 0) + }.reduceByKey(_ + _) .collectAsMap() private lazy val fpByClass: Map[Double, Int] = predictionAndLabels .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) - }.reduceByKey(_ + _) + (prediction, if (prediction != label) 1 else 0) + }.reduceByKey(_ + _) .collectAsMap() private lazy val confusions = predictionAndLabels.map { case (prediction, label) => ((prediction, label), 1) @@ -55,12 +56,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logg * they are ordered by class label ascending, * as in "labels" */ - lazy val confusionMatrix: Array[Array[Int]] = { - val matrix = Array.ofDim[Int](labels.size, labels.size) + lazy val confusionMatrix: Matrix = { + val transposedMatrix = Array.ofDim[Double](labels.size, labels.size) for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) { - matrix(j)(i) = confusions.getOrElse((labels(i), labels(j)), 0) + transposedMatrix(i)(j) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble } - matrix + val flatMatrix = transposedMatrix.flatMap(arr => arr) + Matrices.dense(transposedMatrix.length, transposedMatrix(0).length, flatMatrix) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index c7b01f0135251..555343d7cdb21 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.LocalSparkContext import org.scalatest.FunSuite @@ -28,7 +29,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { * |1|3|0| true class1 (4 instances) * |0|0|1| true class2 (1 instance) */ - val confusionMatrix = Array(Array(2, 1, 1), Array(1, 3, 0), Array(0, 0, 1)) + val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1)) val labels = Array(0.0, 1.0, 2.0) val predictionAndLabels = sc.parallelize( Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), @@ -51,7 +52,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.deep == confusionMatrix.deep) + assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)