Skip to content

Commit

Permalink
add sparse vectors test
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 4cfbadf commit f6e8e9a
Showing 1 changed file with 30 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLUtils._
import scala.collection.mutable.ArrayBuffer

class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
import VectorRDDFunctionsSuite._
Expand All @@ -31,19 +32,47 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
Vectors.dense(7.0, 8.0, 9.0)
)

val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0))))
for (i <- 0 to 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0)))
sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0)))

test("full-statistics") {
val data = sc.parallelize(localData, 2)
val VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min) = data.summarizeStatistics(3)
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) = time(data.summarizeStatistics(3))
assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
assert(cnt === 3, "Column cnt do not match.")
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")

val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
val (VectorRDDStatisticalSummary(sparseMean, sparseVariance, sparseCnt, sparseNnz, sparseMax, sparseMin), sparseTime) = time(dataForSparse.summarizeStatistics(20))
/*
assert(equivVector(sparseMean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
assert(equivVector(sparseVariance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
assert(sparseCnt === 3, "Column cnt do not match.")
assert(equivVector(sparseNnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
assert(equivVector(sparseMax, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
assert(equivVector(sparseMin, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
*/



println(s"dense time is $denseTime, sparse time is $sparseTime.")
}

}

object VectorRDDFunctionsSuite {
def time[R](block: => R): (R, Double) = {
val t0 = System.nanoTime()
val result = block
val t1 = System.nanoTime()
(result, (t1 - t0).toDouble / 1.0e9)
}

def equivVector(lhs: Vector, rhs: Vector): Boolean = {
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
}
Expand Down

0 comments on commit f6e8e9a

Please sign in to comment.