Skip to content

Commit

Permalink
[SPARK-4001][MLlib] adding parallel FP-Growth algorithm for frequent …
Browse files Browse the repository at this point in the history
…pattern mining in MLlib

Apriori is the classic algorithm for frequent item set mining in a transactional data set. It will be useful if Apriori algorithm is added to MLLib in Spark. This PR add an implementation for it.
There is a point I am not sure wether it is most efficient. In order to filter out the eligible frequent item set, currently I am using a cartesian operation on two RDDs to calculate the degree of support of each item set, not sure wether it is better to use broadcast variable to achieve the same.

I will add an example to use this algorithm if requires

Author: Jacky Li <[email protected]>
Author: Jacky Li <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes apache#2847 from jackylk/apriori and squashes the following commits:

bee3093 [Jacky Li] Merge pull request #1 from mengxr/SPARK-4001
7e69725 [Xiangrui Meng] simplify FPTree and update FPGrowth
ec21f7d [Jacky Li] fix scalastyle
93f3280 [Jacky Li] create FPTree class
d110ab2 [Jacky Li] change test case to use MLlibTestSparkContext
a6c5081 [Jacky Li] Add Parallel FPGrowth algorithm
eb3e4ca [Jacky Li] add FPGrowth
03df2b6 [Jacky Li] refactory according to comments
7b77ad7 [Jacky Li] fix scalastyle check
f68a0bd [Jacky Li] add 2 apriori implemenation and fp-growth implementation
889b33f [Jacky Li] modify per scalastyle check
da2cba7 [Jacky Li] adding apriori algorithm for frequent item set mining in Spark
  • Loading branch information
jackylk authored and mengxr committed Feb 2, 2015
1 parent d85cd4e commit 859f724
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 0 deletions.
162 changes: 162 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.fpm

import java.{util => ju}

import scala.collection.mutable

import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable

