Skip to content

Commit

Permalink
Merge pull request #1 from mengxr/SPARK-4001
Browse files Browse the repository at this point in the history
simplify FPTree and update FPGrowth
  • Loading branch information
jackylk committed Feb 2, 2015
2 parents ec21f7d + 7e69725 commit bee3093
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 437 deletions.
180 changes: 92 additions & 88 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer
import java.{util => ju}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import scala.collection.mutable

import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable

/**
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
Expand All @@ -34,125 +36,127 @@ import org.apache.spark.rdd.RDD
*
* @param minSupport the minimal support level of the frequent pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param numPartitions number of partitions used by parallel FP-growth
*/
class FPGrowth private(private var minSupport: Double) extends Logging with Serializable {
class FPGrowth private (
private var minSupport: Double,
private var numPartitions: Int) extends Logging with Serializable {

/**
* Constructs a FPGrowth instance with default parameters:
* {minSupport: 0.3}
* {minSupport: 0.3, numPartitions: auto}
*/
def this() = this(0.3)
def this() = this(0.3, -1)

/**
* set the minimal support level, default is 0.3
* @param minSupport minimal support level
* Sets the minimal support level (default: 0.3).
*/
def setMinSupport(minSupport: Double): this.type = {
this.minSupport = minSupport
this
}

/**
* Compute a FPGrowth Model that contains frequent pattern result.
* Sets the number of partitions used by parallel FP-growth (default: same as input data).
*/
def setNumPartitions(numPartitions: Int): this.type = {
this.numPartitions = numPartitions
this
}

/**
* Computes an FP-Growth model that contains frequent itemsets.
* @param data input data set, each element contains a transaction
* @return FPGrowth Model
* @return an [[FPGrowthModel]]
*/
def run(data: RDD[Array[String]]): FPGrowthModel = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val count = data.count()
val minCount = minSupport * count
val single = generateSingleItem(data, minCount)
val combinations = generateCombinations(data, minCount, single)
val all = single.map(v => (Array[String](v._1), v._2)).union(combinations)
new FPGrowthModel(all.collect())
val minCount = math.ceil(minSupport * count).toLong
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
val partitioner = new HashPartitioner(numParts)
val freqItems = genFreqItems(data, minCount, partitioner)
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
new FPGrowthModel(freqItemsets)
}

/**
* Generate single item pattern by filtering the input data using minimal support level
* @return array of frequent pattern with its count
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
private def generateSingleItem(
private def genFreqItems(
data: RDD[Array[String]],
minCount: Double): RDD[(String, Long)] = {
val single = data.flatMap(v => v.toSet)
.map(v => (v, 1L))
.reduceByKey(_ + _)
minCount: Long,
partitioner: Partitioner): Array[String] = {
data.flatMap { t =>
val uniq = t.toSet
if (t.length != uniq.size) {
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
}
t
}.map(v => (v, 1L))
.reduceByKey(partitioner, _ + _)
.filter(_._2 >= minCount)
.sortBy(_._2)
single
.collect()
.sortBy(-_._2)
.map(_._1)
}

/**
* Generate combination of frequent pattern by computing on FPTree,
* the computation is done on each FPTree partitions.
* @return array of frequent pattern with its count
* Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
* @param data transactions
* @param minCount minimum count for frequent itemsets
* @param freqItems frequent items
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
private def generateCombinations(
private def genFreqItemsets(
data: RDD[Array[String]],
minCount: Double,
singleItem: RDD[(String, Long)]): RDD[(Array[String], Long)] = {
val single = data.context.broadcast(singleItem.collect())
data.flatMap(transaction => createConditionPatternBase(transaction, single))
.aggregateByKey(new FPTree)(
(aggregator, condPattBase) => aggregator.add(condPattBase),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
.flatMap(partition => partition._2.mine(minCount, partition._1))
minCount: Long,
freqItems: Array[String],
partitioner: Partitioner): RDD[(Array[String], Long)] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
(tree, transaction) => tree.add(transaction, 1L),
(tree1, tree2) => tree1.merge(tree2))
.flatMap { case (part, tree) =>
tree.extract(minCount, x => partitioner.getPartition(x) == part)
}.map { case (ranks, count) =>
(ranks.map(i => freqItems(i)).toArray, count)
}
}

/**
* Create FP-Tree partition for the giving basket
* @return an array contains a tuple, whose first element is the single
* item (hash key) and second element is its condition pattern base
* Generates conditional transactions.
* @param transaction a transaction
* @param itemToRank map from item to their rank
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
private def createConditionPatternBase(
private def genCondTransactions(
transaction: Array[String],
singleBC: Broadcast[Array[(String, Long)]]): Array[(String, Array[String])] = {
var output = ArrayBuffer[(String, Array[String])]()
var combination = ArrayBuffer[String]()
var items = ArrayBuffer[(String, Long)]()
val single = singleBC.value
val singleMap = single.toMap

// Filter the basket by single item pattern and sort
// by single item and its count
val candidates = transaction
.filter(singleMap.contains)
.map(item => (item, singleMap(item)))
.sortBy(_._1)
.sortBy(_._2)
.toArray

val itemIterator = candidates.iterator
while (itemIterator.hasNext) {
combination.clear()
val item = itemIterator.next()
val firstNItems = candidates.take(candidates.indexOf(item))
if (firstNItems.length > 0) {
val iterator = firstNItems.iterator
while (iterator.hasNext) {
val elem = iterator.next()
combination += elem._1
}
output += ((item._1, combination.toArray))
itemToRank: Map[String, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
val filtered = transaction.flatMap(itemToRank.get)
ju.Arrays.sort(filtered)
val n = filtered.length
var i = n - 1
while (i >= 0) {
val item = filtered(i)
val part = partitioner.getPartition(item)
if (!output.contains(part)) {
output(part) = filtered.slice(0, i + 1)
}
i -= 1
}
output.toArray
output
}

}

/**
* Top-level methods for calling FPGrowth.
*/
object FPGrowth{

/**
* Generate a FPGrowth Model using the given minimal support level.
*
* @param data input baskets stored as `RDD[Array[String]]`
* @param minSupport minimal support level, for example 0.5
*/
def train(data: RDD[Array[String]], minSupport: Double): FPGrowthModel = {
new FPGrowth().setMinSupport(minSupport).run(data)
}
}

This file was deleted.

Loading

0 comments on commit bee3093

Please sign in to comment.