From 127dec5d7aa536f70891dfce10d0fb56a3b57723 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 16 Jul 2015 21:49:34 -0700 Subject: [PATCH] update argmax impl --- .../apache/spark/mllib/linalg/Vectors.scala | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 68c933752a959..9067b3ba9a7bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -724,22 +724,44 @@ class SparseVector( if (size == 0) { -1 } else { + // Find the max active entry. var maxIdx = indices(0) var maxValue = values(0) - - foreachActive { (i, v) => + var maxJ = 0 + var j = 1 + val na = numActives + while (j < na) { + val v = values(j) if (v > maxValue) { - maxIdx = i maxValue = v + maxIdx = indices(j) + maxJ = j } + j += 1 } - var k = 0 - while (k < indices.length && indices(k) == k && values(k) != 0.0) { - k += 1 + // If the max active entry is nonpositive and there exists inactive ones, find the first zero. + if (maxValue <= 0.0 && na < size) { + if (maxValue == 0.0) { + // If there exists an inactive entry before maxIdx, find it and return its index. + if (maxJ < maxIdx) { + var k = 0 + while (k < maxJ && indices(k) == k) { + k += 1 + } + maxIdx = k + } + } else { + // If the max active value is negative, find and return the first inactive index. + var k = 0 + while (k < na && indices(k) == k) { + k += 1 + } + maxIdx = k + } } - if (maxValue <= 0.0 || k >= maxIdx) k else maxIdx + maxIdx } } }