Skip to content

Commit

Permalink
Remove no-longer-needed slice code and handle review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Apr 15, 2014
1 parent ea5a25a commit d52e763
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
*
* See mllib/python/pyspark._common.py for the mutually agreed upon data format.
* See python/pyspark/mllib/_common.py for the mutually agreed upon data format.
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
Expand Down
41 changes: 0 additions & 41 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ trait Vector extends Serializable {
* @param i index
*/
private[mllib] def apply(i: Int): Double = toBreeze(i)

private[mllib] def slice(start: Int, end: Int): Vector
}

/**
Expand Down Expand Up @@ -159,10 +157,6 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)

override def apply(i: Int) = values(i)

private[mllib] override def slice(start: Int, end: Int): Vector = {
new DenseVector(values.slice(start, end))
}
}

/**
Expand Down Expand Up @@ -193,39 +187,4 @@ class SparseVector(
}

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

override def apply(pos: Int): Double = {
// A more efficient apply() than creating a new Breeze vector
var i = 0
while (i < indices.length) {
if (indices(i) == pos) {
return values(i)
} else if (indices(i) > pos) {
return 0.0
}
i += 1
}
0.0
}

private[mllib] override def slice(start: Int, end: Int): Vector = {
require(start <= end, s"invalid range: ${start} to ${end}")
require(start >= 0, s"invalid range: ${start} to ${end}")
require(end <= size, s"invalid range: ${start} to ${end}")
// Figure out the range of indices that fall within the given bounds
var i = 0
var indexRangeStart = 0
var indexRangeEnd = 0
while (i < indices.length && indices(i) < start) {
i += 1
}
indexRangeStart = i
while (i < indices.length && indices(i) < end) {
i += 1
}
indexRangeEnd = i
val newIndices = indices.slice(indexRangeStart, indexRangeEnd).map(_ - start)
val newValues = values.slice(indexRangeStart, indexRangeEnd)
new SparseVector(end - start, newIndices, newValues)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,4 @@ class VectorsSuite extends FunSuite {
assert(vec2(6) === 4.0)
assert(vec2(7) === 0.0)
}

test("slicing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
val slice = vec.slice(1, 3)
assert(slice === Vectors.dense(2.0, 3.0))
assert(slice.isInstanceOf[DenseVector], "slice was not DenseVector")
}

test("slicing sparse vectors") {
val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
val slice = vec.slice(1, 5)
assert(slice === Vectors.sparse(4, Array(1,3), Array(2.0, 3.0)))
assert(slice.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice2 = vec.slice(1, 2)
assert(slice2 === Vectors.sparse(1, Array(), Array()))
assert(slice2.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice3 = vec.slice(6, 7)
assert(slice3 === Vectors.sparse(1, Array(0), Array(4.0)))
assert(slice3.isInstanceOf[SparseVector], "slice was not SparseVector")
}
}
8 changes: 3 additions & 5 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import struct
import numpy
from numpy import ndarray, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from numpy import ndarray, float64, int64, int32, array_equal, array
from pyspark import SparkContext, RDD
from pyspark.mllib.linalg import SparseVector
from pyspark.serializers import Serializer
Expand Down Expand Up @@ -243,8 +243,6 @@ def _deserialize_double_matrix(ba):

def _serialize_labeled_point(p):
"""Serialize a LabeledPoint with a features vector of any type."""
#from pyspark.mllib.regression import LabeledPoint
#assert type(p) == LabeledPoint, "Expected a LabeledPoint object"
from pyspark.mllib.regression import LabeledPoint
serialized_features = _serialize_double_vector(p.features)
header = bytearray(9)
Expand Down Expand Up @@ -318,9 +316,9 @@ def _get_initial_weights(initial_weights, data):
if initial_weights.ndim != 1:
raise TypeError("At least one data element has "
+ initial_weights.ndim + " dimensions, which is not 1")
initial_weights = numpy.ones([initial_weights.shape[0]])
initial_weights = numpy.zeros([initial_weights.shape[0]])
elif type(initial_weights) == SparseVector:
initial_weights = numpy.ones([initial_weights.size])
initial_weights = numpy.zeros([initial_weights.size])
return initial_weights


Expand Down

0 comments on commit d52e763

Please sign in to comment.