Skip to content

Commit

Permalink
[SPARK-24146][PYSPARK][ML] spark.ml parity for sequential pattern min…
Browse files Browse the repository at this point in the history
…ing - PrefixSpan: Python API

## What changes were proposed in this pull request?

spark.ml parity for sequential pattern mining - PrefixSpan: Python API

## How was this patch tested?

doctests

Author: WeichenXu <[email protected]>

Closes #21265 from WeichenXu123/prefix_span_py.
  • Loading branch information
WeichenXu123 authored and mengxr committed May 31, 2018
1 parent 0053e15 commit 90ae98d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 4 deletions.
6 changes: 3 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
"(minSupport * size-of-the-dataset)." +
"(minSupport * size-of-the-dataset) " +
"times will be output.", ParamValidators.gtEq(0.0))

/** @group getParam */
Expand Down Expand Up @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
* {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset.
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `sequence: ArrayType(ArrayType(T))` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
Expand Down
104 changes: 103 additions & 1 deletion python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
#

from pyspark import keyword_only, since
from pyspark.sql import DataFrame
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm
from pyspark.ml.param.shared import *

__all__ = ["FPGrowth", "FPGrowthModel"]
Expand Down Expand Up @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items",

def _create_model(self, java_model):
return FPGrowthModel(java_model)


class PrefixSpan(JavaParams):
"""
.. note:: Experimental
A parallel PrefixSpan algorithm to mine frequent sequential patterns.
The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
Efficiently by Prefix-Projected Pattern Growth
(see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns`
method to run the PrefixSpan algorithm.
@see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
(Wikipedia)</a>
.. versionadded:: 2.4.0
"""

minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
"(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
typeConverter=TypeConverters.toFloat)

maxPatternLength = Param(Params._dummy(), "maxPatternLength",
"The maximal length of the sequential pattern. Must be > 0.",
typeConverter=TypeConverters.toInt)

maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the " +
"internal storage format) allowed in a projected database before " +
"local processing. If a projected database exceeds this size, " +
"another iteration of distributed prefix growth is run. " +
"Must be > 0.",
typeConverter=TypeConverters.toInt)

sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " +
"dataset, rows with nulls in this column are ignored.",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
sequenceCol="sequence"):
"""
__init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
"""
super(PrefixSpan, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
sequenceCol="sequence")
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("2.4.0")
def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
sequenceCol="sequence"):
"""
setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
"""
kwargs = self._input_kwargs
return self._set(**kwargs)

@since("2.4.0")
def findFrequentSequentialPatterns(self, dataset):
"""
.. note:: Experimental
Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
:param dataset: A dataframe containing a sequence column which is
`ArrayType(ArrayType(T))` type, T is the item type for the input dataset.
:return: A `DataFrame` that contains columns of sequence and corresponding frequency.
The schema of it will be:
- `sequence: ArrayType(ArrayType(T))` (T is the item type)
- `freq: Long`
>>> from pyspark.ml.fpm import PrefixSpan
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]),
... Row(sequence=[[1], [3, 2], [1, 2]]),
... Row(sequence=[[1, 2], [5]]),
... Row(sequence=[[6]])]).toDF()
>>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5)
>>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False)
+----------+----+
|sequence |freq|
+----------+----+
|[[1]] |3 |
|[[1], [3]]|2 |
|[[1, 2]] |3 |
|[[2]] |3 |
|[[3]] |2 |
+----------+----+
.. versionadded:: 2.4.0
"""
self._transfer_params_to_java()
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
return DataFrame(jdf, dataset.sql_ctx)

0 comments on commit 90ae98d

Please sign in to comment.