Skip to content

Commit

Permalink
[SPARK-8998] [MLLIB] Distribute PrefixSpan computation for large proj…
Browse files Browse the repository at this point in the history
…ected databases

Continuation of work by zhangjiajin

Closes apache#7412

Author: zhangjiajin <[email protected]>
Author: Feynman Liang <[email protected]>
Author: zhang jiajin <[email protected]>

Closes apache#7783 from feynmanliang/SPARK-8998-improve-distributed and squashes the following commits:

a61943d [Feynman Liang] Collect small patterns to local
4ddf479 [Feynman Liang] Parallelize freqItemCounts
ad23aa9 [zhang jiajin] Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal
87fa021 [Feynman Liang] Improve extend prefix readability
c2caa5c [Feynman Liang] Readability improvements and comments
1235cfc [Feynman Liang] Use Iterable[Array[_]] over Array[Array[_]] for database
da0091b [Feynman Liang] Use lists for prefixes to reuse data
cb2a4fc [Feynman Liang] Inline code for readability
01c9ae9 [Feynman Liang] Add getters
6e149fa [Feynman Liang] Fix splitPrefixSuffixPairs
64271b3 [zhangjiajin] Modified codes according to comments.
d2250b7 [zhangjiajin] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing.
b07e20c [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into CollectEnoughPrefixes
095aa3a [zhangjiajin] Modified the code according to the review comments.
baa2885 [zhangjiajin] Modified the code according to the review comments.
6560c69 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixeSpan
a8fde87 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark
4dd1c8a [zhangjiajin] initialize file before rebase.
078d410 [zhangjiajin] fix a scala style error.
22b0ef4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan.
ca9c4c8 [zhangjiajin] Modified the code according to the review comments.
574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization.
ba5df34 [zhangjiajin] Fix a Scala style error.
4c60fb3 [zhangjiajin] Fix some Scala style errors.
1dd33ad [zhangjiajin] Modified the code according to the review comments.
89bc368 [zhangjiajin] Fixed a Scala style error.
a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala
951fd42 [zhang jiajin] Delete Prefixspan.scala
575995f [zhangjiajin] Modified the code according to the review comments.
91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
  • Loading branch information
zhangjiajin authored and mengxr committed Jul 30, 2015
1 parent c581593 commit d212a31
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
Expand All @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
Expand All @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
Expand Down
203 changes: 147 additions & 56 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand All @@ -43,28 +45,45 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

/**
* The maximum number of items allowed in a projected database before local processing. If a
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
// TODO: make configurable with a better default value, 10000 may be too small
private val maxLocalProjDBSize: Long = 10000

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)

/**
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
def getMinSupport: Double = this.minSupport

/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}

/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
def getMaxPatternLength: Double = this.maxPatternLength

/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
// TODO: support unbounded pattern length when maxPatternLength = 0
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
Expand All @@ -78,81 +97,153 @@ class PrefixSpan private (
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
val sc = sequences.sparkContext

if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns

// Convert min support to a min number of transactions for this dataset
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong

// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
val freqItemCounts = sequences
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.collect()

// Pairs of (length 1 prefix, suffix consisting of frequent items)
val itemSuffixPairs = {
val freqItems = freqItemCounts.map(_._1).toSet
sequences.flatMap { seq =>
val filteredSeq = seq.filter(freqItems.contains(_))
freqItems.flatMap { item =>
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
candidateSuffix match {
case suffix if !suffix.isEmpty => Some((List(item), suffix))
case _ => None
}
}
}
}

// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
// frequent length-one prefixes)
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))

// Remaining work to be locally and distributively processed respectfully
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)

// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
// projected database sizes <= `maxLocalProjDBSize`)
while (pairsForDistributed.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
extendPrefixes(minCount, pairsForDistributed)
pairsForDistributed.unpersist()
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
pairsForDistributed = largerPairsPart
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
pairsForLocal ++= smallerPairsPart
resultsAccumulator ++= nextPatternAndCounts.collect()
}

// Process the small projected databases locally
val remainingResults = getPatternsInLocal(
minCount, sc.parallelize(pairsForLocal, 1).groupByKey())

(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
.map { case (pattern, count) => (pattern.toArray, count) }
}


/**
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return minimum count,
* Partitions the prefix-suffix pairs by projected database size.
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
* greater than [[maxLocalProjDBSize]]
*/
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
val prefixToSuffixSize = prefixSuffixPairs
.aggregateByKey(0)(
seqOp = { case (count, suffix) => count + suffix.length },
combOp = { _ + _ })
val smallPrefixes = prefixToSuffixSize
.filter(_._2 <= maxLocalProjDBSize)
.keys
.collect()
.toSet
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
(small.collect(), large)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences original sequences data
* @return array of item and count pair
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
* and remaining work.
* @param minCount minimum count
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
* prefix, corresponding suffix) pairs.
*/
private def getFreqItemAndCounts(
private def extendPrefixes(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {

// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
val prefixItemPairAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
}

/**
* Get the frequent prefixes' projected database.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPrefixAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
}.filter(_._2.nonEmpty)
}
// Map from prefix to set of possible next items from suffix
val prefixToNextItems = prefixItemPairAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.collect()
.toMap


// Frequent patterns with length N+1 and their corresponding counts
val extendedPrefixAndCounts = prefixItemPairAndCounts
.map { case ((prefix, item), count) => (item :: prefix, count) }

// Remaining work, all prefixes will have length N+1
val extendedPrefixAndSuffix = prefixSuffixPairs
.filter(x => prefixToNextItems.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToNextItems(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
case _ => None
}
}
}

(extendedPrefixAndCounts, extendedPrefixAndSuffix)
}

/**
* calculate the patterns in local.
* Calculate the patterns in local.
* @param minCount the absolute minimum count
* @param data patterns and projected sequences data data
* @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
data.flatMap {
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
.map { case (pattern: List[Int], count: Long) =>
(pattern.reverse, count)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

val rdd = sc.parallelize(sequences, 2).cache()

def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
Expand All @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue1, result1.collect()))
assert(compareResults(expectedValue1, result1.collect()))

prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
Expand All @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
assert(compareResult(expectedValue2, result2.collect()))
assert(compareResults(expectedValue2, result2.collect()))

prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
Expand All @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue3, result3.collect()))
assert(compareResults(expectedValue3, result3.collect()))
}

private def compareResults(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

}

0 comments on commit d212a31

Please sign in to comment.