From 967d041fa806a87a8bdf3bd74fac84a7a6fe7495 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 2 Apr 2014 17:47:12 +0800 Subject: [PATCH] full revision with Aggregator class --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 155 ++++-------------- .../mllib/rdd/VectorRDDFunctionsSuite.scala | 14 +- 2 files changed, 42 insertions(+), 127 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index b5518e4a91a4f..23623e2a28309 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -40,14 +40,12 @@ private class Aggregator( var totalCnt: Double, val nnz: BV[Double], val currMax: BV[Double], - val currMin: BV[Double]) extends VectorRDDStatisticalSummary { - nnz.activeIterator.foreach { - case (id, 0.0) => - currMax(id) = 0.0 - currMin(id) = 0.0 - case _ => + val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable { + + override def mean(): Vector = { + Vectors.fromBreeze(currMean :* nnz :/ totalCnt) } - override def mean(): Vector = Vectors.fromBreeze(currMean :* nnz :/ totalCnt) + override def variance(): Vector = { val deltaMean = currMean val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt) @@ -58,8 +56,23 @@ private class Aggregator( override def totalCount(): Long = totalCnt.toLong override def numNonZeros(): Vector = Vectors.fromBreeze(nnz) - override def max(): Vector = Vectors.fromBreeze(currMax) - override def min(): Vector = Vectors.fromBreeze(currMin) + + override def max(): Vector = { + nnz.activeIterator.foreach { + case (id, 0.0) => currMax(id) = 0.0 + case _ => + } + Vectors.fromBreeze(currMax) + } + + override def min(): Vector = { + nnz.activeIterator.foreach { + case (id, 0.0) => currMin(id) = 0.0 + case _ => + } + Vectors.fromBreeze(currMin) + } + /** * Aggregate function used for aggregating elements in a worker together. */ @@ -75,15 +88,19 @@ private class Aggregator( currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean) nnz(id) += 1.0 - totalCnt += 1.0 } + + totalCnt += 1.0 this } + /** * Combine function used for combining intermediate results together from every worker. */ - def merge(other: this.type): this.type = { + def merge(other: Aggregator): this.type = { + totalCnt += other.totalCnt + val deltaMean = currMean - other.currMean other.currMean.activeIterator.foreach { @@ -114,14 +131,6 @@ private class Aggregator( } } -case class VectorRDDStatisticalAggregator( - mean: BV[Double], - statCnt: BV[Double], - totalCnt: Double, - nnz: BV[Double], - currMax: BV[Double], - currMin: BV[Double]) - /** * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an * implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use @@ -129,83 +138,13 @@ case class VectorRDDStatisticalAggregator( */ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { - /** - * Aggregate function used for aggregating elements in a worker together. - */ - private def seqOp( - aggregator: VectorRDDStatisticalAggregator, - currData: BV[Double]): VectorRDDStatisticalAggregator = { - aggregator match { - case VectorRDDStatisticalAggregator(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) => - currData.activeIterator.foreach { - case (id, 0.0) => - case (id, value) => - if (maxVec(id) < value) maxVec(id) = value - if (minVec(id) > value) minVec(id) = value - - val tmpPrevMean = prevMean(id) - prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0) - prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean) - - nnzVec(id) += 1.0 - } - - VectorRDDStatisticalAggregator( - prevMean, - prevM2n, - cnt + 1.0, - nnzVec, - maxVec, - minVec) - } - } - - /** - * Combine function used for combining intermediate results together from every worker. - */ - private def combOp( - statistics1: VectorRDDStatisticalAggregator, - statistics2: VectorRDDStatisticalAggregator): VectorRDDStatisticalAggregator = { - (statistics1, statistics2) match { - case (VectorRDDStatisticalAggregator(mean1, m2n1, cnt1, nnz1, max1, min1), - VectorRDDStatisticalAggregator(mean2, m2n2, cnt2, nnz2, max2, min2)) => - val totalCnt = cnt1 + cnt2 - val deltaMean = mean2 - mean1 - - mean2.activeIterator.foreach { - case (id, 0.0) => - case (id, value) => - mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id)) - } - - m2n2.activeIterator.foreach { - case (id, 0.0) => - case (id, value) => - m2n1(id) += - value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id)) - } - - max2.activeIterator.foreach { - case (id, value) => - if (max1(id) < value) max1(id) = value - } - - min2.activeIterator.foreach { - case (id, value) => - if (min1(id) > value) min1(id) = value - } - - axpy(1.0, nnz2, nnz1) - VectorRDDStatisticalAggregator(mean1, m2n1, totalCnt, nnz1, max1, min1) - } - } - /** * Compute full column-wise statistics for the RDD with the size of Vector as input parameter. */ - def summarizeStatistics(): VectorRDDStatisticalAggregator = { + def summarizeStatistics(): VectorRDDStatisticalSummary = { val size = self.take(1).head.size - val zeroValue = VectorRDDStatisticalAggregator( + + val zeroValue = new Aggregator( BV.zeros[Double](size), BV.zeros[Double](size), 0.0, @@ -213,33 +152,9 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { BV.fill(size)(Double.MinValue), BV.fill(size)(Double.MaxValue)) - val VectorRDDStatisticalAggregator(currMean, currM2n, totalCnt, nnz, currMax, currMin) = - self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp) - - // solve real mean - val realMean = currMean :* nnz :/ totalCnt - - // solve real m2n - val deltaMean = currMean - val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt) - - // remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue. - nnz.activeIterator.foreach { - case (id, 0.0) => - currMax(id) = 0.0 - currMin(id) = 0.0 - case _ => - } - - // get variance - realM2n :/= totalCnt - - VectorRDDStatisticalAggregator( - realMean, - realM2n, - totalCnt, - nnz, - currMax, - currMin) + self.map(_.toBreeze).aggregate[Aggregator](zeroValue)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + ) } } \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index 49cde4b4e11d4..ec76c2279697a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -44,19 +44,19 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { test("full-statistics") { val data = sc.parallelize(localData, 2) - val (VectorRDDStatisticalAggregator(mean, variance, cnt, nnz, max, min), denseTime) = + val (summary, denseTime) = time(data.summarizeStatistics()) - assert(equivVector(Vectors.fromBreeze(mean), Vectors.dense(4.0, 5.0, 6.0)), + assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.") - assert(equivVector(Vectors.fromBreeze(variance), Vectors.dense(6.0, 6.0, 6.0)), + assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.") - assert(cnt === 3.0, "Column cnt do not match.") - assert(equivVector(Vectors.fromBreeze(nnz), Vectors.dense(3.0, 3.0, 3.0)), + assert(summary.totalCount() === 3, "Column cnt do not match.") + assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.") - assert(equivVector(Vectors.fromBreeze(max), Vectors.dense(7.0, 8.0, 9.0)), + assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.") - assert(equivVector(Vectors.fromBreeze(min), Vectors.dense(1.0, 2.0, 3.0)), + assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.") val dataForSparse = sc.parallelize(sparseData.toSeq, 2)