From 712cb8881cee541e74f307c67c0c34f2170cb3e2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Mar 2014 17:06:47 -0700 Subject: [PATCH] make Vectors.sparse Java friendly rename VectorSuite to VectorsSuite --- .../apache/spark/mllib/linalg/Vectors.scala | 21 +++++++++++++++---- .../{VectorSuite.scala => VectorsSuite.scala} | 11 +++++++++- 2 files changed, 27 insertions(+), 5 deletions(-) rename mllib/src/test/scala/org/apache/spark/mllib/linalg/{VectorSuite.scala => VectorsSuite.scala} (81%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c102084caa997..16a19df4472f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,6 +17,10 @@ package org.apache.spark.mllib.linalg +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ + import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector, SparseVector => BreezeSparseVector} @@ -68,11 +72,11 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ - def sparse(size: Int, elements: Iterable[(Int, Double)]): Vector = { + def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { require(size > 0) - val (indices, values) = elements.toArray.sortBy(_._1).unzip + val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => require(prev < i, "Found duplicate indices: " + i) @@ -83,14 +87,23 @@ object Vectors { new SparseVector(size, indices.toArray, values.toArray) } + /** + * Creates a sparse vector using unordered (index, value) pairs. + * + * @param size vector size. + * @param elements vector elements in (index, value) pairs. + */ + def sparse(size: Int, elements: JavaIterable[(Int, Double)]): Vector = + sparse(size, elements.asScala.toSeq) + /** * Creates a vector instance from a breeze vector. */ private[mllib] def fromBreeze(breezeVector: BreezeVector[Double]): Vector = { breezeVector match { case v: BreezeDenseVector[Double] => { - require(v.offset == 0) - require(v.stride == 1) + require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.") + require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.") new DenseVector(v.data) } case v: BreezeSparseVector[Double] => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala similarity index 81% rename from mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index e3ee97121f822..adf2005b84f1d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.linalg +import scala.collection.JavaConverters._ + import org.scalatest.FunSuite -class VectorSuite extends FunSuite { +class VectorsSuite extends FunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 @@ -45,4 +47,11 @@ class VectorSuite extends FunSuite { assert(vec.indices === indices) assert(vec.values === values) } + + test("sparse vector construction with unordered elements stored as Java Iterable") { + val vec = Vectors.sparse(n, indices.toSeq.zip(values).reverse.asJava).asInstanceOf[SparseVector] + assert(vec.size === n) + assert(vec.indices === indices) + assert(vec.values === values) + } }