Skip to content

Commit

Permalink
[SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls
Browse files Browse the repository at this point in the history
1. Use blas calls to find the dot product between two vectors.
2. Prevent re-computing the L2 norm of the given vector for each word in model.

Author: MechCoder <[email protected]>

Closes apache#5467 from MechCoder/spark-6065 and squashes the following commits:

dd0b0b2 [MechCoder] Preallocate wordVectors
ffc9240 [MechCoder] Minor
6b74c81 [MechCoder] Switch back to native blas calls
da1642d [MechCoder] Explicit types and indexing
64575b0 [MechCoder] Save indexedmap and a wordvecmat instead of matrix
fbe0108 [MechCoder] Made the following changes 1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas
1350cf3 [MechCoder] [SPARK-6065] Optimize word2vec.findSynonynms using blas calls
  • Loading branch information
MechCoder authored and nemccarthy committed Jun 19, 2015
1 parent b30d4ec commit 2979450
Showing 1 changed file with 51 additions and 6 deletions.
57 changes: 51 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -429,7 +429,36 @@ class Word2Vec extends Serializable with Logging {
*/
@Experimental
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
model: Map[String, Array[Float]]) extends Serializable with Saveable {

// wordList: Ordered list of words obtained from model.
private val wordList: Array[String] = model.keys.toArray

// wordIndex: Maps each word to an index, which can retrieve the corresponding
// vector from wordVectors (see below).
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap

// vectorSize: Dimension of each word's vector.
private val vectorSize = model.head._2.size
private val numWords = wordIndex.size

// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
// mapped with index i can be retrieved by the slice
// (ind * vectorSize, ind * vectorSize + vectorSize)
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
val wordVectors = new Array[Float](vectorSize * numWords)
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
val vec = model.get(wordList(i)).get
Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
i += 1
}
(wordVectors, wordVecNorms)
}

private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
Expand All @@ -443,7 +472,7 @@ class Word2VecModel private[mllib] (
override protected def formatVersion = "1.0"

def save(sc: SparkContext, path: String): Unit = {
Word2VecModel.SaveLoadV1_0.save(sc, path, model)
Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
}

/**
Expand Down Expand Up @@ -479,9 +508,23 @@ class Word2VecModel private[mllib] (
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k

val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
val beta: Float = 0

blas.sgemv(
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)

// Need not divide with the norm of the given vector since it is constant.
val updatedCosines = new Array[Double](numWords)
var ind = 0
while (ind < numWords) {
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
ind += 1
}
wordList.zip(updatedCosines)
.toSeq
.sortBy(- _._2)
.take(num + 1)
Expand All @@ -493,7 +536,9 @@ class Word2VecModel private[mllib] (
* Returns a map of words to their vector representations.
*/
def getVectors: Map[String, Array[Float]] = {
model
wordIndex.map { case (word, ind) =>
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
}
}
}

Expand Down

0 comments on commit 2979450

Please sign in to comment.