Skip to content

Commit

Permalink
[SPARK-5990] [MLLIB] Model import/export for IsotonicRegression
Browse files Browse the repository at this point in the history
Model import/export for IsotonicRegression

Author: Yanbo Liang <[email protected]>

Closes apache#5270 from yanboliang/spark-5990 and squashes the following commits:

872028d [Yanbo Liang] fix code style
f80ec1b [Yanbo Liang] address comments
49600cc [Yanbo Liang] address comments
429ff7d [Yanbo Liang] store each interval as a record
2b2f5a1 [Yanbo Liang] Model import/export for IsotonicRegression
  • Loading branch information
yanboliang authored and mengxr committed Apr 21, 2015
1 parent ab9128f commit 1f2f723
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch

import scala.collection.mutable.ArrayBuffer

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

/**
* :: Experimental ::
Expand All @@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD
class IsotonicRegressionModel (
val boundaries: Array[Double],
val predictions: Array[Double],
val isotonic: Boolean) extends Serializable {
val isotonic: Boolean) extends Serializable with Saveable {

private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse

Expand Down Expand Up @@ -124,6 +131,75 @@ class IsotonicRegressionModel (
predictions(foundIndex)
}
}

override def save(sc: SparkContext, path: String): Unit = {
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
}

override protected def formatVersion: String = "1.0"
}

object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {

def thisFormatVersion: String = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"

/** Model data for model import/export */
case class Data(boundary: Double, prediction: Double)

def save(
sc: SparkContext,
path: String,
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

sqlContext.createDataFrame(
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
).saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))

checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundary", "prediction").collect()
val (boundaries, predictions) = dataArray.map { x =>
(x.getDouble(0), x.getDouble(1))
}.toList.sortBy(_._1).unzip
(boundaries.toArray, predictions.toArray)
}
}

override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val isotonic = (metadata \ "isotonic").extract[Boolean]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
case _ => throw new Exception(
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)"
)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite}

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {

Expand Down Expand Up @@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
assert(model.isotonic)
}

test("model save/load") {
val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)
val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)
val model = new IsotonicRegressionModel(boundaries, predictions, true)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

// Save model, load it back, and compare.
try {
model.save(sc, path)
val sameModel = IsotonicRegressionModel.load(sc, path)
assert(model.boundaries === sameModel.boundaries)
assert(model.predictions === sameModel.predictions)
assert(model.isotonic === model.isotonic)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("isotonic regression with size 0") {
val model = runIsotonicRegression(Seq(), true)

Expand Down

0 comments on commit 1f2f723

Please sign in to comment.