Skip to content

Commit

Permalink
[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel
Browse files Browse the repository at this point in the history
Support model save/load for FPGrowthModel

Author: Yanbo Liang <[email protected]>

Closes #9267 from yanboliang/spark-6724.
  • Loading branch information
yanboliang authored and jkbradley committed Jan 5, 2016
1 parent 047a31b commit 13a3b63
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 3 deletions.
100 changes: 97 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,29 @@

package org.apache.spark.mllib.fpm

import java.{util => ju}
import java.lang.{Iterable => JavaIterable}
import java.{util => ju}

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._

import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.fpm.FPGrowth._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

/**
Expand All @@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel
*/
@Since("1.3.0")
class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
@Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
@Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]])
extends Saveable with Serializable {
/**
* Generates association rules for the [[Item]]s in [[freqItemsets]].
* @param confidence minimal confidence of the rules produced
Expand All @@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
val associationRules = new AssociationRules(confidence)
associationRules.run(freqItemsets)
}

/**
* Save this model to the given path.
* It only works for Item datatypes supported by DataFrames.
*
* This saves:
* - human-readable (JSON) model metadata to path/metadata/
* - Parquet formatted data to path/data/
*
* The model may be loaded using [[FPGrowthModel.load]].
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* If the directory already exists, this method throws an exception.
*/
@Since("2.0.0")
override def save(sc: SparkContext, path: String): Unit = {
FPGrowthModel.SaveLoadV1_0.save(this, path)
}

override protected val formatVersion: String = "1.0"
}

@Since("2.0.0")
object FPGrowthModel extends Loader[FPGrowthModel[_]] {

@Since("2.0.0")
override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
FPGrowthModel.SaveLoadV1_0.load(sc, path)
}

private[fpm] object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel"

def save(model: FPGrowthModel[_], path: String): Unit = {
val sc = model.freqItemsets.sparkContext
val sqlContext = SQLContext.getOrCreate(sc)

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Get the type of item class
val sample = model.freqItemsets.first().items(0)
val className = sample.getClass.getCanonicalName
val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
val tpe = classSymbol.selfType

val itemType = ScalaReflection.schemaFor(tpe).dataType
val fields = Array(StructField("items", ArrayType(itemType)),
StructField("freq", LongType))
val schema = StructType(fields)
val rowDataRDD = model.freqItemsets.map { x =>
Row(x.items, x.freq)
}
sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
implicit val formats = DefaultFormats
val sqlContext = SQLContext.getOrCreate(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
val sample = freqItemsets.select("items").head().get(0)
loadImpl(freqItemsets, sample)
}

def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x =>
val items = x.getAs[Seq[Item]](0).toArray
val freq = x.getLong(1)
new FreqItemset(items, freq)
}
new FPGrowthModel(freqItemsetsRDD)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.fpm;

import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
Expand All @@ -28,6 +29,7 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.util.Utils;

public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -69,4 +71,42 @@ public void runFPGrowth() {
long freq = itemset.freq();
}
}

@Test
public void runFPGrowthSaveLoad() {

@SuppressWarnings("unchecked")
JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
Arrays.asList("r z h k p".split(" ")),
Arrays.asList("z y x w v u t s".split(" ")),
Arrays.asList("s x o n r".split(" ")),
Arrays.asList("x z y m t s q e".split(" ")),
Arrays.asList("z".split(" ")),
Arrays.asList("x z y r q t p".split(" "))), 2);

FPGrowthModel<String> model = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);

File tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
String outputPath = tempDir.getPath();

try {
model.save(sc.sc(), outputPath);
FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath);
List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
.collect();
assertEquals(18, freqItemsets.size());

for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
}
} finally {
Utils.deleteRecursively(tempDir);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
assert(model1.freqItemsets.count() === 65)
}

test("model save/load with String type") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
"s x o n r",
"x z y m t s q e",
"z",
"x z y r q t p")
.map(_.split(" "))
val rdd = sc.parallelize(transactions, 2).cache()

val model3 = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
(itemset.items.toSet, itemset.freq)
}

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model3.save(sc, path)
val newModel = FPGrowthModel.load(sc, path)
val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
(itemset.items.toSet, itemset.freq)
}
assert(freqItemsets3.toSet === newFreqItemsets.toSet)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("model save/load with Int type") {
val transactions = Seq(
"1 2 3",
"1 2 3 4",
"5 4 3 2 1",
"6 5 4 3 2 1",
"2 4",
"1 3",
"1 7")
.map(_.split(" ").map(_.toInt).toArray)
val rdd = sc.parallelize(transactions, 2).cache()

val model3 = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
(itemset.items.toSet, itemset.freq)
}

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model3.save(sc, path)
val newModel = FPGrowthModel.load(sc, path)
val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
(itemset.items.toSet, itemset.freq)
}
assert(freqItemsets3.toSet === newFreqItemsets.toSet)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

0 comments on commit 13a3b63

Please sign in to comment.