From 0f8759b3a8bd6d795e914e43224e7f0594c8f7f9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 30 Mar 2014 14:31:34 -0700 Subject: [PATCH] minor updates to NB --- .../apache/spark/mllib/classification/NaiveBayes.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f4228fe5e7522..924ab43f26e06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -20,8 +20,10 @@ package org.apache.spark.mllib.classification import scala.collection.mutable import org.jblas.DoubleMatrix +import breeze.linalg.{Vector => BV} import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.SparkContext._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.mllib.util.MLUtils @@ -76,7 +78,13 @@ class NaiveBayes private (var lambda: Double) * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { - runRaw(data.map(v => (v.label, v.features.toArray))) + val agg = data.map(p => (p.label, p.features)).combineByKey[(Long, BV[Double])]( + createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), + mergeValue = (c: (Long, BV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + mergeCombiners = (c1: (Long, BV[Double]), c2: (Long, BV[Double])) => + (c1._1 + c2._1, c1._2 += c2._2) + ).collect() + val numLabels = agg.size } /**