Skip to content

Commit

Permalink
remove row-wise APIs and refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 1338ea1 commit c4651bb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.rdd.RDD
import breeze.linalg._

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
Expand All @@ -30,30 +29,6 @@ import breeze.linalg._
*/
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

/**
* Compute the mean of each `Vector` in the RDD.
*/
def rowMeans(): RDD[Double] = {
self.map(x => x.toArray.sum / x.size)
}

/**
* Compute the norm-2 of each `Vector` in the RDD.
*/
def rowNorm2(): RDD[Double] = {
self.map(x => math.sqrt(x.toArray.map(x => x*x).sum))
}

/**
* Compute the standard deviation of each `Vector` in the RDD.
*/
def rowSDs(): RDD[Double] = {
val means = self.rowMeans()
self.zip(means)
.map{ case(x, m) => x.toBreeze - m }
.map{ x => math.sqrt(x.toArray.map(x => x*x).sum / x.size) }
}

/**
* Compute the mean of each column in the RDD.
*/
Expand Down Expand Up @@ -137,11 +112,6 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
*/
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))

/**
* Filter the vectors whose standard deviation is not zero.
*/
def rowShrink(): RDD[Vector] = self.zip(self.rowSDs()).filter(_._2 != 0.0).map(_._1)

/**
* Filter each column of the RDD whose standard deviation is not zero.
*/
Expand All @@ -163,34 +133,66 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
}
}

def parallelMeanAndVar(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0, BV.zeros[Double](size), BV.fill(size){Double.MinValue}, BV.fill(size){Double.MaxValue}))(
/**
* Compute full column-wise statistics for the RDD, including
* {{{
* Mean: Vector,
* Variance: Vector,
* Count: Double,
* Non-zero count: Vector,
* Maximum elements: Vector,
* Minimum elements: Vector.
* }}},
* with the size of Vector as input parameter.
*/
def statistics(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
val results = self.map(_.toBreeze).aggregate((
BV.zeros[Double](size),
BV.zeros[Double](size),
0.0,
BV.zeros[Double](size),
BV.fill(size){Double.MinValue},
BV.fill(size){Double.MaxValue}))(
seqOp = (c, v) => (c, v) match {
case ((prevMean, prevM2n, cnt, nnz, maxVec, minVec), currData) =>
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
val nonZeroCnt = Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
val nonZeroCnt = Vectors
.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
currData.activeIterator.foreach { case (id, value) =>
if (maxVec(id) < value) maxVec(id) = value
if (minVec(id) > value) minVec(id) = value
}
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0, nnz + nonZeroCnt, maxVec, minVec)
(currMean,
prevM2n + ((currData - prevMean) :* (currData - currMean)),
cnt + 1.0,
nnzVec + nonZeroCnt,
maxVec,
minVec)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin), (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
val totalCnt = lhsCnt + rhsCnt
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
val deltaMean = rhsMean - lhsMean
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
rhsMax.activeIterator.foreach { case (id, value) =>
if (lhsMax(id) < value) lhsMax(id) = value
}
rhsMin.activeIterator.foreach { case (id, value) =>
if (lhsMin(id) > value) lhsMin(id) = value
}
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
case (
(lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin),
(rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
val totalCnt = lhsCnt + rhsCnt
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
val deltaMean = rhsMean - lhsMean
val totalM2n =
lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
rhsMax.activeIterator.foreach { case (id, value) =>
if (lhsMax(id) < value) lhsMax(id) = value
}
rhsMin.activeIterator.foreach { case (id, value) =>
if (lhsMin(id) > value) lhsMin(id) = value
}
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
}
)

(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3), statistics._3, Vectors.fromBreeze(statistics._4), Vectors.fromBreeze(statistics._5), Vectors.fromBreeze(statistics._6))
(Vectors.fromBreeze(results._1),
Vectors.fromBreeze(results._2 :/ results._3),
results._3,
Vectors.fromBreeze(results._4),
Vectors.fromBreeze(results._5),
Vectors.fromBreeze(results._6))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
Vectors.dense(7.0, 8.0, 9.0)
)

val rowMeans = Array(2.0, 5.0, 8.0)
val rowNorm2 = Array(math.sqrt(14.0), math.sqrt(77.0), math.sqrt(194.0))
val rowSDs = Array(math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0))

val colMeans = Array(4.0, 5.0, 6.0)
val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0))
val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0))
Expand All @@ -49,35 +45,12 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
Vectors.dense(7.0, 8.0, 0.0)
)

val rowShrinkData = Array(
Vectors.dense(1.0, 2.0, 0.0),
Vectors.dense(7.0, 8.0, 0.0)
)

val colShrinkData = Array(
Vectors.dense(1.0, 2.0),
Vectors.dense(0.0, 0.0),
Vectors.dense(7.0, 8.0)
)

test("rowMeans") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)),
"Row means do not match.")
}

test("rowNorm2") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)),
"Row norm2s do not match.")
}

test("rowSDs") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)),
"Row SDs do not match.")
}

test("colMeans") {
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)),
Expand Down Expand Up @@ -114,14 +87,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
)
}

test("rowShrink") {
val data = sc.parallelize(shrinkingData, 2)
val res = data.rowShrink().collect()
rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
assert(equivVector(lhs, rhs), "Row shrink error.")
}
}

test("columnShrink") {
val data = sc.parallelize(shrinkingData, 2)
val res = data.colShrink().collect()
Expand All @@ -130,9 +95,9 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
}
}

test("meanAndVar") {
test("full-statistics") {
val data = sc.parallelize(localData, 2)
val (mean, sd, cnt, nnz, max, min) = data.parallelMeanAndVar(3)
val (mean, sd, cnt, nnz, max, min) = data.statistics(3)
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
assert(cnt === 3, "Column cnt do not match.")
Expand Down

0 comments on commit c4651bb

Please sign in to comment.