From 18cf07215c1e781e6b96c3986e62ec9e3e9fa788 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 2 Apr 2014 18:23:57 +0800 Subject: [PATCH] change def to lazy val to make sure that the computations in function be evaluated only once --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 43 +++++++++++-------- .../mllib/rdd/VectorRDDFunctionsSuite.scala | 22 ++++++---- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index 736fc363f2e5d..3ddc507a2e601 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.rdd import breeze.linalg.{axpy, Vector => BV} @@ -26,12 +27,12 @@ import org.apache.spark.rdd.RDD * elements count. */ trait VectorRDDStatisticalSummary { - def mean(): Vector - def variance(): Vector - def totalCount(): Long - def numNonZeros(): Vector - def max(): Vector - def min(): Vector + def mean: Vector + def variance: Vector + def totalCount: Long + def numNonZeros: Vector + def max: Vector + def min: Vector } private class Aggregator( @@ -42,22 +43,24 @@ private class Aggregator( val currMax: BV[Double], val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable { - override def mean(): Vector = { - Vectors.fromBreeze(currMean :* nnz :/ totalCnt) - } + override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt) - override def variance(): Vector = { + override lazy val variance = { val deltaMean = currMean - val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt) - realM2n :/= totalCnt - Vectors.fromBreeze(realM2n) + var i = 0 + while(i < currM2n.size) { + currM2n(i) -= deltaMean(i) * deltaMean(i) * nnz(i) * (nnz(i)-totalCnt) / totalCnt + currM2n(i) /= totalCnt + i += 1 + } + Vectors.fromBreeze(currM2n) } - override def totalCount(): Long = totalCnt.toLong + override lazy val totalCount: Long = totalCnt.toLong - override def numNonZeros(): Vector = Vectors.fromBreeze(nnz) + override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz) - override def max(): Vector = { + override lazy val max: Vector = { nnz.activeIterator.foreach { case (id, count) => if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0 @@ -65,7 +68,7 @@ private class Aggregator( Vectors.fromBreeze(currMax) } - override def min(): Vector = { + override lazy val min: Vector = { nnz.activeIterator.foreach { case (id, count) => if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0 @@ -78,6 +81,7 @@ private class Aggregator( */ def add(currData: BV[Double]): this.type = { currData.activeIterator.foreach { + // this case is used for filtering the zero elements if the vector is a dense one. case (id, 0.0) => case (id, value) => if (currMax(id) < value) currMax(id) = value @@ -106,7 +110,8 @@ private class Aggregator( other.currMean.activeIterator.foreach { case (id, 0.0) => case (id, value) => - currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id)) + currMean(id) = + (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id)) } other.currM2n.activeIterator.foreach { @@ -157,4 +162,4 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { (aggregator1, aggregator2) => aggregator1.merge(aggregator2) ) } -} \ No newline at end of file +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index ec76c2279697a..5eb9d8e2c3da8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.rdd import scala.collection.mutable.ArrayBuffer @@ -21,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.VectorRDDFunctionsSuite._ import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.MLUtils._ @@ -29,7 +31,6 @@ import org.apache.spark.mllib.util.MLUtils._ * between dense and sparse vector are tested. */ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { - import VectorRDDFunctionsSuite._ val localData = Array( Vectors.dense(1.0, 2.0, 3.0), @@ -47,16 +48,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { val (summary, denseTime) = time(data.summarizeStatistics()) - assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)), + assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.") - assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)), + + assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.") - assert(summary.totalCount() === 3, "Column cnt do not match.") - assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)), + + assert(summary.totalCount === 3, "Column cnt do not match.") + + assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.") - assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)), + + assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.") - assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)), + + assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.") val dataForSparse = sc.parallelize(sparseData.toSeq, 2) @@ -82,4 +88,4 @@ object VectorRDDFunctionsSuite { val denominator = math.max(lhs, rhs) math.abs(lhs - rhs) / denominator < 0.3 } -} \ No newline at end of file +}