From 48ee053b86210d0d2ff03c13a0c4187d962c0a5d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 4 Apr 2014 09:52:23 +0800 Subject: [PATCH] fix minor error --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 27 +++++++++++++++++-- .../mllib/rdd/VectorRDDFunctionsSuite.scala | 18 ++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index 6e6dd242d3853..0b677d9c4fdef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -27,11 +27,35 @@ import org.apache.spark.rdd.RDD * count. */ trait VectorRDDStatisticalSummary { + + /** + * Computes the mean of columns in RDD[Vector]. + */ def mean: Vector + + /** + * Computes the sample variance of columns in RDD[Vector]. + */ def variance: Vector + + /** + * Computes number of vectors in RDD[Vector]. + */ def count: Long + + /** + * Computes the number of non-zero elements in each column of RDD[Vector]. + */ def numNonZeros: Vector + + /** + * Computes the maximum of each column in RDD[Vector]. + */ def max: Vector + + /** + * Computes the minimum of each column in RDD[Vector]. + */ def min: Vector } @@ -53,7 +77,6 @@ private class VectorRDDStatisticsAggregator( val currMin: BDV[Double]) extends VectorRDDStatisticalSummary with Serializable { - // lazy val is used for computing only once time. Same below. override def mean = { val realMean = BDV.zeros[Double](currMean.length) var i = 0 @@ -71,7 +94,7 @@ private class VectorRDDStatisticsAggregator( while (i < currM2n.size) { realVariance(i) = currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= totalCnt + realVariance(i) /= (totalCnt - 1.0) i += 1 } Vectors.fromBreeze(realVariance) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index bf1b3693cfbf0..9bf92d54429a4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -34,8 +34,8 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { val localData = Array( Vectors.dense(1.0, 2.0, 3.0), - Vectors.dense(4.0, 5.0, 6.0), - Vectors.dense(7.0, 8.0, 9.0) + Vectors.dense(4.0, 0.0, 6.0), + Vectors.dense(0.0, 8.0, 9.0) ) val sparseData = ArrayBuffer(Vectors.sparse(3, Seq((0, 1.0)))) @@ -47,21 +47,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = sc.parallelize(localData, 2) val summary = data.computeSummaryStatistics() - assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)), + assert(equivVector(summary.mean, Vectors.dense(5.0 / 3.0, 10.0 / 3.0, 6.0)), "Dense column mean do not match.") - assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)), + assert(equivVector(summary.variance, Vectors.dense(4.333333333333334, 17.333333333333336, 9.0)), "Dense column variance do not match.") assert(summary.count === 3, "Dense column cnt do not match.") - assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)), + assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 2.0, 3.0)), "Dense column nnz do not match.") - assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)), + assert(equivVector(summary.max, Vectors.dense(4.0, 8.0, 9.0)), "Dense column max do not match.") - assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)), + assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 3.0)), "Dense column min do not match.") } @@ -72,7 +72,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)), "Sparse column mean do not match.") - assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)), + assert(equivVector(summary.variance, Vectors.dense(0.258989898989899, 0.25, 0.0)), "Sparse column variance do not match.") assert(summary.count === 100, "Sparse column cnt do not match.") @@ -90,6 +90,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { object VectorRDDFunctionsSuite { def equivVector(lhs: Vector, rhs: Vector): Boolean = { - (lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9 + (lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-5 } }