forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-4001][MLlib] adding parallel FP-Growth algorithm for frequent …
…pattern mining in MLlib Apriori is the classic algorithm for frequent item set mining in a transactional data set. It will be useful if Apriori algorithm is added to MLLib in Spark. This PR add an implementation for it. There is a point I am not sure wether it is most efficient. In order to filter out the eligible frequent item set, currently I am using a cartesian operation on two RDDs to calculate the degree of support of each item set, not sure wether it is better to use broadcast variable to achieve the same. I will add an example to use this algorithm if requires Author: Jacky Li <[email protected]> Author: Jacky Li <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#2847 from jackylk/apriori and squashes the following commits: bee3093 [Jacky Li] Merge pull request #1 from mengxr/SPARK-4001 7e69725 [Xiangrui Meng] simplify FPTree and update FPGrowth ec21f7d [Jacky Li] fix scalastyle 93f3280 [Jacky Li] create FPTree class d110ab2 [Jacky Li] change test case to use MLlibTestSparkContext a6c5081 [Jacky Li] Add Parallel FPGrowth algorithm eb3e4ca [Jacky Li] add FPGrowth 03df2b6 [Jacky Li] refactory according to comments 7b77ad7 [Jacky Li] fix scalastyle check f68a0bd [Jacky Li] add 2 apriori implemenation and fp-growth implementation 889b33f [Jacky Li] modify per scalastyle check da2cba7 [Jacky Li] adding apriori algorithm for frequent item set mining in Spark
- Loading branch information
Showing
4 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
162 changes: 162 additions & 0 deletions
162
mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fpm | ||
|
||
import java.{util => ju} | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.storage.StorageLevel | ||
|
||
class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable | ||
|
||
/** | ||
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. | ||
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an | ||
* independent group of mining tasks. More detail of this algorithm can be found at | ||
* [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be | ||
* found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]] | ||
* | ||
* @param minSupport the minimal support level of the frequent pattern, any pattern appears | ||
* more than (minSupport * size-of-the-dataset) times will be output | ||
* @param numPartitions number of partitions used by parallel FP-growth | ||
*/ | ||
class FPGrowth private ( | ||
private var minSupport: Double, | ||
private var numPartitions: Int) extends Logging with Serializable { | ||
|
||
/** | ||
* Constructs a FPGrowth instance with default parameters: | ||
* {minSupport: 0.3, numPartitions: auto} | ||
*/ | ||
def this() = this(0.3, -1) | ||
|
||
/** | ||
* Sets the minimal support level (default: 0.3). | ||
*/ | ||
def setMinSupport(minSupport: Double): this.type = { | ||
this.minSupport = minSupport | ||
this | ||
} | ||
|
||
/** | ||
* Sets the number of partitions used by parallel FP-growth (default: same as input data). | ||
*/ | ||
def setNumPartitions(numPartitions: Int): this.type = { | ||
this.numPartitions = numPartitions | ||
this | ||
} | ||
|
||
/** | ||
* Computes an FP-Growth model that contains frequent itemsets. | ||
* @param data input data set, each element contains a transaction | ||
* @return an [[FPGrowthModel]] | ||
*/ | ||
def run(data: RDD[Array[String]]): FPGrowthModel = { | ||
if (data.getStorageLevel == StorageLevel.NONE) { | ||
logWarning("Input data is not cached.") | ||
} | ||
val count = data.count() | ||
val minCount = math.ceil(minSupport * count).toLong | ||
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length | ||
val partitioner = new HashPartitioner(numParts) | ||
val freqItems = genFreqItems(data, minCount, partitioner) | ||
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) | ||
new FPGrowthModel(freqItemsets) | ||
} | ||
|
||
/** | ||
* Generates frequent items by filtering the input data using minimal support level. | ||
* @param minCount minimum count for frequent itemsets | ||
* @param partitioner partitioner used to distribute items | ||
* @return array of frequent pattern ordered by their frequencies | ||
*/ | ||
private def genFreqItems( | ||
data: RDD[Array[String]], | ||
minCount: Long, | ||
partitioner: Partitioner): Array[String] = { | ||
data.flatMap { t => | ||
val uniq = t.toSet | ||
if (t.length != uniq.size) { | ||
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") | ||
} | ||
t | ||
}.map(v => (v, 1L)) | ||
.reduceByKey(partitioner, _ + _) | ||
.filter(_._2 >= minCount) | ||
.collect() | ||
.sortBy(-_._2) | ||
.map(_._1) | ||
} | ||
|
||
/** | ||
* Generate frequent itemsets by building FP-Trees, the extraction is done on each partition. | ||
* @param data transactions | ||
* @param minCount minimum count for frequent itemsets | ||
* @param freqItems frequent items | ||
* @param partitioner partitioner used to distribute transactions | ||
* @return an RDD of (frequent itemset, count) | ||
*/ | ||
private def genFreqItemsets( | ||
data: RDD[Array[String]], | ||
minCount: Long, | ||
freqItems: Array[String], | ||
partitioner: Partitioner): RDD[(Array[String], Long)] = { | ||
val itemToRank = freqItems.zipWithIndex.toMap | ||
data.flatMap { transaction => | ||
genCondTransactions(transaction, itemToRank, partitioner) | ||
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)( | ||
(tree, transaction) => tree.add(transaction, 1L), | ||
(tree1, tree2) => tree1.merge(tree2)) | ||
.flatMap { case (part, tree) => | ||
tree.extract(minCount, x => partitioner.getPartition(x) == part) | ||
}.map { case (ranks, count) => | ||
(ranks.map(i => freqItems(i)).toArray, count) | ||
} | ||
} | ||
|
||
/** | ||
* Generates conditional transactions. | ||
* @param transaction a transaction | ||
* @param itemToRank map from item to their rank | ||
* @param partitioner partitioner used to distribute transactions | ||
* @return a map of (target partition, conditional transaction) | ||
*/ | ||
private def genCondTransactions( | ||
transaction: Array[String], | ||
itemToRank: Map[String, Int], | ||
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { | ||
val output = mutable.Map.empty[Int, Array[Int]] | ||
// Filter the basket by frequent items pattern and sort their ranks. | ||
val filtered = transaction.flatMap(itemToRank.get) | ||
ju.Arrays.sort(filtered) | ||
val n = filtered.length | ||
var i = n - 1 | ||
while (i >= 0) { | ||
val item = filtered(i) | ||
val part = partitioner.getPartition(item) | ||
if (!output.contains(part)) { | ||
output(part) = filtered.slice(0, i + 1) | ||
} | ||
i -= 1 | ||
} | ||
output | ||
} | ||
} |
134 changes: 134 additions & 0 deletions
134
mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fpm | ||
|
||
import scala.collection.mutable | ||
import scala.collection.mutable.ListBuffer | ||
|
||
/** | ||
* FP-Tree data structure used in FP-Growth. | ||
* @tparam T item type | ||
*/ | ||
private[fpm] class FPTree[T] extends Serializable { | ||
|
||
import FPTree._ | ||
|
||
val root: Node[T] = new Node(null) | ||
|
||
private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty | ||
|
||
/** Adds a transaction with count. */ | ||
def add(t: Iterable[T], count: Long = 1L): this.type = { | ||
require(count > 0) | ||
var curr = root | ||
curr.count += count | ||
t.foreach { item => | ||
val summary = summaries.getOrElseUpdate(item, new Summary) | ||
summary.count += count | ||
val child = curr.children.getOrElseUpdate(item, { | ||
val newNode = new Node(curr) | ||
newNode.item = item | ||
summary.nodes += newNode | ||
newNode | ||
}) | ||
child.count += count | ||
curr = child | ||
} | ||
this | ||
} | ||
|
||
/** Merges another FP-Tree. */ | ||
def merge(other: FPTree[T]): this.type = { | ||
other.transactions.foreach { case (t, c) => | ||
add(t, c) | ||
} | ||
this | ||
} | ||
|
||
/** Gets a subtree with the suffix. */ | ||
private def project(suffix: T): FPTree[T] = { | ||
val tree = new FPTree[T] | ||
if (summaries.contains(suffix)) { | ||
val summary = summaries(suffix) | ||
summary.nodes.foreach { node => | ||
var t = List.empty[T] | ||
var curr = node.parent | ||
while (!curr.isRoot) { | ||
t = curr.item :: t | ||
curr = curr.parent | ||
} | ||
tree.add(t, node.count) | ||
} | ||
} | ||
tree | ||
} | ||
|
||
/** Returns all transactions in an iterator. */ | ||
def transactions: Iterator[(List[T], Long)] = getTransactions(root) | ||
|
||
/** Returns all transactions under this node. */ | ||
private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = { | ||
var count = node.count | ||
node.children.iterator.flatMap { case (item, child) => | ||
getTransactions(child).map { case (t, c) => | ||
count -= c | ||
(item :: t, c) | ||
} | ||
} ++ { | ||
if (count > 0) { | ||
Iterator.single((Nil, count)) | ||
} else { | ||
Iterator.empty | ||
} | ||
} | ||
} | ||
|
||
/** Extracts all patterns with valid suffix and minimum count. */ | ||
def extract( | ||
minCount: Long, | ||
validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = { | ||
summaries.iterator.flatMap { case (item, summary) => | ||
if (validateSuffix(item) && summary.count >= minCount) { | ||
Iterator.single((item :: Nil, summary.count)) ++ | ||
project(item).extract(minCount).map { case (t, c) => | ||
(item :: t, c) | ||
} | ||
} else { | ||
Iterator.empty | ||
} | ||
} | ||
} | ||
} | ||
|
||
private[fpm] object FPTree { | ||
|
||
/** Representing a node in an FP-Tree. */ | ||
class Node[T](val parent: Node[T]) extends Serializable { | ||
var item: T = _ | ||
var count: Long = 0L | ||
val children: mutable.Map[T, Node[T]] = mutable.Map.empty | ||
|
||
def isRoot: Boolean = parent == null | ||
} | ||
|
||
/** Summary of a item in an FP-Tree. */ | ||
private class Summary[T] extends Serializable { | ||
var count: Long = 0L | ||
val nodes: ListBuffer[Node[T]] = ListBuffer.empty | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.fpm | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
|
||
class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { | ||
|
||
test("FP-Growth") { | ||
val transactions = Seq( | ||
"r z h k p", | ||
"z y x w v u t s", | ||
"s x o n r", | ||
"x z y m t s q e", | ||
"z", | ||
"x z y r q t p") | ||
.map(_.split(" ")) | ||
val rdd = sc.parallelize(transactions, 2).cache() | ||
|
||
val fpg = new FPGrowth() | ||
|
||
val model6 = fpg | ||
.setMinSupport(0.9) | ||
.setNumPartitions(1) | ||
.run(rdd) | ||
assert(model6.freqItemsets.count() === 0) | ||
|
||
val model3 = fpg | ||
.setMinSupport(0.5) | ||
.setNumPartitions(2) | ||
.run(rdd) | ||
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => | ||
(items.toSet, count) | ||
} | ||
val expected = Set( | ||
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), | ||
(Set("r"), 3L), | ||
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), | ||
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), | ||
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), | ||
(Set("t", "y", "x"), 3L), | ||
(Set("t", "y", "x", "z"), 3L)) | ||
assert(freqItemsets3.toSet === expected) | ||
|
||
val model2 = fpg | ||
.setMinSupport(0.3) | ||
.setNumPartitions(4) | ||
.run(rdd) | ||
assert(model2.freqItemsets.count() === 54) | ||
|
||
val model1 = fpg | ||
.setMinSupport(0.1) | ||
.setNumPartitions(8) | ||
.run(rdd) | ||
assert(model1.freqItemsets.count() === 625) | ||
} | ||
} |
Oops, something went wrong.