Skip to content

Commit

Permalink
[SPARK-6827] [MLLIB] Wrap FPGrowthModel.freqItemsets and make it cons…
Browse files Browse the repository at this point in the history
…istent with Java API

Make PySpark ```FPGrowthModel.freqItemsets``` consistent with Java/Scala API like ```MatrixFactorizationModel.userFeatures```
It return a RDD with each tuple is composed of an array and a long value.
I think it's difficult to implement namedtuples to wrap the output because items of freqItemsets can be any type with arbitrary length which is tedious to impelement corresponding SerDe function.

Author: Yanbo Liang <[email protected]>

Closes #5614 from yanboliang/spark-6827 and squashes the following commits:

da8c404 [Yanbo Liang] use namedtuple
5532e78 [Yanbo Liang] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API
  • Loading branch information
yanboliang authored and mengxr committed Apr 23, 2015
1 parent baf865d commit f4f3998
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
# limitations under the License.
#

import numpy
from numpy import array
from collections import namedtuple

from pyspark import SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
Expand All @@ -36,14 +40,14 @@ class FPGrowthModel(JavaModelWrapper):
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect())
[([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)]
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
"""

def freqItemsets(self):
"""
Get the frequent itemsets of this model
Returns the frequent itemsets of this model.
"""
return self.call("getFreqItemsets")
return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))


class FPGrowth(object):
Expand All @@ -67,6 +71,11 @@ def train(cls, data, minSupport=0.3, numPartitions=-1):
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
return FPGrowthModel(model)

class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
"""
Represents an (items, freq) tuple.
"""


def _test():
import doctest
Expand Down

0 comments on commit f4f3998

Please sign in to comment.