diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4e76b3b1db19d..dce1420793169 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -728,18 +728,11 @@ class SparseVector( var maxValue = values(0) foreachActive { (i, v) => - if(values(i) > maxValue){ + if(v > maxValue){ maxIdx = i maxValue = v } } -// while(i < this.indices.size){ -// if(values(i) > maxValue){ -// maxIdx = indices(i) -// maxValue = values(i) -// } -// i += 1 -// } maxIdx } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index b118856176a72..7d35186df62cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -86,6 +86,10 @@ class VectorsSuite extends FunSuite { val vec2 = Vectors.sparse(n,indices,values).asInstanceOf[SparseVector] val max = vec2.argmax assert(max === 3) + + val vec3 = Vectors.sparse(5,Array(1,3,4),Array(1.0,.5,.7)) + val max2 = vec3.argmax + assert(max2 === 1) } test("vector equals") {