Skip to content

Commit

Permalink
minor revision
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 86522c4 commit 548e9de
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.rdd

import breeze.linalg.{axpy, Vector => BV}
import breeze.linalg.{Vector => BV, DenseVector => BDV}

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.rdd.RDD
Expand All @@ -29,60 +29,59 @@ import org.apache.spark.rdd.RDD
trait VectorRDDStatisticalSummary {
def mean: Vector
def variance: Vector
def totalCount: Long
def count: Long
def numNonZeros: Vector
def max: Vector
def min: Vector
}

/**
* Aggregates [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]]
* together with add() and merge() function.
* together with add() and merge() function. Online variance solution used in add() function, while
* parallel variance solution used in merge() function. Reference here:
* [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution here
* ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm to
* O(nnz). Real variance is computed here after we get other statistics, simply by another parallel
* combination process.
*/
private class Aggregator(
val currMean: BV[Double],
val currM2n: BV[Double],
private class VectorRDDStatisticsAggregator(
val currMean: BDV[Double],
val currM2n: BDV[Double],
var totalCnt: Double,
val nnz: BV[Double],
val currMax: BV[Double],
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {
val nnz: BDV[Double],
val currMax: BDV[Double],
val currMin: BDV[Double]) extends VectorRDDStatisticalSummary with Serializable {

// lazy val is used for computing only once time. Same below.
override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)

// Online variance solution used in add() function, while parallel variance solution used in
// merge() function. Reference here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// Solution here ignoring the zero elements when calling add() and merge(), for decreasing the
// O(n) algorithm to O(nnz). Real variance is computed here after we get other statistics, simply
// by another parallel combination process.
override lazy val variance = {
val deltaMean = currMean
var i = 0
while(i < currM2n.size) {
currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt-nnz(i)) / totalCnt
while (i < currM2n.size) {
currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
currM2n(i) /= totalCnt
i += 1
}
Vectors.fromBreeze(currM2n)
}

override lazy val totalCount: Long = totalCnt.toLong
override lazy val count: Long = totalCnt.toLong

override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz)

override lazy val max: Vector = {
nnz.iterator.foreach {
case (id, count) =>
if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0
if ((count < totalCnt) && (currMax(id) < 0.0)) currMax(id) = 0.0
}
Vectors.fromBreeze(currMax)
}

override lazy val min: Vector = {
nnz.iterator.foreach {
case (id, count) =>
if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0
if ((count < totalCnt) && (currMin(id) > 0.0)) currMin(id) = 0.0
}
Vectors.fromBreeze(currMin)
}
Expand All @@ -92,7 +91,7 @@ private class Aggregator(
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
// this case is used for filtering the zero elements if the vector is a dense one.
// this case is used for filtering the zero elements if the vector.
case (id, 0.0) =>
case (id, value) =>
if (currMax(id) < value) currMax(id) = value
Expand All @@ -112,7 +111,7 @@ private class Aggregator(
/**
* Combine function used for combining intermediate results together from every worker.
*/
def merge(other: Aggregator): this.type = {
def merge(other: VectorRDDStatisticsAggregator): this.type = {

totalCnt += other.totalCnt

Expand Down Expand Up @@ -145,7 +144,7 @@ private class Aggregator(
if (currMin(id) > value) currMin(id) = value
}

axpy(1.0, other.nnz, nnz)
nnz += other.nnz
this
}
}
Expand All @@ -160,18 +159,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
/**
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
*/
def summarizeStatistics(): VectorRDDStatisticalSummary = {
val size = self.take(1).head.size
def computeSummaryStatistics(): VectorRDDStatisticalSummary = {
val size = self.first().size

val zeroValue = new Aggregator(
BV.zeros[Double](size),
BV.zeros[Double](size),
val zeroValue = new VectorRDDStatisticsAggregator(
BDV.zeros[Double](size),
BDV.zeros[Double](size),
0.0,
BV.zeros[Double](size),
BV.fill(size)(Double.MinValue),
BV.fill(size)(Double.MaxValue))
BDV.zeros[Double](size),
BDV.fill(size)(Double.MinValue),
BDV.fill(size)(Double.MaxValue))

self.map(_.toBreeze).aggregate[Aggregator](zeroValue)(
self.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("dense statistical summary") {
val data = sc.parallelize(localData, 2)
val summary = data.summarizeStatistics()
val summary = data.computeSummaryStatistics()

assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
"Dense column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
"Dense column variance do not match.")

assert(summary.totalCount === 3, "Dense column cnt do not match.")
assert(summary.count === 3, "Dense column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
"Dense column nnz do not match.")
Expand All @@ -67,15 +67,15 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("sparse statistical summary") {
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
val summary = dataForSparse.summarizeStatistics()
val summary = dataForSparse.computeSummaryStatistics()

assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)),
"Sparse column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)),
"Sparse column variance do not match.")

assert(summary.totalCount === 100, "Sparse column cnt do not match.")
assert(summary.count === 100, "Sparse column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 1.0, 0.0)),
"Sparse column nnz do not match.")
Expand Down

0 comments on commit 548e9de

Please sign in to comment.