Skip to content

Commit

Permalink
change def to lazy val to make sure that the computations in function…
Browse files Browse the repository at this point in the history
… be evaluated only once
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent f7a3ca2 commit 18cf072
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.rdd

import breeze.linalg.{axpy, Vector => BV}
Expand All @@ -26,12 +27,12 @@ import org.apache.spark.rdd.RDD
* elements count.
*/
trait VectorRDDStatisticalSummary {
def mean(): Vector
def variance(): Vector
def totalCount(): Long
def numNonZeros(): Vector
def max(): Vector
def min(): Vector
def mean: Vector
def variance: Vector
def totalCount: Long
def numNonZeros: Vector
def max: Vector
def min: Vector
}

private class Aggregator(
Expand All @@ -42,30 +43,32 @@ private class Aggregator(
val currMax: BV[Double],
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {

override def mean(): Vector = {
Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
}
override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)

override def variance(): Vector = {
override lazy val variance = {
val deltaMean = currMean
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
realM2n :/= totalCnt
Vectors.fromBreeze(realM2n)
var i = 0
while(i < currM2n.size) {
currM2n(i) -= deltaMean(i) * deltaMean(i) * nnz(i) * (nnz(i)-totalCnt) / totalCnt
currM2n(i) /= totalCnt
i += 1
}
Vectors.fromBreeze(currM2n)
}

override def totalCount(): Long = totalCnt.toLong
override lazy val totalCount: Long = totalCnt.toLong

override def numNonZeros(): Vector = Vectors.fromBreeze(nnz)
override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz)

override def max(): Vector = {
override lazy val max: Vector = {
nnz.activeIterator.foreach {
case (id, count) =>
if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0
}
Vectors.fromBreeze(currMax)
}

override def min(): Vector = {
override lazy val min: Vector = {
nnz.activeIterator.foreach {
case (id, count) =>
if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0
Expand All @@ -78,6 +81,7 @@ private class Aggregator(
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
// this case is used for filtering the zero elements if the vector is a dense one.
case (id, 0.0) =>
case (id, value) =>
if (currMax(id) < value) currMax(id) = value
Expand Down Expand Up @@ -106,7 +110,8 @@ private class Aggregator(
other.currMean.activeIterator.foreach {
case (id, 0.0) =>
case (id, value) =>
currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
currMean(id) =
(currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
}

other.currM2n.activeIterator.foreach {
Expand Down Expand Up @@ -157,4 +162,4 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.rdd

import scala.collection.mutable.ArrayBuffer

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.VectorRDDFunctionsSuite._
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLUtils._

Expand All @@ -29,7 +31,6 @@ import org.apache.spark.mllib.util.MLUtils._
* between dense and sparse vector are tested.
*/
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
import VectorRDDFunctionsSuite._

val localData = Array(
Vectors.dense(1.0, 2.0, 3.0),
Expand All @@ -47,16 +48,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
val (summary, denseTime) =
time(data.summarizeStatistics())

assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)),
assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
"Column mean do not match.")
assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)),

assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
"Column variance do not match.")
assert(summary.totalCount() === 3, "Column cnt do not match.")
assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)),

assert(summary.totalCount === 3, "Column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
"Column nnz do not match.")
assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)),

assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)),
"Column max do not match.")
assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)),

assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)),
"Column min do not match.")

val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
Expand All @@ -82,4 +88,4 @@ object VectorRDDFunctionsSuite {
val denominator = math.max(lhs, rhs)
math.abs(lhs - rhs) / denominator < 0.3
}
}
}

0 comments on commit 18cf072

Please sign in to comment.