/**
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an
* independent group of mining tasks. More detail of this algorithm can be found at
* [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be
* found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]]
*
* @param minSupport the minimal support level of the frequent pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param numPartitions number of partitions used by parallel FP-growth
*/
class FPGrowth private (
private var minSupport: Double,
private var numPartitions: Int) extends Logging with Serializable {

/**
* Constructs a FPGrowth instance with default parameters:
* {minSupport: 0.3, numPartitions: auto}
*/
def this() = this(0.3, -1)

/**
* Sets the minimal support level (default: 0.3).
*/
def setMinSupport(minSupport: Double): this.type = {
this.minSupport = minSupport
this
}

/**
* Sets the number of partitions used by parallel FP-growth (default: same as input data).
*/
def setNumPartitions(numPartitions: Int): this.type = {
this.numPartitions = numPartitions
this
}

/**
* Computes an FP-Growth model that contains frequent itemsets.
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*/
def run(data: RDD[Array[String]]): FPGrowthModel = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val count = data.count()
val minCount = math.ceil(minSupport * count).toLong
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
val partitioner = new HashPartitioner(numParts)
val freqItems = genFreqItems(data, minCount, partitioner)
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
new FPGrowthModel(freqItemsets)
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
private def genFreqItems(
data: RDD[Array[String]],
minCount: Long,
partitioner: Partitioner): Array[String] = {
data.flatMap { t =>
val uniq = t.toSet
if (t.length != uniq.size) {
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
}
t
}.map(v => (v, 1L))
.reduceByKey(partitioner, _ + _)
.filter(_._2 >= minCount)
.collect()
.sortBy(-_._2)
.map(_._1)
}

/**
* Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
* @param data transactions
* @param minCount minimum count for frequent itemsets
* @param freqItems frequent items
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
private def genFreqItemsets(
data: RDD[Array[String]],
minCount: Long,
freqItems: Array[String],
partitioner: Partitioner): RDD[(Array[String], Long)] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
(tree, transaction) => tree.add(transaction, 1L),
(tree1, tree2) => tree1.merge(tree2))
.flatMap { case (part, tree) =>
tree.extract(minCount, x => partitioner.getPartition(x) == part)
}.map { case (ranks, count) =>
(ranks.map(i => freqItems(i)).toArray, count)
}
}

/**
* Generates conditional transactions.
* @param transaction a transaction
* @param itemToRank map from item to their rank
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
private def genCondTransactions(
transaction: Array[String],
itemToRank: Map[String, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
val filtered = transaction.flatMap(itemToRank.get)
ju.Arrays.sort(filtered)
val n = filtered.length
var i = n - 1
while (i >= 0) {
val item = filtered(i)
val part = partitioner.getPartition(item)
if (!output.contains(part)) {
output(part) = filtered.slice(0, i + 1)
}
i -= 1
}
output
}
}
134 changes: 134 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.fpm

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/**
* FP-Tree data structure used in FP-Growth.
* @tparam T item type
*/
private[fpm] class FPTree[T] extends Serializable {

import FPTree._

val root: Node[T] = new Node(null)

private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty

/** Adds a transaction with count. */
def add(t: Iterable[T], count: Long = 1L): this.type = {
require(count > 0)
var curr = root
curr.count += count
t.foreach { item =>
val summary = summaries.getOrElseUpdate(item, new Summary)
summary.count += count
val child = curr.children.getOrElseUpdate(item, {
val newNode = new Node(curr)
newNode.item = item
summary.nodes += newNode
newNode
})
child.count += count
curr = child
}
this
}

/** Merges another FP-Tree. */
def merge(other: FPTree[T]): this.type = {
other.transactions.foreach { case (t, c) =>
add(t, c)
}
this
}

/** Gets a subtree with the suffix. */
private def project(suffix: T): FPTree[T] = {
val tree = new FPTree[T]
if (summaries.contains(suffix)) {
val summary = summaries(suffix)
summary.nodes.foreach { node =>
var t = List.empty[T]
var curr = node.parent
while (!curr.isRoot) {
t = curr.item :: t
curr = curr.parent
}
tree.add(t, node.count)
}
}
tree
}

/** Returns all transactions in an iterator. */
def transactions: Iterator[(List[T], Long)] = getTransactions(root)

/** Returns all transactions under this node. */
private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = {
var count = node.count
node.children.iterator.flatMap { case (item, child) =>
getTransactions(child).map { case (t, c) =>
count -= c
(item :: t, c)
}
} ++ {
if (count > 0) {
Iterator.single((Nil, count))
} else {
Iterator.empty
}
}
}

/** Extracts all patterns with valid suffix and minimum count. */
def extract(
minCount: Long,
validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = {
summaries.iterator.flatMap { case (item, summary) =>
if (validateSuffix(item) && summary.count >= minCount) {
Iterator.single((item :: Nil, summary.count)) ++
project(item).extract(minCount).map { case (t, c) =>
(item :: t, c)
}
} else {
Iterator.empty
}
}
}
}

private[fpm] object FPTree {

/** Representing a node in an FP-Tree. */
class Node[T](val parent: Node[T]) extends Serializable {
var item: T = _
var count: Long = 0L
val children: mutable.Map[T, Node[T]] = mutable.Map.empty

def isRoot: Boolean = parent == null
}

/** Summary of a item in an FP-Tree. */
private class Summary[T] extends Serializable {
var count: Long = 0L
val nodes: ListBuffer[Node[T]] = ListBuffer.empty
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {

test("FP-Growth") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
"s x o n r",
"x z y m t s q e",
"z",
"x z y r q t p")
.map(_.split(" "))
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
val expected = Set(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
(Set("r"), 3L),
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
(Set("t", "y", "x"), 3L),
(Set("t", "y", "x", "z"), 3L))
assert(freqItemsets3.toSet === expected)

val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run(rdd)
assert(model2.freqItemsets.count() === 54)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run(rdd)
assert(model1.freqItemsets.count() === 625)
}
}
Loading

0 comments on commit 859f724

Please sign in to comment.