Skip to content

Commit

Permalink
refine the code style
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent ad6c82d commit 9af2e95
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.rdd.RDD

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an implicit conversion.
* Import `org.apache.spark.MLContext._` at the top of your program to use these functions.
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
* these functions.
*/
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

Expand Down Expand Up @@ -81,31 +82,36 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
/**
* Compute the norm-2 of each column in the RDD with `size` as the dimension of each `Vector`.
*/
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).aggregate(BV.zeros[Double](size))(
seqOp = (c, v) => c + (v :* v),
combOp = (lhs, rhs) => lhs + rhs
).map(math.sqrt))
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze)
.aggregate(BV.zeros[Double](size))(
seqOp = (c, v) => c + (v :* v),
combOp = (lhs, rhs) => lhs + rhs
).map(math.sqrt)
)

/**
* Compute the standard deviation of each column in the RDD.
*/
def colSDs(): Vector = colSDs(self.take(1).head.size)

/**
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each `Vector`.
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each
* `Vector`.
*/
def colSDs(size: Int): Vector = {
val means = self.colMeans()
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
seqOp = (c, v) => (c, v) match {
case ((prev, cnt), current) =>
(((prev :* cnt) + (current :* current)) :/ (cnt + 1.0), cnt + 1.0)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
((lhsVec :* lhsCnt) + (rhsVec :* rhsCnt) :/ (lhsCnt + rhsCnt), lhsCnt + rhsCnt)
}
)._1.map(math.sqrt))
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze)
.aggregate((BV.zeros[Double](size), 0.0))(
seqOp = (c, v) => (c, v) match {
case ((prev, cnt), current) =>
(((prev :* cnt) + (current :* current)) :/ (cnt + 1.0), cnt + 1.0)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
((lhsVec :* lhsCnt) + (rhsVec :* rhsCnt) :/ (lhsCnt + rhsCnt), lhsCnt + rhsCnt)
}
)._1.map(math.sqrt)
)
}

/**
Expand All @@ -119,12 +125,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
}

/**
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at all.
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at
* all.
*/
def maxOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(cmp)

/**
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at all.
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at
* all.
*/
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,38 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("rowMeans") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)),
"Row means do not match.")
}

test("rowNorm2") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)), "Row norm2s do not match.")
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)),
"Row norm2s do not match.")
}

test("rowSDs") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)), "Row SDs do not match.")
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)),
"Row SDs do not match.")
}

test("colMeans") {
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)), "Column means do not match.")
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)),
"Column means do not match.")
}

test("colNorm2") {
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)), "Column norm2s do not match.")
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)),
"Column norm2s do not match.")
}

test("colSDs") {
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)), "Column SDs do not match.")
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)),
"Column SDs do not match.")
}

test("maxOption") {
Expand Down

0 comments on commit 9af2e95

Please sign in to comment.