Skip to content

Commit

Permalink
pass all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 28cf060 commit 8ef3377
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

def colNorm2(): Vector = colNorm2(self.take(1).head.size)

def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).fold(BV.zeros[Double](size)) {
case (lhs, rhs) =>
lhs + (rhs :* rhs)
}.map(math.sqrt))
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).aggregate(BV.zeros[Double](size))(
seqOp = (c, v) => c + (v :* v),
combOp = (lhs, rhs) => lhs + rhs
).map(math.sqrt))

def colSDs(): Vector = colSDs(self.take(1).head.size)

def colSDs(size: Int): Vector = {
val means = this.colMeans()
val means = self.colMeans()
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
seqOp = (c, v) => (c, v) match {
case ((prev, cnt), current) =>
(((prev :* cnt) + current) :/ (cnt + 1.0), cnt + 1.0)
(((prev :* cnt) + (current :* current)) :/ (cnt + 1.0), cnt + 1.0)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,37 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
val minVec = Array(1.0, 2.0, 3.0)

test("rowMeans") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
}

test("rowNorm2") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)), "Row norm2s do not match.")
}

test("rowSDs") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)), "Row SDs do not match.")
}

test("colMeans") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)), "Column means do not match.")
}

test("colNorm2") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)), "Column norm2s do not match.")
}

test("colSDs") {
val data = sc.parallelize(localData)
val test = data.colSDs()
val data = sc.parallelize(localData, 2)
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)), "Column SDs do not match.")
}

test("maxOption") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVectorOption(
data.maxOption((lhs: Vector, rhs: Vector) => lhs.toBreeze.norm(2) >= rhs.toBreeze.norm(2)),
Some(Vectors.dense(maxVec))),
Expand All @@ -85,7 +84,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
}

test("minOption") {
val data = sc.parallelize(localData)
val data = sc.parallelize(localData, 2)
assert(equivVectorOption(
data.minOption((lhs: Vector, rhs: Vector) => lhs.toBreeze.norm(2) >= rhs.toBreeze.norm(2)),
Some(Vectors.dense(minVec))),
Expand Down

0 comments on commit 8ef3377

Please sign in to comment.