-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5990] [MLLIB] Model import/export for IsotonicRegression #5270
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch | |
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.json4s._ | ||
import org.json4s.JsonDSL._ | ||
import org.json4s.jackson.JsonMethods._ | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} | ||
import org.apache.spark.mllib.util.{Loader, Saveable} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.SparkContext | ||
import org.apache.spark.sql.{DataFrame, SQLContext} | ||
|
||
/** | ||
* :: Experimental :: | ||
|
@@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD | |
class IsotonicRegressionModel ( | ||
val boundaries: Array[Double], | ||
val predictions: Array[Double], | ||
val isotonic: Boolean) extends Serializable { | ||
val isotonic: Boolean) extends Serializable with Saveable { | ||
|
||
private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse | ||
|
||
|
@@ -124,6 +131,74 @@ class IsotonicRegressionModel ( | |
predictions(foundIndex) | ||
} | ||
} | ||
|
||
override def save(sc: SparkContext, path: String): Unit = { | ||
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) | ||
} | ||
|
||
override protected def formatVersion: String = "1.0" | ||
} | ||
|
||
object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { | ||
|
||
import org.apache.spark.mllib.util.Loader._ | ||
|
||
private object SaveLoadV1_0 { | ||
|
||
def thisFormatVersion: String = "1.0" | ||
|
||
/** Hard-code class name string in case it changes in the future */ | ||
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" | ||
|
||
/** Model data for model import/export */ | ||
case class Data(boundary: Double, prediction: Double) | ||
|
||
def save( | ||
sc: SparkContext, | ||
path: String, | ||
boundaries: Array[Double], | ||
predictions: Array[Double], | ||
isotonic: Boolean): Unit = { | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext.implicits._ | ||
|
||
val metadata = compact(render( | ||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ | ||
("isotonic" -> isotonic))) | ||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) | ||
|
||
sqlContext.createDataFrame(boundaries.toSeq.zip(predictions) | ||
.map { case (b, p) => Data(b, p) }).saveAsParquetFile(dataPath(path)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may read better with the following style: sqlContext.createDataFrame(
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
).saveAsParquetFile(dataPath(path)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see my previous comment about the code style here. The following has a better separation of the logic: sqlContext.createDataFrame(
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
).saveAsParquetFile(dataPath(path)) |
||
} | ||
|
||
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { | ||
val sqlContext = new SQLContext(sc) | ||
val dataRDD = sqlContext.parquetFile(dataPath(path)) | ||
|
||
checkSchema[Data](dataRDD.schema) | ||
val dataArray = dataRDD.select("boundary", "prediction").collect() | ||
val (boundaries, predictions) = dataArray.map { | ||
x => (x.getDouble(0), x.getDouble(1)) }.toList.sortBy(_._1).unzip | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. val (boundaries, predictions) = dataArray.map { x =>
(x.getDouble(0), x.getDouble(1))
}.toList.sortBy(_._1).unzip |
||
(boundaries.toArray, predictions.toArray) | ||
} | ||
} | ||
|
||
override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { | ||
implicit val formats = DefaultFormats | ||
val (loadedClassName, version, metadata) = loadMetadata(sc, path) | ||
val isotonic = (metadata \ "isotonic").extract[Boolean] | ||
val classNameV1_0 = SaveLoadV1_0.thisClassName | ||
(loadedClassName, version) match { | ||
case (className, "1.0") if className == classNameV1_0 => | ||
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) | ||
new IsotonicRegressionModel(boundaries, predictions, isotonic) | ||
case _ => throw new Exception( | ||
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + | ||
s"($loadedClassName, $version). Supported:\n" + | ||
s" ($classNameV1_0, 1.0)" | ||
) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite} | |
|
||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.util.Utils | ||
|
||
class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { | ||
|
||
|
@@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M | |
assert(model.isotonic) | ||
} | ||
|
||
test("model save/load") { | ||
val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) | ||
val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0) | ||
val model = new IsotonicRegressionModel(boundaries, predictions, true) | ||
|
||
val tempDir = Utils.createTempDir() | ||
val path = tempDir.toURI.toString | ||
|
||
// Save model, load it back, and compare. | ||
try { | ||
model.save(sc, path) | ||
val sameModel = IsotonicRegressionModel.load(sc, path) | ||
assert(model.boundaries === sameModel.boundaries) | ||
assert(model.predictions === sameModel.predictions) | ||
assert(model.isotonic == model.isotonic) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} finally { | ||
Utils.deleteRecursively(tempDir) | ||
} | ||
} | ||
|
||
test("isotonic regression with size 0") { | ||
val model = runIsotonicRegression(Seq(), true) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line because no implicits are used.