Skip to content

Commit

Permalink
[SPARK-10386][MLLIB] PrefixSpanModel supports save/load
Browse files Browse the repository at this point in the history
```PrefixSpanModel``` supports ```save/load```. It's similar with #9267.

cc jkbradley

Author: Yanbo Liang <[email protected]>

Closes #10664 from yanboliang/spark-10386.
  • Loading branch information
yanboliang authored and jkbradley committed Apr 13, 2016
1 parent dbbe149 commit b0adb9f
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
96 changes: 95 additions & 1 deletion mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._

import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
@Since("1.5.0")
class PrefixSpanModel[Item] @Since("1.5.0") (
@Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
extends Serializable
extends Saveable with Serializable {

/**
* Save this model to the given path.
* It only works for Item datatypes supported by DataFrames.
*
* This saves:
* - human-readable (JSON) model metadata to path/metadata/
* - Parquet formatted data to path/data/
*
* The model may be loaded using [[PrefixSpanModel.load]].
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* If the directory already exists, this method throws an exception.
*/
@Since("2.0.0")
override def save(sc: SparkContext, path: String): Unit = {
PrefixSpanModel.SaveLoadV1_0.save(this, path)
}

override protected val formatVersion: String = "1.0"
}

@Since("2.0.0")
object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {

@Since("2.0.0")
override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
PrefixSpanModel.SaveLoadV1_0.load(sc, path)
}

private[fpm] object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"

def save(model: PrefixSpanModel[_], path: String): Unit = {
val sc = model.freqSequences.sparkContext
val sqlContext = SQLContext.getOrCreate(sc)

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Get the type of item class
val sample = model.freqSequences.first().sequence(0)(0)
val className = sample.getClass.getCanonicalName
val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
val tpe = classSymbol.selfType

val itemType = ScalaReflection.schemaFor(tpe).dataType
val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))),
StructField("freq", LongType))
val schema = StructType(fields)
val rowDataRDD = model.freqSequences.map { x =>
Row(x.sequence, x.freq)
}
sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
implicit val formats = DefaultFormats
val sqlContext = SQLContext.getOrCreate(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
val sample = freqSequences.select("sequence").head().get(0)
loadImpl(freqSequences, sample)
}

def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = {
val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x =>
val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
val freq = x.getLong(1)
new PrefixSpan.FreqSequence(sequence, freq)
}
new PrefixSpanModel(freqSequencesRDD)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.fpm;

import java.io.File;
import java.util.Arrays;
import java.util.List;

Expand All @@ -28,6 +29,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
import org.apache.spark.util.Utils;

public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -64,4 +66,39 @@ public void runPrefixSpan() {
long freq = freqSeq.freq();
}
}

@Test
public void runPrefixSpanSaveLoad() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
Arrays.asList(Arrays.asList(6))
), 2);
PrefixSpan prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5);
PrefixSpanModel<Integer> model = prefixSpan.run(sequences);

File tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
String outputPath = tempDir.getPath();

try {
model.save(sc.sc(), outputPath);
PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
} finally {
Utils.deleteRecursively(tempDir);
}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}

test("model save/load") {
val sequences = Seq(
Array(Array(1, 2), Array(3)),
Array(Array(1), Array(3, 2), Array(1, 2)),
Array(Array(1, 2), Array(5)),
Array(Array(6)))
val rdd = sc.parallelize(sequences, 2).cache()

val prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
val model = prefixSpan.run(rdd)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model.save(sc, path)
val newModel = PrefixSpanModel.load(sc, path)
val originalSet = model.freqSequences.collect().map { x =>
(x.sequence.map(_.toSet).toSeq, x.freq)
}.toSet
val newSet = newModel.freqSequences.collect().map { x =>
(x.sequence.map(_.toSet).toSeq, x.freq)
}.toSet
assert(originalSet === newSet)
} finally {
Utils.deleteRecursively(tempDir)
}
}

private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
Expand Down

0 comments on commit b0adb9f

Please sign in to comment.