Skip to content

Commit

Permalink
add a TODO to NB
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent b9b7ef7 commit 7c1bc01
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,32 @@ class NaiveBayes private (var lambda: Double) extends Serializable with Logging
*/
def run(data: RDD[LabeledPoint]) = {
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
val numLabels = aggregated.length
var numExamples = 0L
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numExamples += n
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
val piLogDenom = math.log(numExamples + numLabels * lambda)
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sum)) =>
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
val thetaLogDenom = math.log(brzSum(sum) + numFeatures * lambda)
val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
pi(i) = math.log(n + lambda) - piLogDenom
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sum(j) + lambda) - thetaLogDenom
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
Expand Down

0 comments on commit 7c1bc01

Please sign in to comment.