From 548e9de33291436a459d6ae9f5dc0163bbfbf867 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 3 Apr 2014 12:43:16 +0800 Subject: [PATCH] minor revision --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 63 +++++++++---------- .../mllib/rdd/VectorRDDFunctionsSuite.scala | 8 +-- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index fcb59c571e4f8..fcb5a5f18b127 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.rdd -import breeze.linalg.{axpy, Vector => BV} +import breeze.linalg.{Vector => BV, DenseVector => BDV} import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.rdd.RDD @@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD trait VectorRDDStatisticalSummary { def mean: Vector def variance: Vector - def totalCount: Long + def count: Long def numNonZeros: Vector def max: Vector def min: Vector @@ -37,44 +37,43 @@ trait VectorRDDStatisticalSummary { /** * Aggregates [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]] - * together with add() and merge() function. + * together with add() and merge() function. Online variance solution used in add() function, while + * parallel variance solution used in merge() function. Reference here: + * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution here + * ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm to + * O(nnz). Real variance is computed here after we get other statistics, simply by another parallel + * combination process. */ -private class Aggregator( - val currMean: BV[Double], - val currM2n: BV[Double], +private class VectorRDDStatisticsAggregator( + val currMean: BDV[Double], + val currM2n: BDV[Double], var totalCnt: Double, - val nnz: BV[Double], - val currMax: BV[Double], - val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable { + val nnz: BDV[Double], + val currMax: BDV[Double], + val currMin: BDV[Double]) extends VectorRDDStatisticalSummary with Serializable { // lazy val is used for computing only once time. Same below. override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt) - // Online variance solution used in add() function, while parallel variance solution used in - // merge() function. Reference here: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - // Solution here ignoring the zero elements when calling add() and merge(), for decreasing the - // O(n) algorithm to O(nnz). Real variance is computed here after we get other statistics, simply - // by another parallel combination process. override lazy val variance = { val deltaMean = currMean var i = 0 - while(i < currM2n.size) { - currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt-nnz(i)) / totalCnt + while (i < currM2n.size) { + currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt currM2n(i) /= totalCnt i += 1 } Vectors.fromBreeze(currM2n) } - override lazy val totalCount: Long = totalCnt.toLong + override lazy val count: Long = totalCnt.toLong override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz) override lazy val max: Vector = { nnz.iterator.foreach { case (id, count) => - if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0 + if ((count < totalCnt) && (currMax(id) < 0.0)) currMax(id) = 0.0 } Vectors.fromBreeze(currMax) } @@ -82,7 +81,7 @@ private class Aggregator( override lazy val min: Vector = { nnz.iterator.foreach { case (id, count) => - if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0 + if ((count < totalCnt) && (currMin(id) > 0.0)) currMin(id) = 0.0 } Vectors.fromBreeze(currMin) } @@ -92,7 +91,7 @@ private class Aggregator( */ def add(currData: BV[Double]): this.type = { currData.activeIterator.foreach { - // this case is used for filtering the zero elements if the vector is a dense one. + // this case is used for filtering the zero elements if the vector. case (id, 0.0) => case (id, value) => if (currMax(id) < value) currMax(id) = value @@ -112,7 +111,7 @@ private class Aggregator( /** * Combine function used for combining intermediate results together from every worker. */ - def merge(other: Aggregator): this.type = { + def merge(other: VectorRDDStatisticsAggregator): this.type = { totalCnt += other.totalCnt @@ -145,7 +144,7 @@ private class Aggregator( if (currMin(id) > value) currMin(id) = value } - axpy(1.0, other.nnz, nnz) + nnz += other.nnz this } } @@ -160,18 +159,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { /** * Compute full column-wise statistics for the RDD with the size of Vector as input parameter. */ - def summarizeStatistics(): VectorRDDStatisticalSummary = { - val size = self.take(1).head.size + def computeSummaryStatistics(): VectorRDDStatisticalSummary = { + val size = self.first().size - val zeroValue = new Aggregator( - BV.zeros[Double](size), - BV.zeros[Double](size), + val zeroValue = new VectorRDDStatisticsAggregator( + BDV.zeros[Double](size), + BDV.zeros[Double](size), 0.0, - BV.zeros[Double](size), - BV.fill(size)(Double.MinValue), - BV.fill(size)(Double.MaxValue)) + BDV.zeros[Double](size), + BDV.fill(size)(Double.MinValue), + BDV.fill(size)(Double.MaxValue)) - self.map(_.toBreeze).aggregate[Aggregator](zeroValue)( + self.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2) ) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index b621bf79b6e8b..87cfd6c8c436c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -45,7 +45,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { test("dense statistical summary") { val data = sc.parallelize(localData, 2) - val summary = data.summarizeStatistics() + val summary = data.computeSummaryStatistics() assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)), "Dense column mean do not match.") @@ -53,7 +53,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)), "Dense column variance do not match.") - assert(summary.totalCount === 3, "Dense column cnt do not match.") + assert(summary.count === 3, "Dense column cnt do not match.") assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)), "Dense column nnz do not match.") @@ -67,7 +67,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { test("sparse statistical summary") { val dataForSparse = sc.parallelize(sparseData.toSeq, 2) - val summary = dataForSparse.summarizeStatistics() + val summary = dataForSparse.computeSummaryStatistics() assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)), "Sparse column mean do not match.") @@ -75,7 +75,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)), "Sparse column variance do not match.") - assert(summary.totalCount === 100, "Sparse column cnt do not match.") + assert(summary.count === 100, "Sparse column cnt do not match.") assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 1.0, 0.0)), "Sparse column nnz do not match.")