Skip to content

Commit

Permalink
make Vectors.sparse Java friendly
Browse files Browse the repository at this point in the history
rename VectorSuite to VectorsSuite
  • Loading branch information
mengxr committed Mar 20, 2014
1 parent 27858e4 commit 712cb88
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
21 changes: 17 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.mllib.linalg

import java.lang.{Iterable => JavaIterable}

import scala.collection.JavaConverters._

import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector,
SparseVector => BreezeSparseVector}

Expand Down Expand Up @@ -68,11 +72,11 @@ object Vectors {
* @param size vector size.
* @param elements vector elements in (index, value) pairs.
*/
def sparse(size: Int, elements: Iterable[(Int, Double)]): Vector = {
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {

require(size > 0)

val (indices, values) = elements.toArray.sortBy(_._1).unzip
val (indices, values) = elements.sortBy(_._1).unzip
var prev = -1
indices.foreach { i =>
require(prev < i, "Found duplicate indices: " + i)
Expand All @@ -83,14 +87,23 @@ object Vectors {
new SparseVector(size, indices.toArray, values.toArray)
}

/**
* Creates a sparse vector using unordered (index, value) pairs.
*
* @param size vector size.
* @param elements vector elements in (index, value) pairs.
*/
def sparse(size: Int, elements: JavaIterable[(Int, Double)]): Vector =
sparse(size, elements.asScala.toSeq)

/**
* Creates a vector instance from a breeze vector.
*/
private[mllib] def fromBreeze(breezeVector: BreezeVector[Double]): Vector = {
breezeVector match {
case v: BreezeDenseVector[Double] => {
require(v.offset == 0)
require(v.stride == 1)
require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.")
require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.")
new DenseVector(v.data)
}
case v: BreezeSparseVector[Double] => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.mllib.linalg

import scala.collection.JavaConverters._

import org.scalatest.FunSuite

class VectorSuite extends FunSuite {
class VectorsSuite extends FunSuite {

val arr = Array(0.1, 0.2, 0.3, 0.4)
val n = 20
Expand All @@ -45,4 +47,11 @@ class VectorSuite extends FunSuite {
assert(vec.indices === indices)
assert(vec.values === values)
}

test("sparse vector construction with unordered elements stored as Java Iterable") {
val vec = Vectors.sparse(n, indices.toSeq.zip(values).reverse.asJava).asInstanceOf[SparseVector]
assert(vec.size === n)
assert(vec.indices === indices)
assert(vec.values === values)
}
}

0 comments on commit 712cb88

Please sign in to comment.