Skip to content

Commit

Permalink
use weighted sum in combOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 3, 2014
1 parent 7efbb6f commit 1a8fb41
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]) {
private def learnVocab(words:RDD[String]){
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
Expand All @@ -110,6 +110,10 @@ class Word2Vec(
logInfo("trainWordsCount = " + trainWordsCount)
}

private def learnVocabPerPartition(words:RDD[String]) {

}

private def createExpTable(): Array[Double] = {
val expTable = new Array[Double](EXP_TABLE_SIZE)
var i = 0
Expand Down Expand Up @@ -303,8 +307,12 @@ class Word2Vec(
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
blas.dscal(n, weight1, syn0_1, 1)
blas.dscal(n, weight1, syn1_1, 1)
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
syn0Global = aggSyn0
Expand Down

0 comments on commit 1a8fb41

Please sign in to comment.