forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
208 changes: 208 additions & 0 deletions
208
mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fpm | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.broadcast._ | ||
import org.apache.spark.rdd.RDD | ||
|
||
import scala.collection.mutable.{ArrayBuffer, Map} | ||
|
||
/** | ||
* This class implements Parallel FPGrowth algorithm to do frequent pattern matching on input data. | ||
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an | ||
* independent group of mining tasks. More detail of this algorithm can be found at | ||
* http://infolab.stanford.edu/~echang/recsys08-69.pdf | ||
*/ | ||
class FPGrowth private(private var minSupport: Double) extends Logging with Serializable { | ||
|
||
/** | ||
* Constructs a FPGrowth instance with default parameters: | ||
* {minSupport: 0.5} | ||
*/ | ||
def this() = this(0.5) | ||
|
||
/** | ||
* set the minimal support level, default is 0.5 | ||
* @param minSupport minimal support level | ||
*/ | ||
def setMinSupport(minSupport: Double): this.type = { | ||
this.minSupport = minSupport | ||
this | ||
} | ||
|
||
/** | ||
* Compute a FPGrowth Model that contains frequent pattern result. | ||
* @param data input data set | ||
* @return FPGrowth Model | ||
*/ | ||
def run(data: RDD[Array[String]]): FPGrowthModel = { | ||
val model = runAlgorithm(data) | ||
model | ||
} | ||
|
||
/** | ||
* Implementation of PFP. | ||
*/ | ||
private def runAlgorithm(data: RDD[Array[String]]): FPGrowthModel = { | ||
val count = data.count() | ||
val minCount = minSupport * count | ||
val single = generateSingleItem(data, minCount) | ||
val combinations = generateCombinations(data, minCount, single) | ||
new FPGrowthModel(single ++ combinations) | ||
} | ||
|
||
/** | ||
* Generate single item pattern by filtering the input data using minimal support level | ||
*/ | ||
private def generateSingleItem( | ||
data: RDD[Array[String]], | ||
minCount: Double): Array[(String, Int)] = { | ||
data.flatMap(v => v) | ||
.map(v => (v, 1)) | ||
.reduceByKey(_ + _) | ||
.filter(_._2 >= minCount) | ||
.collect() | ||
.distinct | ||
.sortWith(_._2 > _._2) | ||
} | ||
|
||
/** | ||
* Generate combination of items by computing on FPTree, | ||
* the computation is done on each FPTree partitions. | ||
*/ | ||
private def generateCombinations( | ||
data: RDD[Array[String]], | ||
minCount: Double, | ||
singleItem: Array[(String, Int)]): Array[(String, Int)] = { | ||
val single = data.context.broadcast(singleItem) | ||
data.flatMap(basket => createFPTree(basket, single)) | ||
.groupByKey() | ||
.flatMap(partition => runFPTree(partition, minCount)) | ||
.collect() | ||
} | ||
|
||
/** | ||
* Create FP-Tree partition for the giving basket | ||
*/ | ||
private def createFPTree( | ||
basket: Array[String], | ||
singleItem: Broadcast[Array[(String, Int)]]): Array[(String, Array[String])] = { | ||
var output = ArrayBuffer[(String, Array[String])]() | ||
var combination = ArrayBuffer[String]() | ||
val single = singleItem.value | ||
var items = ArrayBuffer[(String, Int)]() | ||
|
||
// Filter the basket by single item pattern | ||
val iterator = basket.iterator | ||
while (iterator.hasNext){ | ||
val item = iterator.next | ||
val opt = single.find(_._1.equals(item)) | ||
if (opt != None) { | ||
items ++= opt | ||
} | ||
} | ||
|
||
// Sort it and create the item combinations | ||
val sortedItems = items.sortWith(_._1 > _._1).sortWith(_._2 > _._2).toArray | ||
val itemIterator = sortedItems.iterator | ||
while (itemIterator.hasNext) { | ||
combination.clear() | ||
val item = itemIterator.next | ||
val firstNItems = sortedItems.take(sortedItems.indexOf(item)) | ||
if (firstNItems.length > 0) { | ||
val iterator = firstNItems.iterator | ||
while (iterator.hasNext) { | ||
val elem = iterator.next | ||
combination += elem._1 | ||
} | ||
output += ((item._1, combination.toArray)) | ||
} | ||
} | ||
output.toArray | ||
} | ||
|
||
/** | ||
* Generate frequent pattern by walking through the FPTree | ||
*/ | ||
private def runFPTree( | ||
partition: (String, Iterable[Array[String]]), | ||
minCount: Double): Array[(String, Int)] = { | ||
val key = partition._1 | ||
val value = partition._2 | ||
val output = ArrayBuffer[(String, Int)]() | ||
val map = Map[String, Int]() | ||
|
||
// Walk through the FPTree partition to generate all combinations that satisfy | ||
// the minimal support level. | ||
var k = 1 | ||
while (k > 0) { | ||
map.clear() | ||
val iterator = value.iterator | ||
while (iterator.hasNext) { | ||
val pattern = iterator.next | ||
if (pattern.length >= k) { | ||
val combination = pattern.toList.combinations(k).toList | ||
val itemIterator = combination.iterator | ||
while (itemIterator.hasNext){ | ||
val item = itemIterator.next | ||
val list2key: List[String] = (item :+ key).sortWith(_ > _) | ||
val newKey = list2key.mkString(" ") | ||
if (map.get(newKey) == None) { | ||
map(newKey) = 1 | ||
} else { | ||
map(newKey) = map.apply(newKey) + 1 | ||
} | ||
} | ||
} | ||
} | ||
var eligible: Array[(String, Int)] = null | ||
if (map.size != 0) { | ||
val candidate = map.filter(_._2 >= minCount) | ||
if (candidate.size != 0) { | ||
eligible = candidate.toArray | ||
output ++= eligible | ||
} | ||
} | ||
if ((eligible == null) || (eligible.length == 0)) { | ||
k = 0 | ||
} else { | ||
k = k + 1 | ||
} | ||
} | ||
output.toArray | ||
} | ||
} | ||
|
||
/** | ||
* Top-level methods for calling FPGrowth. | ||
*/ | ||
object FPGrowth{ | ||
|
||
/** | ||
* Generate a FPGrowth Model using the given minimal support level. | ||
* | ||
* @param data input baskets stored as `RDD[Array[String]]` | ||
* @param minSupport minimal support level, for example 0.5 | ||
*/ | ||
def train(data: RDD[Array[String]], minSupport: Double): FPGrowthModel = { | ||
new FPGrowth().setMinSupport(minSupport).run(data) | ||
} | ||
} | ||
|
24 changes: 24 additions & 0 deletions
24
mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthModel.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.fpm | ||
|
||
/** | ||
* A FPGrowth Model for FPGrowth, each element is a frequent pattern with count. | ||
*/ | ||
class FPGrowthModel (val frequentPattern: Array[(String, Int)]) extends Serializable { | ||
} |
72 changes: 72 additions & 0 deletions
72
mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.fpm | ||
|
||
import org.scalatest.FunSuite | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class FPGrowthSuite extends FunSuite with LocalSparkContext { | ||
|
||
test("test FPGrowth algorithm") | ||
{ | ||
val arr = FPGrowthSuite.createTestData() | ||
|
||
assert(arr.length === 6) | ||
val dataSet = sc.parallelize(arr) | ||
assert(dataSet.count() == 6) | ||
val rdd = dataSet.map(line => line.split(" ")) | ||
assert(rdd.count() == 6) | ||
|
||
val algorithm = new FPGrowth() | ||
algorithm.setMinSupport(0.9) | ||
assert(algorithm.run(rdd).frequentPattern.length == 0) | ||
algorithm.setMinSupport(0.8) | ||
assert(algorithm.run(rdd).frequentPattern.length == 1) | ||
algorithm.setMinSupport(0.7) | ||
assert(algorithm.run(rdd).frequentPattern.length == 1) | ||
algorithm.setMinSupport(0.6) | ||
assert(algorithm.run(rdd).frequentPattern.length == 2) | ||
algorithm.setMinSupport(0.5) | ||
assert(algorithm.run(rdd).frequentPattern.length == 18) | ||
algorithm.setMinSupport(0.4) | ||
assert(algorithm.run(rdd).frequentPattern.length == 18) | ||
algorithm.setMinSupport(0.3) | ||
assert(algorithm.run(rdd).frequentPattern.length == 54) | ||
algorithm.setMinSupport(0.2) | ||
assert(algorithm.run(rdd).frequentPattern.length == 54) | ||
algorithm.setMinSupport(0.1) | ||
assert(algorithm.run(rdd).frequentPattern.length == 625) | ||
} | ||
} | ||
|
||
object FPGrowthSuite | ||
{ | ||
/** | ||
* Create test data set | ||
*/ | ||
def createTestData():Array[String] = | ||
{ | ||
val arr = Array[String]( | ||
"r z h k p", | ||
"z y x w v u t s", | ||
"s x o n r", | ||
"x z y m t s q e", | ||
"z", | ||
"x z y r q t p") | ||
arr | ||
} | ||
} |