diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index b164e1843522a..71e66392d64ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,29 +17,34 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.broadcast.Broadcast import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.broadcast._ import org.apache.spark.rdd.RDD -import scala.collection.mutable.{ArrayBuffer, Map} + /** - * This class implements Parallel FPGrowth algorithm to do frequent pattern matching on input data. + * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an * independent group of mining tasks. More detail of this algorithm can be found at - * http://infolab.stanford.edu/~echang/recsys08-69.pdf + * [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be found at + * [[http://dx.doi.org/10.1145/335191.335372, FP-growth]] + * + * @param minSupport the minimal support level of the frequent pattern, any pattern appears more than + * (minSupport * size-of-the-dataset) times will be output */ class FPGrowth private(private var minSupport: Double) extends Logging with Serializable { /** * Constructs a FPGrowth instance with default parameters: - * {minSupport: 0.5} + * {minSupport: 0.3} */ - def this() = this(0.5) + def this() = this(0.3) /** - * set the minimal support level, default is 0.5 + * set the minimal support level, default is 0.3 * @param minSupport minimal support level */ def setMinSupport(minSupport: Double): this.type = { @@ -49,87 +54,82 @@ class FPGrowth private(private var minSupport: Double) extends Logging with Seri /** * Compute a FPGrowth Model that contains frequent pattern result. - * @param data input data set + * @param data input data set, each element contains a transaction * @return FPGrowth Model */ def run(data: RDD[Array[String]]): FPGrowthModel = { - val model = runAlgorithm(data) - model - } - - /** - * Implementation of PFP. - */ - private def runAlgorithm(data: RDD[Array[String]]): FPGrowthModel = { val count = data.count() val minCount = minSupport * count val single = generateSingleItem(data, minCount) val combinations = generateCombinations(data, minCount, single) - new FPGrowthModel(single ++ combinations) + val all = single.map(v => (Array[String](v._1), v._2)).union(combinations) + new FPGrowthModel(all.collect()) } /** * Generate single item pattern by filtering the input data using minimal support level + * @return array of frequent pattern with its count */ private def generateSingleItem( data: RDD[Array[String]], - minCount: Double): Array[(String, Int)] = { - data.flatMap(v => v) - .map(v => (v, 1)) + minCount: Double): RDD[(String, Long)] = { + val single = data.flatMap(v => v.toSet) + .map(v => (v, 1L)) .reduceByKey(_ + _) .filter(_._2 >= minCount) - .collect() - .distinct - .sortWith(_._2 > _._2) + .sortBy(_._2) + single } /** - * Generate combination of items by computing on FPTree, + * Generate combination of frequent pattern by computing on FPTree, * the computation is done on each FPTree partitions. + * @return array of frequent pattern with its count */ private def generateCombinations( data: RDD[Array[String]], minCount: Double, - singleItem: Array[(String, Int)]): Array[(String, Int)] = { - val single = data.context.broadcast(singleItem) - data.flatMap(basket => createFPTree(basket, single)) - .groupByKey() - .flatMap(partition => runFPTree(partition, minCount)) - .collect() + singleItem: RDD[(String, Long)]): RDD[(Array[String], Long)] = { + val single = data.context.broadcast(singleItem.collect()) + data.flatMap(transaction => createConditionPatternBase(transaction, single)) + .aggregateByKey(new FPTree)( + (aggregator, condPattBase) => aggregator.add(condPattBase), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + .flatMap(partition => partition._2.mine(minCount, partition._1)) } /** * Create FP-Tree partition for the giving basket + * @return an array contains a tuple, whose first element is the single + * item (hash key) and second element is its condition pattern base */ - private def createFPTree( - basket: Array[String], - singleItem: Broadcast[Array[(String, Int)]]): Array[(String, Array[String])] = { + private def createConditionPatternBase( + transaction: Array[String], + singleBC: Broadcast[Array[(String, Long)]]): Array[(String, Array[String])] = { var output = ArrayBuffer[(String, Array[String])]() var combination = ArrayBuffer[String]() - val single = singleItem.value - var items = ArrayBuffer[(String, Int)]() - - // Filter the basket by single item pattern - val iterator = basket.iterator - while (iterator.hasNext){ - val item = iterator.next - val opt = single.find(_._1.equals(item)) - if (opt != None) { - items ++= opt - } - } - - // Sort it and create the item combinations - val sortedItems = items.sortWith(_._1 > _._1).sortWith(_._2 > _._2).toArray - val itemIterator = sortedItems.iterator + var items = ArrayBuffer[(String, Long)]() + val single = singleBC.value + val singleMap = single.toMap + + // Filter the basket by single item pattern and sort + // by single item and its count + val candidates = transaction + .filter(singleMap.contains) + .map(item => (item, singleMap(item))) + .sortBy(_._1) + .sortBy(_._2) + .toArray + + val itemIterator = candidates.iterator while (itemIterator.hasNext) { combination.clear() - val item = itemIterator.next - val firstNItems = sortedItems.take(sortedItems.indexOf(item)) + val item = itemIterator.next() + val firstNItems = candidates.take(candidates.indexOf(item)) if (firstNItems.length > 0) { val iterator = firstNItems.iterator while (iterator.hasNext) { - val elem = iterator.next + val elem = iterator.next() combination += elem._1 } output += ((item._1, combination.toArray)) @@ -138,56 +138,6 @@ class FPGrowth private(private var minSupport: Double) extends Logging with Seri output.toArray } - /** - * Generate frequent pattern by walking through the FPTree - */ - private def runFPTree( - partition: (String, Iterable[Array[String]]), - minCount: Double): Array[(String, Int)] = { - val key = partition._1 - val value = partition._2 - val output = ArrayBuffer[(String, Int)]() - val map = Map[String, Int]() - - // Walk through the FPTree partition to generate all combinations that satisfy - // the minimal support level. - var k = 1 - while (k > 0) { - map.clear() - val iterator = value.iterator - while (iterator.hasNext) { - val pattern = iterator.next - if (pattern.length >= k) { - val combination = pattern.toList.combinations(k).toList - val itemIterator = combination.iterator - while (itemIterator.hasNext){ - val item = itemIterator.next - val list2key: List[String] = (item :+ key).sortWith(_ > _) - val newKey = list2key.mkString(" ") - if (map.get(newKey) == None) { - map(newKey) = 1 - } else { - map(newKey) = map.apply(newKey) + 1 - } - } - } - } - var eligible: Array[(String, Int)] = null - if (map.size != 0) { - val candidate = map.filter(_._2 >= minCount) - if (candidate.size != 0) { - eligible = candidate.toArray - output ++= eligible - } - } - if ((eligible == null) || (eligible.length == 0)) { - k = 0 - } else { - k = k + 1 - } - } - output.toArray - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala index cb3348d654733..1f490d6ccdd59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala @@ -20,5 +20,5 @@ package org.apache.spark.mllib.fpm /** * A FPGrowth Model for FPGrowth, each element is a frequent pattern with count. */ -class FPGrowthModel (val frequentPattern: Array[(String, Int)]) extends Serializable { +class FPGrowthModel (val frequentPattern: Array[(Array[String], Long)]) extends Serializable { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala new file mode 100644 index 0000000000000..2dc2631d55232 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.collection.mutable.{ListBuffer, ArrayBuffer, Map} + +class FPTree extends Serializable { + + val root: FPTreeNode = new FPTreeNode(null, 0) + + def add(transaction: Array[String]): this.type = { + var index = 0 + val size = transaction.size + var curr = root + while (index < size) { + if (curr.children.contains(transaction(index))) { + val node = curr.children(transaction(index)) + node.count = node.count + 1 + curr = node + } else { + val newNode = new FPTreeNode(transaction(index), 1) + newNode.parent = curr + curr.children(transaction(index)) = newNode + curr = newNode + } + index = index + 1 + } + + // TODO: in oder to further reduce the amount of data for shuffle, + // remove the same pattern which has the same hash number + this + } + + /** + * merge with the input tree + * @param tree the tree to merge + * @return tree after merge + */ + def merge(tree: FPTree): this.type = { + // merge two trees recursively to remove all duplicated nodes + mergeTree(this.root, tree.root) + this + } + + /** + * merge two trees from their root node + * @param tree1 root node of the tree one + * @param tree2 root node of the tree two + * @return root node after merge + */ + private def mergeTree(tree1: FPTreeNode, tree2: FPTreeNode): FPTreeNode = { + // firstly merge two roots, then iterate on the second tree, merge all children of it to the first tree + require(tree1 != null) + require(tree2 != null) + if (!tree2.isRoot) { + require(tree1.item.equals(tree2.item)) + tree1.count = tree1.count + tree2.count + } + if (!tree2.isLeaf) { + val it = tree2.children.iterator + while (it.hasNext) { + val node = mergeSubTree(tree1, it.next()._2) + tree1.children(node.item) = node + node.parent = tree1 + } + } + tree1 + } + + /** + * merge the second tree into the children of the first tree, if there is a match + * @param tree1Root root node of the tree one + * @param subTree2 the child of the tree two + * @return root node after merge + */ + private def mergeSubTree(tree1Root: FPTreeNode, subTree2: FPTreeNode): FPTreeNode = { + if (tree1Root.children.contains(subTree2.item)) { + mergeTree(tree1Root.children(subTree2.item), subTree2) + } else { + subTree2 + } + } + + /** + * Generate all frequent patterns by mining the FPTree recursively + * @param minCount minimal count + * @param suffix key of this tree + * @return + */ + def mine(minCount: Double, suffix: String): Array[(Array[String], Long)] = { + val condPattBase = expandFPTree(this) + mineFPTree(condPattBase, minCount, suffix) + } + + /** + * This function will walk through the tree and build all conditional pattern base out of it + * @param tree the tree to expand + * @return conditional pattern base + */ + private def expandFPTree(tree: FPTree): ArrayBuffer[ArrayBuffer[String]] = { + var output: ArrayBuffer[ArrayBuffer[String]] = null + if (!tree.root.isLeaf) { + val it = tree.root.children.iterator + while (it.hasNext) { + val childOuput = expandFPTreeNode(it.next()._2) + if (output == null) output = childOuput else output ++= childOuput + } + } + output + } + + /** + * Expand from the input node + * @param node tree node + * @return conditional pattern base + */ + private def expandFPTreeNode(node: FPTreeNode): ArrayBuffer[ArrayBuffer[String]] = { + // Iterate on all children and build the output recursively + val output = new ArrayBuffer[ArrayBuffer[String]]() + for (i <- 0 to node.count - 1) { + output.append(ArrayBuffer[String](node.item)) + } + val it = node.children.iterator + var i = 0 + while (it.hasNext) { + val child = it.next() + val childOutput = expandFPTreeNode(child._2) + require(childOutput.size <= output.size) + for (buffer <- childOutput) { + output(i) ++= buffer + i = i + 1 + } + } + output + } + + /** + * Generate all frequent patterns by combinations of condition pattern base. + * This implementation is different from classical fp-growth algorithm which generate + * FPTree recursively. + * + * @param condPattBase condition pattern base + * @param minCount the minimum count + * @param suffix key of the condition pattern base + * @return frequent item set + */ + private def mineFPTree( + condPattBase: ArrayBuffer[ArrayBuffer[String]], + minCount: Double, + suffix: String): Array[(Array[String], Long)] = { + // frequently item + val key = suffix + // the set of construction CPFTree + val value = condPattBase + + // tree step.start 2th + var k = 1 + // save all frequently item set + val fimSetBuffer = ArrayBuffer[(String, Long)]() + // save step k's lineComList temp value to next step k+1 compute combinations + var lineComListTempBuffer = ArrayBuffer[String]() + // loop the data set from 1 to k while k>0 + while (k > 0) { + // save step k's lineComList temp value + var lineComListBuffer = ListBuffer[List[String]]() + // loop every value to combinations while each value length >= k + for (v <- value) { + val vLen = v.length + if (vLen >= k) { + // calculate each value combinations while each value k == 2 + if (k == 1) { + val lineCom = v.toList.combinations(k) + lineComListBuffer ++= lineCom.toList + } else { + /* if each value length > k,it need calculate the intersect of each value & before combinations */ + val union_lineComListTemp2v = v intersect lineComListTempBuffer.toArray.array + // calculate each value combinations after intersect + if (union_lineComListTemp2v.length >= k) { + val lineCom = union_lineComListTemp2v.toList.combinations(k) + lineComListBuffer ++= lineCom.toList + } + } + } + } + + var lineComList: Array[(String, Long)] = null + // reset + lineComListTempBuffer = ArrayBuffer[String]() + // calculate frequent item set + if (lineComListBuffer != null || lineComListBuffer.size != 0) { + val lineComListTemp = lineComListBuffer + .map( v => ( (v :+ key).sortWith(_ > _),1) ) + .groupBy(_._1) + .map(v => (v._1,v._2.length)) + .filter(_._2 >= minCount) + if ( lineComListTemp != null || lineComListTemp.size != 0) { + lineComList = lineComListTemp + .map(v => (v._1.mkString(" "), v._2.toLong)) + .toArray + fimSetBuffer ++= lineComList + for (lcl <- lineComList) { + lineComListTempBuffer ++= lcl._1.split(" ") + } + } + } + // reset k value + if (lineComList == null || lineComList.length == 0) { + k = 0 + } else { + k = k + 1 + } + } + val fimSetArray = fimSetBuffer + .map(v => (v._1.split(" "), v._2)) + .toArray + fimSetArray + } +} + +class FPTreeNode(val item: String, var count: Int) extends Serializable { + var parent: FPTreeNode = null + val children: Map[String, FPTreeNode] = Map[String, FPTreeNode]() + def isLeaf: Boolean = children.size == 0 + def isRoot: Boolean = parent == null +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index e29399ffe71fe..02181bf2b83e8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.mllib.fpm import org.scalatest.FunSuite + import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { test("test FPGrowth algorithm") { - val arr = FPGrowthSuite.createTestData() + val arr = FPGrowthSuite.createFIMDataSet() - assert(arr.length === 6) + assert(arr.length == 6) val dataSet = sc.parallelize(arr) assert(dataSet.count() == 6) val rdd = dataSet.map(line => line.split(" ")) @@ -58,15 +59,42 @@ object FPGrowthSuite /** * Create test data set */ - def createTestData():Array[String] = - { - val arr = Array[String]( - "r z h k p", - "z y x w v u t s", - "s x o n r", - "x z y m t s q e", - "z", - "x z y r q t p") - arr + def createFIMDataSet():Array[String] = + { + val arr = Array[String]( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + arr + } + + def printTree(tree: FPTree) = printTreeRoot(tree.root, 0) + + private def printTreeRoot(tree: FPTreeNode, level: Int): Unit = { + printNode(tree, level) + if (tree.isLeaf) return + val it = tree.children.iterator + while (it.hasNext) { + val child = it.next() + printTreeRoot(child._2, level + 1) + } + } + + private def printNode(node: FPTreeNode, level: Int) = { + for (i <- 0 to level) { + print("\t") + } + println(node.item + " " + node.count) + } + + def printFrequentPattern(pattern: Array[(Array[String], Long)]) = { + for (a <- pattern) { + a._1.foreach(x => print(x + " ")) + print(a._2) + println } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala new file mode 100644 index 0000000000000..92d415da1ddb7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.scalatest.FunSuite + +class FPTreeSuite extends FunSuite with MLlibTestSparkContext { + + test("add transaction to tree") { + val tree = new FPTree + tree.add(Array[String]("a", "b", "c")) + tree.add(Array[String]("a", "b", "y")) + tree.add(Array[String]("b")) + FPGrowthSuite.printTree(tree) + + assert(tree.root.children.size == 2) + assert(tree.root.children.contains("a")) + assert(tree.root.children("a").item.equals("a")) + assert(tree.root.children("a").count == 2) + assert(tree.root.children.contains("b")) + assert(tree.root.children("b").item.equals("b")) + assert(tree.root.children("b").count == 1) + var child = tree.root.children("a") + assert(child.children.size == 1) + assert(child.children.contains("b")) + assert(child.children("b").item.equals("b")) + assert(child.children("b").count == 2) + child = child.children("b") + assert(child.children.size == 2) + assert(child.children.contains("c")) + assert(child.children.contains("y")) + assert(child.children("c").item.equals("c")) + assert(child.children("y").item.equals("y")) + assert(child.children("c").count == 1) + assert(child.children("y").count == 1) + } + + test("merge tree") { + val tree1 = new FPTree + tree1.add(Array[String]("a", "b", "c")) + tree1.add(Array[String]("a", "b", "y")) + tree1.add(Array[String]("b")) + FPGrowthSuite.printTree(tree1) + + val tree2 = new FPTree + tree2.add(Array[String]("a", "b")) + tree2.add(Array[String]("a", "b", "c")) + tree2.add(Array[String]("a", "b", "c", "d")) + tree2.add(Array[String]("a", "x")) + tree2.add(Array[String]("a", "x", "y")) + tree2.add(Array[String]("c", "n")) + tree2.add(Array[String]("c", "m")) + FPGrowthSuite.printTree(tree2) + + val tree3 = tree1.merge(tree2) + FPGrowthSuite.printTree(tree3) + + assert(tree3.root.children.size == 3) + assert(tree3.root.children("a").count == 7) + assert(tree3.root.children("b").count == 1) + assert(tree3.root.children("c").count == 2) + val child1 = tree3.root.children("a") + assert(child1.children.size == 2) + assert(child1.children("b").count == 5) + assert(child1.children("x").count == 2) + val child2 = child1.children("b") + assert(child2.children.size == 2) + assert(child2.children("y").count == 1) + assert(child2.children("c").count == 3) + val child3 = child2.children("c") + assert(child3.children.size == 1) + assert(child3.children("d").count == 1) + val child4 = child1.children("x") + assert(child4.children.size == 1) + assert(child4.children("y").count == 1) + val child5 = tree3.root.children("c") + assert(child5.children.size == 2) + assert(child5.children("n").count == 1) + assert(child5.children("m").count == 1) + } + + /* + test("expand tree") { + val tree = new FPTree + tree.add(Array[String]("a", "b", "c")) + tree.add(Array[String]("a", "b", "y")) + tree.add(Array[String]("a", "b")) + tree.add(Array[String]("a")) + tree.add(Array[String]("b")) + tree.add(Array[String]("b", "n")) + + FPGrowthSuite.printTree(tree) + val buffer = tree.expandFPTree(tree) + for (a <- buffer) { + a.foreach(x => print(x + " ")) + println + } + } + */ + + test("mine tree") { + val tree = new FPTree + tree.add(Array[String]("a", "b", "c")) + tree.add(Array[String]("a", "b", "y")) + tree.add(Array[String]("a", "b")) + tree.add(Array[String]("a")) + tree.add(Array[String]("b")) + tree.add(Array[String]("b", "n")) + + FPGrowthSuite.printTree(tree) + val buffer = tree.mine(3.0, "t") + + for (a <- buffer) { + a._1.foreach(x => print(x + " ")) + print(a._2) + println + } + val s1 = buffer(0)._1 + val s2 = buffer(1)._1 + val s3 = buffer(2)._1 + assert(s1(1).equals("a")) + assert(s2(1).equals("b")) + assert(s3(1).equals("b")) + assert(s3(2).equals("a")) + assert(buffer(0)._2 == 4) + assert(buffer(1)._2 == 5) + assert(buffer(2)._2 == 3) + } +}