Skip to content

Commit

Permalink
add scala docs and refine shrink method
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 8ef3377 commit e09d5d2
Showing 1 changed file with 59 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.mllib.rdd

import breeze.linalg.{Vector => BV, *}
import breeze.linalg.{Vector => BV, DenseVector => BDV}

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils._
Expand All @@ -28,23 +28,38 @@ import org.apache.spark.rdd.RDD
*/
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

/**
* Compute the mean of each `Vector` in the RDD.
*/
def rowMeans(): RDD[Double] = {
self.map(x => x.toArray.sum / x.size)
}

/**
* Compute the norm-2 of each `Vector` in the RDD.
*/
def rowNorm2(): RDD[Double] = {
self.map(x => math.sqrt(x.toArray.map(x => x*x).sum))
}

/**
* Compute the standard deviation of each `Vector` in the RDD.
*/
def rowSDs(): RDD[Double] = {
val means = self.rowMeans()
self.zip(means)
.map{ case(x, m) => x.toBreeze - m }
.map{ x => math.sqrt(x.toArray.map(x => x*x).sum / x.size) }
}

/**
* Compute the mean of each column in the RDD.
*/
def colMeans(): Vector = colMeans(self.take(1).head.size)

/**
* Compute the mean of each column in the RDD with `size` as the dimension of each `Vector`.
*/
def colMeans(size: Int): Vector = {
Vectors.fromBreeze(self.map(_.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
seqOp = (c, v) => (c, v) match {
Expand All @@ -58,15 +73,27 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
)._1)
}

/**
* Compute the norm-2 of each column in the RDD.
*/
def colNorm2(): Vector = colNorm2(self.take(1).head.size)

/**
* Compute the norm-2 of each column in the RDD with `size` as the dimension of each `Vector`.
*/
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).aggregate(BV.zeros[Double](size))(
seqOp = (c, v) => c + (v :* v),
combOp = (lhs, rhs) => lhs + rhs
).map(math.sqrt))

/**
* Compute the standard deviation of each column in the RDD.
*/
def colSDs(): Vector = colSDs(self.take(1).head.size)

/**
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each `Vector`.
*/
def colSDs(size: Int): Vector = {
val means = self.colMeans()
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
Expand All @@ -81,21 +108,49 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
)._1.map(math.sqrt))
}

/**
* Find the optional max or min vector in the RDD.
*/
private def maxMinOption(cmp: (Vector, Vector) => Boolean): Option[Vector] = {
def cmpMaxMin(x1: Vector, x2: Vector) = if (cmp(x1, x2)) x1 else x2
self.mapPartitions { iterator =>
Seq(iterator.reduceOption(cmpMaxMin)).iterator
}.collect { case Some(x) => x }.collect().reduceOption(cmpMaxMin)
}

/**
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at all.
*/
def maxOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(cmp)

/**
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at all.
*/
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))

def rowShrink(): RDD[Vector] = self.filter(x => x.toArray.sum != 0)
/**
* Filter the vectors whose standard deviation is not zero.
*/
def rowShrink(): RDD[Vector] = self.zip(self.rowSDs()).filter(_._2 != 0.0).map(_._1)

/**
* Filter each column of the RDD whose standard deviation is not zero.
*/
def colShrink(): RDD[Vector] = {
val means = self.colMeans()
self.map( v => Vectors.dense(v.toArray.zip(means.toArray).filter{ case (x, m) => m != 0.0 }.map(_._1)))
val sds = self.colSDs()
self.take(1).head.toBreeze.isInstanceOf[BDV[Double]] match {
case true =>
self.map{ v =>
Vectors.dense(v.toArray.zip(sds.toArray).filter{case (x, m) => m != 0.0}.map(_._1))
}
case false =>
self.map { v =>
val filtered = v.toArray.zip(sds.toArray).filter{case (x, m) => m != 0.0}.map(_._1)
val denseVector = Vectors.dense(filtered).toBreeze
val size = denseVector.size
val iterElement = denseVector.activeIterator.toSeq
Vectors.sparse(size, iterElement)
}
}
}
}

0 comments on commit e09d5d2

Please sign in to comment.