From 37305729334922c40804752598a30a2fb892c317 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 3 Mar 2015 15:22:20 -0800 Subject: [PATCH] modified NB model type to be more Java-friendly --- .../spark/examples/mllib/JavaNaiveBayes.java | 67 ++++++++++++++++ .../examples/mllib/SparseNaiveBayes.scala | 4 + .../mllib/classification/NaiveBayes.scala | 77 +++++++++++-------- 3 files changed, 117 insertions(+), 31 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayes.java diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayes.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayes.java new file mode 100644 index 0000000000000..952648f14da89 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayes.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.NaiveBayes; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; + +import java.util.regex.Pattern; + +public final class JavaNaiveBayes { + + static class ParsePoint implements Function { + private static final Pattern COMMA = Pattern.compile(","); + private static final Pattern SPACE = Pattern.compile(" "); + + @Override + public LabeledPoint call(String line) { + String[] parts = COMMA.split(line); + double y = Double.parseDouble(parts[0]); + String[] tok = SPACE.split(parts[1]); + double[] x = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + x[i] = Double.parseDouble(tok[i]); + } + return new LabeledPoint(y, Vectors.dense(x)); + } + } + + public static void main(String[] args) { + if (args.length != 3) { + System.err.println("Usage: JavaLR "); + System.exit(1); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + double stepSize = Double.parseDouble(args[1]); + int iterations = Integer.parseInt(args[2]); + + // Example which compiles. (Don't actually include!) + NaiveBayes nb = new NaiveBayes(); + nb.setModelType(NaiveBayes.Bernoulli()); + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..9c157a6719084 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -89,6 +89,10 @@ object SparseNaiveBayes { println(s"numTraining = $numTraining, numTest = $numTest.") + // Example which compiles. (Don't actually include!) + val nb = new NaiveBayes() + nb.setModelType(NaiveBayes.Bernoulli) + val model = new NaiveBayes().setLambda(params.lambda).run(training) val prediction = model.predict(test.map(_.features)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8f9418deb045a..0ebfd87889ea8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -27,24 +27,11 @@ import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} -/** - * - */ -object NaiveBayesModels extends Enumeration { - type NaiveBayesModels = Value - val Multinomial, Bernoulli = Value - - implicit def toString(model: NaiveBayesModels): String = { - model.toString - } -} - /** * Model for Naive Bayes Classifiers. * @@ -60,17 +47,18 @@ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], - val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable { + val modelType: NaiveBayes.ModelType) + extends ClassificationModel with Serializable with Saveable { def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, NaiveBayesModels.Multinomial) + this(labels, pi, theta, NaiveBayes.Multinomial) private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t private val brzNegTheta: Option[BDM[Double]] = modelType match { - case NaiveBayesModels.Multinomial => None - case NaiveBayesModels.Bernoulli => + case NaiveBayes.Multinomial => None + case NaiveBayes.Bernoulli => val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) Option(negTheta) } @@ -85,9 +73,9 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { - case NaiveBayesModels.Multinomial => + case NaiveBayes.Multinomial => labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) - case NaiveBayesModels.Bernoulli => + case NaiveBayes.Bernoulli => labels (brzArgmax (brzPi + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzSum(brzNegTheta.get, Axis._1))) @@ -95,7 +83,7 @@ class NaiveBayesModel private[mllib] ( } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType) + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString) NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) } @@ -147,15 +135,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray - val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3)) + val modelType = NaiveBayes.ModelType.fromString(data.getString(3)) new NaiveBayesModel(labels, pi, theta, modelType) } } override def load(sc: SparkContext, path: String): NaiveBayesModel = { - def getModelType(metadata: JValue): NaiveBayesModels = { + def getModelType(metadata: JValue): NaiveBayes.ModelType = { implicit val formats = DefaultFormats - NaiveBayesModels.withName((metadata \ "modelType").extract[String]) + NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]) } val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -191,12 +179,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double, - var modelType: NaiveBayesModels) extends Serializable with Logging { +class NaiveBayes private ( + private var lambda: Double, + var modelType: NaiveBayes.ModelType) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial) + def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) - def this() = this(1.0, NaiveBayesModels.Multinomial) + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -205,7 +194,7 @@ class NaiveBayes private (private var lambda: Double, } /** Set the model type. Default: Multinomial. */ - def setModelType(model: NaiveBayesModels): NaiveBayes = { + def setModelType(model: NaiveBayes.ModelType): NaiveBayes = { this.modelType = model this } @@ -262,8 +251,8 @@ class NaiveBayes private (private var lambda: Double, labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) - case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda) + case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda) } var j = 0 while (j < numFeatures) { @@ -330,6 +319,32 @@ object NaiveBayes { * Multinomial or Bernoulli */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { - new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input) + new NaiveBayes(lambda, Multinomial).run(input) } + + sealed abstract class ModelType + + object MODELTYPE { + final val MULTINOMIAL_STRING = "multinomial" + final val BERNOULLI_STRING = "bernoulli" + + def fromString(modelType: String): ModelType = modelType match { + case MULTINOMIAL_STRING => Multinomial + case BERNOULLI_STRING => Bernoulli + case _ => + throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType") + } + } + + final val ModelType = MODELTYPE + + final val Multinomial: ModelType = new ModelType { + override def toString: String = ModelType.MULTINOMIAL_STRING + } + + final val Bernoulli: ModelType = new ModelType { + override def toString: String = ModelType.BERNOULLI_STRING + } + } +