From cc658100c161fee84fff48874715fd542c518db4 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sun, 30 Mar 2014 15:13:28 +0800 Subject: [PATCH] add parallel mean and variance --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 21 +++++++++++++++++++ .../mllib/rdd/VectorRDDFunctionsSuite.scala | 8 +++++++ 2 files changed, 29 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index 1f53a60bc3171..1e941b2429914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -21,6 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.rdd.RDD +import breeze.numerics._ /** * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an @@ -161,4 +162,24 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { } } } + + def parallelMeanAndVar(size: Int): (Vector, Vector) = { + val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0))( + seqOp = (c, v) => (c, v) match { + case ((prevMean, prevM2n, cnt), currData) => + val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0) + (currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0) + }, + combOp = (lhs, rhs) => (lhs, rhs) match { + case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) => + val totalCnt = lhsCnt + rhsCnt + val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt + val deltaMean = rhsMean - lhsMean + val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt) + (totalMean, totalM2n, totalCnt) + } + ) + + (Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index f4ff560148ede..1fab692a12533 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -38,6 +38,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { val colMeans = Array(4.0, 5.0, 6.0) val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0)) val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0)) + val colVar = Array(6.0, 6.0, 6.0) val maxVec = Array(7.0, 8.0, 9.0) val minVec = Array(1.0, 2.0, 3.0) @@ -128,6 +129,13 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(equivVector(lhs, rhs), "Column shrink error.") } } + + test("meanAndVar") { + val data = sc.parallelize(localData, 2) + val (mean, sd) = data.parallelMeanAndVar(3) + assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.") + assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.") + } } object VectorRDDFunctionsSuite {