diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index cb864caca47af..4e1c6a63fb01c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -25,7 +25,6 @@ import org.json4s.jackson.JsonMethods._ import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} @@ -33,21 +32,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} -/** - * Model types supported in Naive Bayes: - * multinomial and Bernoulli currently supported - */ -object NaiveBayesModels extends Enumeration { - type NaiveBayesModels = Value - val Multinomial, Bernoulli = Value - - implicit def toString(model: NaiveBayesModels): String = { - model.toString - } -} - - - /** * Model for Naive Bayes Classifiers. * @@ -62,20 +46,21 @@ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], - val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable { + val modelType: NaiveBayes.ModelType) + extends ClassificationModel with Serializable with Saveable { def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, NaiveBayesModels.Multinomial) + this(labels, pi, theta, NaiveBayes.Multinomial) private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t //Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0 - //precomputing log(1.0 - exp(theta)) and its sum for linear algebra application + //this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application //of this condition in predict function private val (brzNegTheta, brzNegThetaSum) = modelType match { - case NaiveBayesModels.Multinomial => (None, None) - case NaiveBayesModels.Bernoulli => + case NaiveBayes.Multinomial => (None, None) + case NaiveBayes.Bernoulli => val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) (Option(negTheta), Option(brzSum(brzNegTheta, Axis._1))) } @@ -90,16 +75,16 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { - case NaiveBayesModels.Multinomial => + case NaiveBayes.Multinomial => labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) - case NaiveBayesModels.Bernoulli => + case NaiveBayes.Bernoulli => labels (brzArgmax (brzPi + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) } } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType) + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString) NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) } @@ -152,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray - val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3)) + val modelType = NaiveBayes.ModelType.fromString(data.getString(3)) new NaiveBayesModel(labels, pi, theta, modelType) } } override def load(sc: SparkContext, path: String): NaiveBayesModel = { - def getModelType(metadata: JValue): NaiveBayesModels = { + def getModelType(metadata: JValue): NaiveBayes.ModelType = { implicit val formats = DefaultFormats - NaiveBayesModels.withName((metadata \ "modelType").extract[String]) + NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]) } val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -196,12 +181,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double, - private var modelType: NaiveBayesModels) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial) +class NaiveBayes private ( + private var lambda: Double, + private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { + + def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) - def this() = this(1.0, NaiveBayesModels.Multinomial) + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -210,7 +197,7 @@ class NaiveBayes private (private var lambda: Double, } /** Set the model type. Default: Multinomial. */ - def setModelType(model: NaiveBayesModels): NaiveBayes = { + def setModelType(model: NaiveBayes.ModelType): NaiveBayes = { this.modelType = model this } @@ -267,8 +254,8 @@ class NaiveBayes private (private var lambda: Double, labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) - case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda) + case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda) } var j = 0 while (j < numFeatures) { @@ -337,6 +324,37 @@ object NaiveBayes { * Multinomial or Bernoulli */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { - new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input) + new NaiveBayes(lambda, Multinomial).run(input) + } + + + /** + * Model types supported in Naive Bayes: + * multinomial and Bernoulli currently supported + */ + sealed abstract class ModelType + + object MODELTYPE { + final val MULTINOMIAL_STRING = "multinomial" + final val BERNOULLI_STRING = "bernoulli" + + def fromString(modelType: String): ModelType = modelType match { + case MULTINOMIAL_STRING => Multinomial + case BERNOULLI_STRING => Bernoulli + case _ => + throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType") + } + } + + final val ModelType = MODELTYPE + + final val Multinomial: ModelType = new ModelType { + override def toString: String = ModelType.MULTINOMIAL_STRING + } + + final val Bernoulli: ModelType = new ModelType { + override def toString: String = ModelType.BERNOULLI_STRING } + } +