Skip to content

Commit

Permalink
add shrink test
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent e09d5d2 commit ad6c82d
Showing 1 changed file with 40 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@

package org.apache.spark.mllib.rdd

import org.apache.spark.mllib.linalg.Vector
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils._
import VectorRDDFunctionsSuite._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLUtils._

class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
import VectorRDDFunctionsSuite._

val localData = Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(4.0, 5.0, 6.0),
Vectors.dense(7.0, 8.0, 9.0)
)
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(4.0, 5.0, 6.0),
Vectors.dense(7.0, 8.0, 9.0)
)

val rowMeans = Array(2.0, 5.0, 8.0)
val rowNorm2 = Array(math.sqrt(14.0), math.sqrt(77.0), math.sqrt(194.0))
Expand All @@ -44,6 +42,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
val maxVec = Array(7.0, 8.0, 9.0)
val minVec = Array(1.0, 2.0, 3.0)

val shrinkingData = Array(
Vectors.dense(1.0, 2.0, 0.0),
Vectors.dense(0.0, 0.0, 0.0),
Vectors.dense(7.0, 8.0, 0.0)
)

val rowShrinkData = Array(
Vectors.dense(1.0, 2.0, 0.0),
Vectors.dense(7.0, 8.0, 0.0)
)

val colShrinkData = Array(
Vectors.dense(1.0, 2.0),
Vectors.dense(0.0, 0.0),
Vectors.dense(7.0, 8.0)
)

test("rowMeans") {
val data = sc.parallelize(localData, 2)
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
Expand Down Expand Up @@ -91,6 +106,22 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
"Optional minimum does not match."
)
}

test("rowShrink") {
val data = sc.parallelize(shrinkingData, 2)
val res = data.rowShrink().collect()
rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
assert(equivVector(lhs, rhs), "Row shrink error.")
}
}

test("columnShrink") {
val data = sc.parallelize(shrinkingData, 2)
val res = data.colShrink().collect()
colShrinkData.zip(res).foreach { case (lhs, rhs) =>
assert(equivVector(lhs, rhs), "Column shrink error.")
}
}
}

object VectorRDDFunctionsSuite {
Expand Down

0 comments on commit ad6c82d

Please sign in to comment.