Skip to content

Commit

Permalink
modified NB model type to be more Java-friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Mar 3, 2015
1 parent b61b5e2 commit 3730572
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

import java.util.regex.Pattern;

public final class JavaNaiveBayes {

static class ParsePoint implements Function<String, LabeledPoint> {
private static final Pattern COMMA = Pattern.compile(",");
private static final Pattern SPACE = Pattern.compile(" ");

@Override
public LabeledPoint call(String line) {
String[] parts = COMMA.split(line);
double y = Double.parseDouble(parts[0]);
String[] tok = SPACE.split(parts[1]);
double[] x = new double[tok.length];
for (int i = 0; i < tok.length; ++i) {
x[i] = Double.parseDouble(tok[i]);
}
return new LabeledPoint(y, Vectors.dense(x));
}
}

public static void main(String[] args) {
if (args.length != 3) {
System.err.println("Usage: JavaLR <input_dir> <step_size> <niters>");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(args[0]);
JavaRDD<LabeledPoint> points = lines.map(new ParsePoint()).cache();
double stepSize = Double.parseDouble(args[1]);
int iterations = Integer.parseInt(args[2]);

// Example which compiles. (Don't actually include!)
NaiveBayes nb = new NaiveBayes();
nb.setModelType(NaiveBayes.Bernoulli());

sc.stop();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ object SparseNaiveBayes {

println(s"numTraining = $numTraining, numTest = $numTest.")

// Example which compiles. (Don't actually include!)
val nb = new NaiveBayes()
nb.setModelType(NaiveBayes.Bernoulli)

val model = new NaiveBayes().setLambda(params.lambda).run(training)

val prediction = model.predict(test.map(_.features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,11 @@ import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}


/**
*
*/
object NaiveBayesModels extends Enumeration {
type NaiveBayesModels = Value
val Multinomial, Bernoulli = Value

implicit def toString(model: NaiveBayesModels): String = {
model.toString
}
}

/**
* Model for Naive Bayes Classifiers.
*
Expand All @@ -60,17 +47,18 @@ class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable {
val modelType: NaiveBayes.ModelType)
extends ClassificationModel with Serializable with Saveable {

def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayesModels.Multinomial)
this(labels, pi, theta, NaiveBayes.Multinomial)

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t

private val brzNegTheta: Option[BDM[Double]] = modelType match {
case NaiveBayesModels.Multinomial => None
case NaiveBayesModels.Bernoulli =>
case NaiveBayes.Multinomial => None
case NaiveBayes.Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
Option(negTheta)
}
Expand All @@ -85,17 +73,17 @@ class NaiveBayesModel private[mllib] (

override def predict(testData: Vector): Double = {
modelType match {
case NaiveBayesModels.Multinomial =>
case NaiveBayes.Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayesModels.Bernoulli =>
case NaiveBayes.Bernoulli =>
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze +
brzSum(brzNegTheta.get, Axis._1)))
}
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

Expand Down Expand Up @@ -147,15 +135,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
val modelType: NaiveBayesModels = NaiveBayesModels.withName(data.getAs[String](3))
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
new NaiveBayesModel(labels, pi, theta, modelType)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): NaiveBayesModels = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
implicit val formats = DefaultFormats
NaiveBayesModels.withName((metadata \ "modelType").extract[String])
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
}
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
Expand Down Expand Up @@ -191,12 +179,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
class NaiveBayes private (private var lambda: Double,
var modelType: NaiveBayesModels) extends Serializable with Logging {
class NaiveBayes private (
private var lambda: Double,
var modelType: NaiveBayes.ModelType) extends Serializable with Logging {

def this(lambda: Double) = this(lambda, NaiveBayesModels.Multinomial)
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)

def this() = this(1.0, NaiveBayesModels.Multinomial)
def this() = this(1.0, NaiveBayes.Multinomial)

/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
Expand All @@ -205,7 +194,7 @@ class NaiveBayes private (private var lambda: Double,
}

/** Set the model type. Default: Multinomial. */
def setModelType(model: NaiveBayesModels): NaiveBayes = {
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = model
this
}
Expand Down Expand Up @@ -262,8 +251,8 @@ class NaiveBayes private (private var lambda: Double,
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case NaiveBayesModels.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case NaiveBayesModels.Bernoulli => math.log(n + 2.0 * lambda)
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
}
var j = 0
while (j < numFeatures) {
Expand Down Expand Up @@ -330,6 +319,32 @@ object NaiveBayes {
* Multinomial or Bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
new NaiveBayes(lambda, NaiveBayesModels.withName(modelType)).run(input)
new NaiveBayes(lambda, Multinomial).run(input)
}

sealed abstract class ModelType

object MODELTYPE {
final val MULTINOMIAL_STRING = "multinomial"
final val BERNOULLI_STRING = "bernoulli"

def fromString(modelType: String): ModelType = modelType match {
case MULTINOMIAL_STRING => Multinomial
case BERNOULLI_STRING => Bernoulli
case _ =>
throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType")
}
}

final val ModelType = MODELTYPE

final val Multinomial: ModelType = new ModelType {
override def toString: String = ModelType.MULTINOMIAL_STRING
}

final val Bernoulli: ModelType = new ModelType {
override def toString: String = ModelType.BERNOULLI_STRING
}

}

0 comments on commit 3730572

Please sign in to comment.