From 13a3b636d9425c5713cd1381203ee1b60f71b8c8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 5 Jan 2016 13:31:59 -0800 Subject: [PATCH] [SPARK-6724][MLLIB] Support model save/load for FPGrowthModel Support model save/load for FPGrowthModel Author: Yanbo Liang Closes #9267 from yanboliang/spark-6724. --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 100 +++++++++++++++++- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 40 +++++++ .../spark/mllib/fpm/FPGrowthSuite.scala | 68 ++++++++++++ 3 files changed, 205 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70ef1ed30c71a..5273ed4d76650 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,19 +17,29 @@ package org.apache.spark.mllib.fpm -import java.{util => ju} import java.lang.{Iterable => JavaIterable} +import java.{util => ju} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) } + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[FPGrowthModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + FPGrowthModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object FPGrowthModel extends Loader[FPGrowthModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + FPGrowthModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel" + + def save(model: FPGrowthModel[_], path: String): Unit = { + val sc = model.freqItemsets.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqItemsets.first().items(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("items", ArrayType(itemType)), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqItemsets.map { x => + Row(x.items, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqItemsets.select("items").head().get(0) + loadImpl(freqItemsets, sample) + } + + def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => + val items = x.getAs[Seq[Item]](0).toArray + val freq = x.getLong(1) + new FreqItemset(items, freq) + } + new FPGrowthModel(freqItemsetsRDD) + } + } } /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 154f75d75e4a6..eeeabfe359e68 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -69,4 +71,42 @@ public void runFPGrowth() { long freq = itemset.freq(); } } + + @Test + public void runFPGrowthSaveLoad() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath); + List> freqItemsets = newModel.freqItemsets().toJavaRDD() + .collect(); + assertEquals(18, freqItemsets.size()); + + for (FPGrowth.FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 4a9bfdb348d9f..b9e997c207bc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { */ assert(model1.freqItemsets.count() === 65) } + + test("model save/load with String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load with Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } }