diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala new file mode 100644 index 0000000000000..7858ec602483f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.rdd.RDDFunctions._ + +/** + * Computes the area under the curve (AUC) using the trapezoidal rule. + */ +private[evaluation] object AreaUnderCurve { + + /** + * Uses the trapezoidal rule to compute the area under the line connecting the two input points. + * @param points two 2D points stored in Seq + */ + private def trapezoid(points: Seq[(Double, Double)]): Double = { + require(points.length == 2) + val x = points.head + val y = points.last + (y._1 - x._1) * (y._2 + x._2) / 2.0 + } + + /** + * Returns the area under the given curve. + * + * @param curve a RDD of ordered 2D points stored in pairs representing a curve + */ + def of(curve: RDD[(Double, Double)]): Double = { + curve.sliding(2).aggregate(0.0)( + seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combOp = _ + _ + ) + } + + /** + * Returns the area under the given curve. + * + * @param curve an iterator over ordered 2D points stored in pairs representing a curve + */ + def of(curve: Iterable[(Double, Double)]): Double = { + curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)( + seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combop = _ + _ + ) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala new file mode 100644 index 0000000000000..562663ad36b40 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary classification evaluation metric computer. + */ +private[evaluation] trait BinaryClassificationMetricComputer extends Serializable { + def apply(c: BinaryConfusionMatrix): Double +} + +/** Precision. */ +private[evaluation] object Precision extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) +} + +/** False positive rate. */ +private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numFalsePositives.toDouble / c.numNegatives +} + +/** Recall. */ +private[evaluation] object Recall extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / c.numPositives +} + +/** + * F-Measure. + * @param beta the beta constant in F-Measure + * @see http://en.wikipedia.org/wiki/F1_score + */ +private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer { + private val beta2 = beta * beta + override def apply(c: BinaryConfusionMatrix): Double = { + val precision = Precision(c) + val recall = Recall(c) + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala new file mode 100644 index 0000000000000..ed7b0fc943367 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.apache.spark.rdd.{UnionRDD, RDD} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.AreaUnderCurve +import org.apache.spark.Logging + +/** + * Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]]. + * + * @param count label counter for labels with scores greater than or equal to the current score + * @param totalCount label counter for all labels + */ +private case class BinaryConfusionMatrixImpl( + count: LabelCounter, + totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { + + /** number of true positives */ + override def numTruePositives: Long = count.numPositives + + /** number of false positives */ + override def numFalsePositives: Long = count.numNegatives + + /** number of false negatives */ + override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives + + /** number of true negatives */ + override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives + + /** number of positives */ + override def numPositives: Long = totalCount.numPositives + + /** number of negatives */ + override def numNegatives: Long = totalCount.numNegatives +} + +/** + * Evaluator for binary classification. + * + * @param scoreAndLabels an RDD of (score, label) pairs. + */ +class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) + extends Serializable with Logging { + + private lazy val ( + cumulativeCounts: RDD[(Double, LabelCounter)], + confusions: RDD[(Double, BinaryConfusionMatrix)]) = { + // Create a bin for each distinct score value, count positives and negatives within each bin, + // and then sort by score values in descending order. + val counts = scoreAndLabels.combineByKey( + createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label, + mergeValue = (c: LabelCounter, label: Double) => c += label, + mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2 + ).sortByKey(ascending = false) + val agg = counts.values.mapPartitions({ iter => + val agg = new LabelCounter() + iter.foreach(agg += _) + Iterator(agg) + }, preservesPartitioning = true).collect() + val partitionwiseCumulativeCounts = + agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c) + val totalCount = partitionwiseCumulativeCounts.last + logInfo(s"Total counts: $totalCount") + val cumulativeCounts = counts.mapPartitionsWithIndex( + (index: Int, iter: Iterator[(Double, LabelCounter)]) => { + val cumCount = partitionwiseCumulativeCounts(index) + iter.map { case (score, c) => + cumCount += c + (score, cumCount.clone()) + } + }, preservesPartitioning = true) + cumulativeCounts.persist() + val confusions = cumulativeCounts.map { case (score, cumCount) => + (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) + } + (cumulativeCounts, confusions) + } + + /** Unpersist intermediate RDDs used in the computation. */ + def unpersist() { + cumulativeCounts.unpersist() + } + + /** Returns thresholds in descending order. */ + def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an RDD of (false positive rate, true positive rate) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + def roc(): RDD[(Double, Double)] = { + val rocCurve = createCurve(FalsePositiveRate, Recall) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) + } + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + def areaUnderROC(): Double = AreaUnderCurve.of(roc()) + + /** + * Returns the precision-recall curve, which is an RDD of (recall, precision), + * NOT (precision, recall), with (0.0, 1.0) prepended to it. + * @see http://en.wikipedia.org/wiki/Precision_and_recall + */ + def pr(): RDD[(Double, Double)] = { + val prCurve = createCurve(Recall, Precision) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) + first.union(prCurve) + } + + /** + * Computes the area under the precision-recall curve. + */ + def areaUnderPR(): Double = AreaUnderCurve.of(pr()) + + /** + * Returns the (threshold, F-Measure) curve. + * @param beta the beta factor in F-Measure computation. + * @return an RDD of (threshold, F-Measure) pairs. + * @see http://en.wikipedia.org/wiki/F1_score + */ + def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) + + /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) + + /** Returns the (threshold, precision) curve. */ + def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) + + /** Returns the (threshold, recall) curve. */ + def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) + + /** Creates a curve of (threshold, metric). */ + private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (s, c) => + (s, y(c)) + } + } + + /** Creates a curve of (metricX, metricY). */ + private def createCurve( + x: BinaryClassificationMetricComputer, + y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (_, c) => + (x(c), y(c)) + } + } +} + +/** + * A counter for positives and negatives. + * + * @param numPositives number of positive labels + * @param numNegatives number of negative labels + */ +private class LabelCounter( + var numPositives: Long = 0L, + var numNegatives: Long = 0L) extends Serializable { + + /** Processes a label. */ + def +=(label: Double): LabelCounter = { + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle + // -1.0 for negative as well. + if (label > 0.5) numPositives += 1L else numNegatives += 1L + this + } + + /** Merges another counter. */ + def +=(other: LabelCounter): LabelCounter = { + numPositives += other.numPositives + numNegatives += other.numNegatives + this + } + + override def clone: LabelCounter = { + new LabelCounter(numPositives, numNegatives) + } + + override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala new file mode 100644 index 0000000000000..75a75b216002a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary confusion matrix. + */ +private[evaluation] trait BinaryConfusionMatrix { + /** number of true positives */ + def numTruePositives: Long + + /** number of false positives */ + def numFalsePositives: Long + + /** number of false negatives */ + def numFalseNegatives: Long + + /** number of true negatives */ + def numTrueNegatives: Long + + /** number of positives */ + def numPositives: Long = numTruePositives + numFalseNegatives + + /** number of negatives */ + def numNegatives: Long = numFalsePositives + numTrueNegatives +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala new file mode 100644 index 0000000000000..873de871fd884 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD + +/** + * Machine learning specific RDD functions. + */ +private[mllib] +class RDDFunctions[T: ClassTag](self: RDD[T]) { + + /** + * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions and the window size is + * greater than 1. + */ + def sliding(windowSize: Int): RDD[Seq[T]] = { + require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") + if (windowSize == 1) { + self.map(Seq(_)) + } else { + new SlidingRDD[T](self, windowSize) + } + } +} + +private[mllib] +object RDDFunctions { + + /** Implicit conversion from an RDD to RDDFunctions. */ + implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala new file mode 100644 index 0000000000000..dd80782c0f001 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.rdd.RDD + +private[mllib] +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) + extends Partition with Serializable { + override val index: Int = idx +} + +/** + * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions. To make this operation + * efficient, the number of items per partition should be larger than the window size and the + * window size should be small, e.g., 2. + * + * @param parent the parent RDD + * @param windowSize the window size, must be greater than 1 + * + * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + */ +private[mllib] +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) + extends RDD[Seq[T]](parent) { + + require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + + override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + val part = split.asInstanceOf[SlidingRDDPartition[T]] + (firstParent[T].iterator(part.prev, context) ++ part.tail) + .sliding(windowSize) + .withPartial(false) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev) + + override def getPartitions: Array[Partition] = { + val parentPartitions = parent.partitions + val n = parentPartitions.size + if (n == 0) { + Array.empty + } else if (n == 1) { + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + } else { + val n1 = n - 1 + val w1 = windowSize - 1 + // Get the first w1 items of each partition, starting from the second partition. + val nextHeads = + parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true) + val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + var i = 0 + var partitionIndex = 0 + while (i < n1) { + var j = i + val tail = mutable.ListBuffer[T]() + // Keep appending to the current tail until appended a head of size w1. + while (j < n1 && nextHeads(j).size < w1) { + tail ++= nextHeads(j) + j += 1 + } + if (j < n1) { + tail ++= nextHeads(j) + j += 1 + } + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) + partitionIndex += 1 + // Skip appended heads. + i = j + } + // If the head of last partition has size w1, we also need to add this partition. + if (nextHeads.last.size == w1) { + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + } + partitions.toArray + } + } + + // TODO: Override methods such as aggregate, which only requires one Spark job. +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala new file mode 100644 index 0000000000000..1c9844f289fe0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { + test("auc computation") { + val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) + val auc = 4.0 + assert(AreaUnderCurve.of(curve) === auc) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) == auc) + } + + test("auc of an empty curve") { + val curve = Seq.empty[(Double, Double)] + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } + + test("auc of a curve with a single point") { + val curve = Seq((1.0, 1.0)) + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala new file mode 100644 index 0000000000000..173fdaefab3da --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.evaluation.AreaUnderCurve + +class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + test("binary evaluation metrics") { + val scoreAndLabels = sc.parallelize( + Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = Seq(1, 3, 3, 4) + val numFalsePositives = Seq(0, 1, 2, 3) + val numPositives = 4 + val numNegatives = 3 + val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recall = numTruePositives.map(t => t.toDouble / numPositives) + val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) + val pr = recall.zip(precision) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + assert(metrics.thresholds().collect().toSeq === threshold) + assert(metrics.roc().collect().toSeq === rocCurve) + assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) + assert(metrics.pr().collect().toSeq === prCurve) + assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) + assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) + assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) + assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) + assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala new file mode 100644 index 0000000000000..3f3b10dfff35e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.rdd.RDDFunctions._ + +class RDDFunctionsSuite extends FunSuite with LocalSparkContext { + + test("sliding") { + val data = 0 until 6 + for (numPartitions <- 1 to 8) { + val rdd = sc.parallelize(data, numPartitions) + for (windowSize <- 1 to 6) { + val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList + val expected = data.sliding(windowSize).map(_.toList).toList + assert(sliding === expected) + } + assert(rdd.sliding(7).collect().isEmpty, + "Should return an empty RDD if the window size is greater than the number of items.") + } + } + + test("sliding with empty partitions") { + val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) + val rdd = sc.parallelize(data, data.length).flatMap(s => s) + assert(rdd.partitions.size === data.length) + val sliding = rdd.sliding(3) + val expected = data.flatMap(x => x).sliding(3).toList + assert(sliding.collect().toList === expected) + } +}