diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala new file mode 100644 index 0000000000000..9591c7966e06a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import java.{util => ju} + +import scala.collection.mutable + +import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable + +/** + * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. + * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an + * independent group of mining tasks. More detail of this algorithm can be found at + * [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be + * found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]] + * + * @param minSupport the minimal support level of the frequent pattern, any pattern appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param numPartitions number of partitions used by parallel FP-growth + */ +class FPGrowth private ( + private var minSupport: Double, + private var numPartitions: Int) extends Logging with Serializable { + + /** + * Constructs a FPGrowth instance with default parameters: + * {minSupport: 0.3, numPartitions: auto} + */ + def this() = this(0.3, -1) + + /** + * Sets the minimal support level (default: 0.3). + */ + def setMinSupport(minSupport: Double): this.type = { + this.minSupport = minSupport + this + } + + /** + * Sets the number of partitions used by parallel FP-growth (default: same as input data). + */ + def setNumPartitions(numPartitions: Int): this.type = { + this.numPartitions = numPartitions + this + } + + /** + * Computes an FP-Growth model that contains frequent itemsets. + * @param data input data set, each element contains a transaction + * @return an [[FPGrowthModel]] + */ + def run(data: RDD[Array[String]]): FPGrowthModel = { + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + val count = data.count() + val minCount = math.ceil(minSupport * count).toLong + val numParts = if (numPartitions > 0) numPartitions else data.partitions.length + val partitioner = new HashPartitioner(numParts) + val freqItems = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) + new FPGrowthModel(freqItemsets) + } + + /** + * Generates frequent items by filtering the input data using minimal support level. + * @param minCount minimum count for frequent itemsets + * @param partitioner partitioner used to distribute items + * @return array of frequent pattern ordered by their frequencies + */ + private def genFreqItems( + data: RDD[Array[String]], + minCount: Long, + partitioner: Partitioner): Array[String] = { + data.flatMap { t => + val uniq = t.toSet + if (t.length != uniq.size) { + throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") + } + t + }.map(v => (v, 1L)) + .reduceByKey(partitioner, _ + _) + .filter(_._2 >= minCount) + .collect() + .sortBy(-_._2) + .map(_._1) + } + + /** + * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition. + * @param data transactions + * @param minCount minimum count for frequent itemsets + * @param freqItems frequent items + * @param partitioner partitioner used to distribute transactions + * @return an RDD of (frequent itemset, count) + */ + private def genFreqItemsets( + data: RDD[Array[String]], + minCount: Long, + freqItems: Array[String], + partitioner: Partitioner): RDD[(Array[String], Long)] = { + val itemToRank = freqItems.zipWithIndex.toMap + data.flatMap { transaction => + genCondTransactions(transaction, itemToRank, partitioner) + }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)( + (tree, transaction) => tree.add(transaction, 1L), + (tree1, tree2) => tree1.merge(tree2)) + .flatMap { case (part, tree) => + tree.extract(minCount, x => partitioner.getPartition(x) == part) + }.map { case (ranks, count) => + (ranks.map(i => freqItems(i)).toArray, count) + } + } + + /** + * Generates conditional transactions. + * @param transaction a transaction + * @param itemToRank map from item to their rank + * @param partitioner partitioner used to distribute transactions + * @return a map of (target partition, conditional transaction) + */ + private def genCondTransactions( + transaction: Array[String], + itemToRank: Map[String, Int], + partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { + val output = mutable.Map.empty[Int, Array[Int]] + // Filter the basket by frequent items pattern and sort their ranks. + val filtered = transaction.flatMap(itemToRank.get) + ju.Arrays.sort(filtered) + val n = filtered.length + var i = n - 1 + while (i >= 0) { + val item = filtered(i) + val part = partitioner.getPartition(item) + if (!output.contains(part)) { + output(part) = filtered.slice(0, i + 1) + } + i -= 1 + } + output + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala new file mode 100644 index 0000000000000..1d2d777c00793 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * FP-Tree data structure used in FP-Growth. + * @tparam T item type + */ +private[fpm] class FPTree[T] extends Serializable { + + import FPTree._ + + val root: Node[T] = new Node(null) + + private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty + + /** Adds a transaction with count. */ + def add(t: Iterable[T], count: Long = 1L): this.type = { + require(count > 0) + var curr = root + curr.count += count + t.foreach { item => + val summary = summaries.getOrElseUpdate(item, new Summary) + summary.count += count + val child = curr.children.getOrElseUpdate(item, { + val newNode = new Node(curr) + newNode.item = item + summary.nodes += newNode + newNode + }) + child.count += count + curr = child + } + this + } + + /** Merges another FP-Tree. */ + def merge(other: FPTree[T]): this.type = { + other.transactions.foreach { case (t, c) => + add(t, c) + } + this + } + + /** Gets a subtree with the suffix. */ + private def project(suffix: T): FPTree[T] = { + val tree = new FPTree[T] + if (summaries.contains(suffix)) { + val summary = summaries(suffix) + summary.nodes.foreach { node => + var t = List.empty[T] + var curr = node.parent + while (!curr.isRoot) { + t = curr.item :: t + curr = curr.parent + } + tree.add(t, node.count) + } + } + tree + } + + /** Returns all transactions in an iterator. */ + def transactions: Iterator[(List[T], Long)] = getTransactions(root) + + /** Returns all transactions under this node. */ + private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = { + var count = node.count + node.children.iterator.flatMap { case (item, child) => + getTransactions(child).map { case (t, c) => + count -= c + (item :: t, c) + } + } ++ { + if (count > 0) { + Iterator.single((Nil, count)) + } else { + Iterator.empty + } + } + } + + /** Extracts all patterns with valid suffix and minimum count. */ + def extract( + minCount: Long, + validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = { + summaries.iterator.flatMap { case (item, summary) => + if (validateSuffix(item) && summary.count >= minCount) { + Iterator.single((item :: Nil, summary.count)) ++ + project(item).extract(minCount).map { case (t, c) => + (item :: t, c) + } + } else { + Iterator.empty + } + } + } +} + +private[fpm] object FPTree { + + /** Representing a node in an FP-Tree. */ + class Node[T](val parent: Node[T]) extends Serializable { + var item: T = _ + var count: Long = 0L + val children: mutable.Map[T, Node[T]] = mutable.Map.empty + + def isRoot: Boolean = parent == null + } + + /** Summary of a item in an FP-Tree. */ + private class Summary[T] extends Serializable { + var count: Long = 0L + val nodes: ListBuffer[Node[T]] = ListBuffer.empty + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala new file mode 100644 index 0000000000000..71ef60da6dd32 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { + + test("FP-Growth") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => + (items.toSet, count) + } + val expected = Set( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 54) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 625) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala new file mode 100644 index 0000000000000..04017f67c311d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.language.existentials + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPTreeSuite extends FunSuite with MLlibTestSparkContext { + + test("add transaction") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + assert(tree.root.children.size == 2) + assert(tree.root.children.contains("a")) + assert(tree.root.children("a").item.equals("a")) + assert(tree.root.children("a").count == 2) + assert(tree.root.children.contains("b")) + assert(tree.root.children("b").item.equals("b")) + assert(tree.root.children("b").count == 1) + var child = tree.root.children("a") + assert(child.children.size == 1) + assert(child.children.contains("b")) + assert(child.children("b").item.equals("b")) + assert(child.children("b").count == 2) + child = child.children("b") + assert(child.children.size == 2) + assert(child.children.contains("c")) + assert(child.children.contains("y")) + assert(child.children("c").item.equals("c")) + assert(child.children("y").item.equals("y")) + assert(child.children("c").count == 1) + assert(child.children("y").count == 1) + } + + test("merge tree") { + val tree1 = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + val tree2 = new FPTree[String] + .add(Seq("a", "b")) + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "c", "d")) + .add(Seq("a", "x")) + .add(Seq("a", "x", "y")) + .add(Seq("c", "n")) + .add(Seq("c", "m")) + + val tree3 = tree1.merge(tree2) + + assert(tree3.root.children.size == 3) + assert(tree3.root.children("a").count == 7) + assert(tree3.root.children("b").count == 1) + assert(tree3.root.children("c").count == 2) + val child1 = tree3.root.children("a") + assert(child1.children.size == 2) + assert(child1.children("b").count == 5) + assert(child1.children("x").count == 2) + val child2 = child1.children("b") + assert(child2.children.size == 2) + assert(child2.children("y").count == 1) + assert(child2.children("c").count == 3) + val child3 = child2.children("c") + assert(child3.children.size == 1) + assert(child3.children("d").count == 1) + val child4 = child1.children("x") + assert(child4.children.size == 1) + assert(child4.children("y").count == 1) + val child5 = tree3.root.children("c") + assert(child5.children.size == 2) + assert(child5.children("n").count == 1) + assert(child5.children("m").count == 1) + } + + test("extract freq itemsets") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("a", "b")) + .add(Seq("a")) + .add(Seq("b")) + .add(Seq("b", "n")) + + val freqItemsets = tree.extract(3L).map { case (items, count) => + (items.toSet, count) + }.toSet + val expected = Set( + (Set("a"), 4L), + (Set("b"), 5L), + (Set("a", "b"), 3L)) + assert(freqItemsets === expected) + } +}