Skip to content

Commit

Permalink
[SPARK-7422] [MLLIB] Add argmax to Vector, SparseVector
Browse files Browse the repository at this point in the history
Modifying Vector, DenseVector, and SparseVector to implement argmax functionality. This work is to set the stage for changes to be done in Spark-7423.

Author: George Dittmar <[email protected]>
Author: George <[email protected]>
Author: dittmarg <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes #6112 from GeorgeDittmar/SPARK-7422 and squashes the following commits:

3e0a939 [George Dittmar] Merge pull request #1 from mengxr/SPARK-7422
127dec5 [Xiangrui Meng] update argmax impl
2ea6a55 [George Dittmar] Added MimaExcludes for Vectors.argmax
98058f4 [George Dittmar] Merge branch 'master' of github.com:apache/spark into SPARK-7422
5fd9380 [George Dittmar] fixing style check error
42341fb [George Dittmar] refactoring arg max check to better handle zero values
b22af46 [George Dittmar] Fixing spaces between commas in unit test
f2eba2f [George Dittmar] Cleaning up unit tests to be fewer lines
aa330e3 [George Dittmar] Fixing some last if else spacing issues
ac53c55 [George Dittmar] changing dense vector argmax unit test to be one line call vs 2
d5b5423 [George Dittmar] Fixing code style and updating if logic on when to check for zero values
ee1a85a [George Dittmar] Cleaning up unit tests a bit and modifying a few cases
3ee8711 [George Dittmar] Fixing corner case issue with zeros in the active values of the sparse vector. Updated unit tests
b1f059f [George Dittmar] Added comment before we start arg max calculation. Updated unit tests to cover corner cases
f21dcce [George Dittmar] commit
af17981 [dittmarg] Initial work fixing bug that was made clear in pr
eeda560 [George] Fixing SparseVector argmax function to ignore zero values while doing the calculation.
4526acc [George] Merge branch 'master' of github.com:apache/spark into SPARK-7422
df9538a [George] Added argmax to sparse vector and added unit test
3cffed4 [George] Adding unit tests for argmax functions for Dense and Sparse vectors
04677af [George] initial work on adding argmax to Vector and SparseVector
  • Loading branch information
GeorgeDittmar authored and mengxr committed Jul 20, 2015
1 parent 79ec072 commit 3f7de7d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
57 changes: 52 additions & 5 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
toDense
}
}

/**
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
* Returns -1 if vector has length 0.
*/
def argmax: Int
}

/**
Expand Down Expand Up @@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
new SparseVector(size, ii, vv)
}

/**
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
* Returns -1 if vector has length 0.
*/
private[spark] def argmax: Int = {
override def argmax: Int = {
if (size == 0) {
-1
} else {
Expand Down Expand Up @@ -717,6 +719,51 @@ class SparseVector(
new SparseVector(size, ii, vv)
}
}

override def argmax: Int = {
if (size == 0) {
-1
} else {
// Find the max active entry.
var maxIdx = indices(0)
var maxValue = values(0)
var maxJ = 0
var j = 1
val na = numActives
while (j < na) {
val v = values(j)
if (v > maxValue) {
maxValue = v
maxIdx = indices(j)
maxJ = j
}
j += 1
}

// If the max active entry is nonpositive and there exists inactive ones, find the first zero.
if (maxValue <= 0.0 && na < size) {
if (maxValue == 0.0) {
// If there exists an inactive entry before maxIdx, find it and return its index.
if (maxJ < maxIdx) {
var k = 0
while (k < maxJ && indices(k) == k) {
k += 1
}
maxIdx = k
}
} else {
// If the max active value is negative, find and return the first inactive index.
var k = 0
while (k < na && indices(k) == k) {
k += 1
}
maxIdx = k
}
}

maxIdx
}
}
}

object SparseVector {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(vec.toArray.eq(arr))
}

test("dense argmax") {
val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
assert(vec.argmax === -1)

val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec2.argmax === 3)

val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
assert(vec3.argmax === 3)
}

test("sparse to array") {
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec.toArray === arr)
}

test("sparse argmax") {
val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
assert(vec.argmax === -1)

val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec2.argmax === 3)

val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
assert(vec3.argmax === 2)

// check for case that sparse vector is created with
// only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
assert(vec4.argmax === 0)

val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
assert(vec5.argmax === 1)

val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
assert(vec6.argmax === 2)

val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
assert(vec7.argmax === 1)

val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
assert(vec8.argmax === 0)
}

test("vector equals") {
val dv1 = Vectors.dense(arr.clone())
val dv2 = Vectors.dense(arr.clone())
Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ object MimaExcludes {
"org.apache.spark.api.r.StringRRDD.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.r.BaseRRDD.this")
) ++ Seq(
// SPARK-7422 add argmax for sparse vectors
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.argmax")
)

case v if v.startsWith("1.4") =>
Expand Down

0 comments on commit 3f7de7d

Please sign in to comment.