diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala index 087bb8a6ba4f1..c3a4710d3a9f0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.MLUtils._ +import scala.collection.mutable.ArrayBuffer class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { import VectorRDDFunctionsSuite._ @@ -31,19 +32,47 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext { Vectors.dense(7.0, 8.0, 9.0) ) + val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0)))) + for (i <- 0 to 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0))) + sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0))) + sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0))) + test("full-statistics") { val data = sc.parallelize(localData, 2) - val VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min) = data.summarizeStatistics(3) + val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) = time(data.summarizeStatistics(3)) assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.") assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.") assert(cnt === 3, "Column cnt do not match.") assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.") assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.") assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.") + + val dataForSparse = sc.parallelize(sparseData.toSeq, 2) + val (VectorRDDStatisticalSummary(sparseMean, sparseVariance, sparseCnt, sparseNnz, sparseMax, sparseMin), sparseTime) = time(dataForSparse.summarizeStatistics(20)) + /* + assert(equivVector(sparseMean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.") + assert(equivVector(sparseVariance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.") + assert(sparseCnt === 3, "Column cnt do not match.") + assert(equivVector(sparseNnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.") + assert(equivVector(sparseMax, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.") + assert(equivVector(sparseMin, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.") + */ + + + + println(s"dense time is $denseTime, sparse time is $sparseTime.") } + } object VectorRDDFunctionsSuite { + def time[R](block: => R): (R, Double) = { + val t0 = System.nanoTime() + val result = block + val t1 = System.nanoTime() + (result, (t1 - t0).toDouble / 1.0e9) + } + def equivVector(lhs: Vector, rhs: Vector): Boolean = { (lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9 }