Skip to content

Commit

Permalink
fix minor error
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent e624f93 commit 48ee053
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,35 @@ import org.apache.spark.rdd.RDD
* count.
*/
trait VectorRDDStatisticalSummary {

/**
* Computes the mean of columns in RDD[Vector].
*/
def mean: Vector

/**
* Computes the sample variance of columns in RDD[Vector].
*/
def variance: Vector

/**
* Computes number of vectors in RDD[Vector].
*/
def count: Long

/**
* Computes the number of non-zero elements in each column of RDD[Vector].
*/
def numNonZeros: Vector

/**
* Computes the maximum of each column in RDD[Vector].
*/
def max: Vector

/**
* Computes the minimum of each column in RDD[Vector].
*/
def min: Vector
}

Expand All @@ -53,7 +77,6 @@ private class VectorRDDStatisticsAggregator(
val currMin: BDV[Double])
extends VectorRDDStatisticalSummary with Serializable {

// lazy val is used for computing only once time. Same below.
override def mean = {
val realMean = BDV.zeros[Double](currMean.length)
var i = 0
Expand All @@ -71,7 +94,7 @@ private class VectorRDDStatisticsAggregator(
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= totalCnt
realVariance(i) /= (totalCnt - 1.0)
i += 1
}
Vectors.fromBreeze(realVariance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

val localData = Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(4.0, 5.0, 6.0),
Vectors.dense(7.0, 8.0, 9.0)
Vectors.dense(4.0, 0.0, 6.0),
Vectors.dense(0.0, 8.0, 9.0)
)

val sparseData = ArrayBuffer(Vectors.sparse(3, Seq((0, 1.0))))
Expand All @@ -47,21 +47,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
val data = sc.parallelize(localData, 2)
val summary = data.computeSummaryStatistics()

assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
assert(equivVector(summary.mean, Vectors.dense(5.0 / 3.0, 10.0 / 3.0, 6.0)),
"Dense column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
assert(equivVector(summary.variance, Vectors.dense(4.333333333333334, 17.333333333333336, 9.0)),
"Dense column variance do not match.")

assert(summary.count === 3, "Dense column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 2.0, 3.0)),
"Dense column nnz do not match.")

assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)),
assert(equivVector(summary.max, Vectors.dense(4.0, 8.0, 9.0)),
"Dense column max do not match.")

assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)),
assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 3.0)),
"Dense column min do not match.")
}

Expand All @@ -72,7 +72,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)),
"Sparse column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)),
assert(equivVector(summary.variance, Vectors.dense(0.258989898989899, 0.25, 0.0)),
"Sparse column variance do not match.")

assert(summary.count === 100, "Sparse column cnt do not match.")
Expand All @@ -90,6 +90,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

object VectorRDDFunctionsSuite {
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-5
}
}

0 comments on commit 48ee053

Please sign in to comment.