Skip to content

Commit

Permalink
all-in-one version test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent cc65810 commit 1338ea1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.rdd.RDD
import breeze.numerics._
import breeze.linalg._

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
Expand Down Expand Up @@ -163,23 +163,34 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
}
}

def parallelMeanAndVar(size: Int): (Vector, Vector) = {
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0))(
def parallelMeanAndVar(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0, BV.zeros[Double](size), BV.fill(size){Double.MinValue}, BV.fill(size){Double.MaxValue}))(
seqOp = (c, v) => (c, v) match {
case ((prevMean, prevM2n, cnt), currData) =>
case ((prevMean, prevM2n, cnt, nnz, maxVec, minVec), currData) =>
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0)
val nonZeroCnt = Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
currData.activeIterator.foreach { case (id, value) =>
if (maxVec(id) < value) maxVec(id) = value
if (minVec(id) > value) minVec(id) = value
}
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0, nnz + nonZeroCnt, maxVec, minVec)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) =>
case ((lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin), (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
val totalCnt = lhsCnt + rhsCnt
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
val deltaMean = rhsMean - lhsMean
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
(totalMean, totalM2n, totalCnt)
rhsMax.activeIterator.foreach { case (id, value) =>
if (lhsMax(id) < value) lhsMax(id) = value
}
rhsMin.activeIterator.foreach { case (id, value) =>
if (lhsMin(id) > value) lhsMin(id) = value
}
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
}
)

(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3))
(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3), statistics._3, Vectors.fromBreeze(statistics._4), Vectors.fromBreeze(statistics._5), Vectors.fromBreeze(statistics._6))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,13 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("meanAndVar") {
val data = sc.parallelize(localData, 2)
val (mean, sd) = data.parallelMeanAndVar(3)
val (mean, sd, cnt, nnz, max, min) = data.parallelMeanAndVar(3)
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
assert(cnt === 3, "Column cnt do not match.")
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
}
}

Expand Down

0 comments on commit 1338ea1

Please sign in to comment.