Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10386] [MLlib] PrefixSpanModel supports save/load #10664

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._

import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
@Since("1.5.0")
class PrefixSpanModel[Item] @Since("1.5.0") (
@Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
extends Serializable
extends Saveable with Serializable {

/**
* Save this model to the given path.
* It only works for Item datatypes supported by DataFrames.
*
* This saves:
* - human-readable (JSON) model metadata to path/metadata/
* - Parquet formatted data to path/data/
*
* The model may be loaded using [[PrefixSpanModel.load]].
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* If the directory already exists, this method throws an exception.
*/
@Since("2.0.0")
override def save(sc: SparkContext, path: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is sc used ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not here, but it's needed for other algorithms (with local representations).

PrefixSpanModel.SaveLoadV1_0.save(this, path)
}

override protected val formatVersion: String = "1.0"
}

@Since("2.0.0")
object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {

@Since("2.0.0")
override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
PrefixSpanModel.SaveLoadV1_0.load(sc, path)
}

private[fpm] object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"

def save(model: PrefixSpanModel[_], path: String): Unit = {
val sc = model.freqSequences.sparkContext
val sqlContext = SQLContext.getOrCreate(sc)

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Get the type of item class
val sample = model.freqSequences.first().sequence(0)(0)
val className = sample.getClass.getCanonicalName
val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
val tpe = classSymbol.selfType

val itemType = ScalaReflection.schemaFor(tpe).dataType
val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))),
StructField("freq", LongType))
val schema = StructType(fields)
val rowDataRDD = model.freqSequences.map { x =>
Row(x.sequence, x.freq)
}
sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
implicit val formats = DefaultFormats
val sqlContext = SQLContext.getOrCreate(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
val sample = freqSequences.select("sequence").head().get(0)
loadImpl(freqSequences, sample)
}

def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = {
val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x =>
val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
val freq = x.getLong(1)
new PrefixSpan.FreqSequence(sequence, freq)
}
new PrefixSpanModel(freqSequencesRDD)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.fpm;

import java.io.File;
import java.util.Arrays;
import java.util.List;

Expand All @@ -28,6 +29,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
import org.apache.spark.util.Utils;

public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -64,4 +66,39 @@ public void runPrefixSpan() {
long freq = freqSeq.freq();
}
}

@Test
public void runPrefixSpanSaveLoad() {
JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
Arrays.asList(Arrays.asList(6))
), 2);
PrefixSpan prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5);
PrefixSpanModel<Integer> model = prefixSpan.run(sequences);

File tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
String outputPath = tempDir.getPath();

try {
model.save(sc.sc(), outputPath);
PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
Assert.assertEquals(5, localFreqSeqs.size());
// Check that each frequent sequence could be materialized.
for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
List<List<Integer>> seq = freqSeq.javaSequence();
long freq = freqSeq.freq();
}
} finally {
Utils.deleteRecursively(tempDir);
}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}

test("model save/load") {
val sequences = Seq(
Array(Array(1, 2), Array(3)),
Array(Array(1), Array(3, 2), Array(1, 2)),
Array(Array(1, 2), Array(5)),
Array(Array(6)))
val rdd = sc.parallelize(sequences, 2).cache()

val prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
val model = prefixSpan.run(rdd)

val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model.save(sc, path)
val newModel = PrefixSpanModel.load(sc, path)
val originalSet = model.freqSequences.collect().map { x =>
(x.sequence.map(_.toSet).toSeq, x.freq)
}.toSet
val newSet = newModel.freqSequences.collect().map { x =>
(x.sequence.map(_.toSet).toSeq, x.freq)
}.toSet
assert(originalSet === newSet)
} finally {
Utils.deleteRecursively(tempDir)
}
}

private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
Expand Down