Skip to content

Commit

Permalink
[SPARK-6065] Optimize word2vec.findSynonynms using blas calls
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Apr 17, 2015
1 parent 8220d52 commit 1350cf3
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,23 @@ class Word2VecModel private[mllib] (
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))

val fVector = vector.toArray
val flatVec = model.toSeq.flatMap { case(w, v) =>
v.map(_.toDouble)}.toArray

val numDim = model.head._2.size
val numWords = model.size
val cosineArray = Array.fill[Double](numWords)(0)

blas.dgemv(
"T", numDim, numWords, 1.0, flatVec, numDim, fVector, 1, 0.0, cosineArray, 1)

// Need not divide with the norm of the given vector since it is constant.
val updatedCosines = model.zipWithIndex.map { case (vec, ind) =>
cosineArray(ind) / blas.snrm2(numDim, vec._2, 1) }

model.keys.zip(updatedCosines)
.toSeq
.sortBy(- _._2)
.take(num + 1)
Expand Down

0 comments on commit 1350cf3

Please sign in to comment.