diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index fb6567ac2e431..e956185319a69 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -82,6 +82,8 @@ class NaiveBayes private (var lambda: Double) extends Serializable with Logging */ def run(data: RDD[LabeledPoint]) = { // Aggregates term frequencies per label. + // TODO: Calling combineByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), @@ -89,23 +91,23 @@ class NaiveBayes private (var lambda: Double) extends Serializable with Logging (c1._1 + c2._1, c1._2 += c2._2) ).collect() val numLabels = aggregated.length - var numExamples = 0L + var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => - numExamples += n + numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) - val piLogDenom = math.log(numExamples + numLabels * lambda) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 - aggregated.foreach { case (label, (n, sum)) => + aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sum) + numFeatures * lambda) + val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom var j = 0 while (j < numFeatures) { - theta(i)(j) = math.log(sum(j) + lambda) - thetaLogDenom + theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1