Skip to content

Commit

Permalink
Added new model save/load format 2.0 for NaiveBayesModel after modelT…
Browse files Browse the repository at this point in the history
…ype parameter was added. Updated tests. Also updated ModelType enum-like type.
  • Loading branch information
jkbradley committed Mar 22, 2015
1 parent 852a727 commit 6a8f383
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

import NaiveBayes.ModelType.{Bernoulli, Multinomial}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
extends ClassificationModel with Serializable with Saveable {

private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayes.Multinomial)
this(labels, pi, theta, Multinomial)

/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
Expand All @@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = modelType match {
case NaiveBayes.Multinomial => (None, None)
case NaiveBayes.Bernoulli =>
case Multinomial => (None, None)
case Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}

override def predict(testData: RDD[Vector]): RDD[Double] = {
Expand All @@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (

override def predict(testData: Vector): Double = {
modelType match {
case NaiveBayes.Multinomial =>
case Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayes.Bernoulli =>
case Bernoulli =>
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
override protected def formatVersion: String = "2.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {
private[mllib] object SaveLoadV2_0 {

def thisFormatVersion: String = "1.0"
def thisFormatVersion: String = "2.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
Expand All @@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~
("modelType" -> data.modelType)))
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
Expand All @@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
new NaiveBayesModel(labels, pi, theta, modelType)
}

}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
implicit val formats = DefaultFormats
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
private[mllib] object SaveLoadV1_0 {

def thisFormatVersion: String = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(
labels: Array[Double],
pi: Array[Double],
theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
val classNameV2_0 = SaveLoadV2_0.thisClassName
val (model, numFeatures, numClasses) = (loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
assert(model.modelType == getModelType(metadata))
model
(model, numFeatures, numClasses)
case (className, "2.0") if className == classNameV2_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV2_0.load(sc, path)
(model, numFeatures, numClasses)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
}
}

Expand All @@ -197,9 +250,9 @@ class NaiveBayes private (
private var lambda: Double,
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {

def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
def this(lambda: Double) = this(lambda, Multinomial)

def this() = this(1.0, NaiveBayes.Multinomial)
def this() = this(1.0, Multinomial)

/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
Expand All @@ -210,9 +263,22 @@ class NaiveBayes private (
/** Get the smoothing parameter. */
def getLambda: Double = lambda

/** Set the model type. Default: Multinomial. */
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = model
/**
* Set the model type using a string (case-insensitive).
* Supported options: "multinomial" and "bernoulli".
* (default: multinomial)
*/
def setModelType(modelType: String): NaiveBayes = {
setModelType(NaiveBayes.ModelType.fromString(modelType))
}

/**
* Set the model type.
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
* (default: Multinomial)
*/
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = modelType
this
}

Expand Down Expand Up @@ -270,8 +336,11 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
}
var j = 0
while (j < numFeatures) {
Expand Down Expand Up @@ -317,7 +386,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input)
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
}

/**
Expand All @@ -339,12 +408,45 @@ object NaiveBayes {
* multinomial or bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
}

/** Provides static methods for using ModelType. */
sealed abstract class ModelType extends Serializable

object ModelType extends Serializable {

/**
* Get the model type from a string.
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
*/
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
case "multinomial" => Multinomial
case "bernoulli" => Bernoulli
case _ =>
throw new IllegalArgumentException(
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
}

final val Multinomial: ModelType = {
case object Multinomial extends ModelType with Serializable {
override def toString: String = "multinomial"
}
Multinomial
}

final val Bernoulli: ModelType = {
case object Bernoulli extends ModelType with Serializable {
override def toString: String = "bernoulli"
}
Bernoulli
}
}

/** Java-friendly accessor for supported ModelType options */
final val modelTypes = ModelType

/*
object MODELTYPE extends Serializable{
final val MULTINOMIAL_STRING = "multinomial"
final val BERNOULLI_STRING = "bernoulli"
Expand All @@ -368,6 +470,6 @@ object NaiveBayes {
final val Bernoulli: ModelType = new ModelType {
override def toString: String = ModelType.BERNOULLI_STRING
}

*/
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

package org.apache.spark.mllib.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception {
// Should be able to get the first prediction.
predictions.first();
}

@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
.setModelType(NaiveBayes.modelTypes().Bernoulli())
.setModelType(NaiveBayes.modelTypes().Multinomial());
}
}
Loading

0 comments on commit 6a8f383

Please sign in to comment.