Skip to content

Commit

Permalink
Explicit types and indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Apr 17, 2015
1 parent 64575b0 commit da1642d
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,16 +431,23 @@ class Word2Vec extends Serializable with Logging {
class Word2VecModel private[mllib] (
model: Map[String, Array[Float]]) extends Serializable with Saveable {

val indexedModel = model.keys.zip(0 until model.size).toMap
// Maintain a ordered list of words based on the index in the initial model.
private val wordList: Array[String] = model.keys.toArray
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap

private val (wordVectors, wordVecNorms) = {
private val (wordVectors: DenseMatrix, wordVecNorms: Array[Double]) = {
val numDim = model.head._2.size
val numWords = indexedModel.size
val numWords = wordIndex.size
val flatVec = model.toSeq.flatMap { case(w, v) =>
v.map(_.toDouble)}.toArray
val wordVectors = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
val wordVecNorms = model.map { case (word, vec) =>
blas.snrm2(numDim, vec, 1)}.toArray
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
val vec = model.get(wordList(i)).get
wordVecNorms(i) = blas.snrm2(numDim, vec, 1)
i += 1
}
(wordVectors, wordVecNorms)
}

Expand Down Expand Up @@ -495,13 +502,16 @@ class Word2VecModel private[mllib] (

val numWords = wordVectors.numRows
val cosineVec = Vectors.zeros(numWords).asInstanceOf[DenseVector]
BLAS.gemv(1.0, wordVectors, vector.asInstanceOf[DenseVector], 0.0, cosineVec)
BLAS.gemv(1.0, wordVectors, new DenseVector(vector.toArray), 0.0, cosineVec)

// Need not divide with the norm of the given vector since it is constant.
val updatedCosines = indexedModel.map { case (_, ind) =>
cosineVec(ind) / wordVecNorms(ind) }

indexedModel.keys.zip(updatedCosines)
val updatedCosines = new Array[Double](numWords)
var ind = 0
while (ind < numWords) {
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
ind += 1
}
wordList.zip(updatedCosines)
.toSeq
.sortBy(- _._2)
.take(num + 1)
Expand All @@ -514,7 +524,7 @@ class Word2VecModel private[mllib] (
*/
def getVectors: Map[String, Array[Float]] = {
val numDim = wordVectors.numCols
indexedModel.map { case (word, ind) =>
wordIndex.map { case (word, ind) =>
val startInd = numDim * ind
val endInd = startInd + numDim
(word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
Expand Down

0 comments on commit da1642d

Please sign in to comment.