Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8998][MLlib] Distribute PrefixSpan computation for large projected databases #7783

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
91fd7e6
Add new algorithm PrefixSpan and test file.
zhangjiajin Jul 7, 2015
575995f
Modified the code according to the review comments.
zhangjiajin Jul 8, 2015
951fd42
Delete Prefixspan.scala
zhangjiajin Jul 8, 2015
a2eb14c
Delete PrefixspanSuite.scala
zhangjiajin Jul 8, 2015
89bc368
Fixed a Scala style error.
zhangjiajin Jul 8, 2015
1dd33ad
Modified the code according to the review comments.
zhangjiajin Jul 9, 2015
4c60fb3
Fix some Scala style errors.
zhangjiajin Jul 9, 2015
ba5df34
Fix a Scala style error.
zhangjiajin Jul 9, 2015
574e56c
Add new object LocalPrefixSpan, and do some optimization.
zhangjiajin Jul 10, 2015
ca9c4c8
Modified the code according to the review comments.
zhangjiajin Jul 11, 2015
22b0ef4
Add feature: Collect enough frequent prefixes before projection in Pr…
zhangjiajin Jul 14, 2015
078d410
fix a scala style error.
zhangjiajin Jul 14, 2015
4dd1c8a
initialize file before rebase.
zhangjiajin Jul 15, 2015
a8fde87
Merge branch 'master' of https://github.com/apache/spark
zhangjiajin Jul 15, 2015
6560c69
Add feature: Collect enough frequent prefixes before projection in Pr…
zhangjiajin Jul 15, 2015
baa2885
Modified the code according to the review comments.
zhangjiajin Jul 15, 2015
095aa3a
Modified the code according to the review comments.
zhangjiajin Jul 16, 2015
b07e20c
Merge branch 'master' of https://github.com/apache/spark into Collect…
zhangjiajin Jul 16, 2015
d2250b7
remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalPr…
zhangjiajin Jul 18, 2015
64271b3
Modified codes according to comments.
zhangjiajin Jul 27, 2015
6e149fa
Fix splitPrefixSuffixPairs
Jul 28, 2015
01c9ae9
Add getters
Jul 28, 2015
cb2a4fc
Inline code for readability
Jul 28, 2015
da0091b
Use lists for prefixes to reuse data
Jul 28, 2015
1235cfc
Use Iterable[Array[_]] over Array[Array[_]] for database
Jul 28, 2015
c2caa5c
Readability improvements and comments
Jul 28, 2015
87fa021
Improve extend prefix readability
Jul 28, 2015
ad23aa9
Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal
zhangjiajin Jul 29, 2015
4ddf479
Parallelize freqItemCounts
Jul 30, 2015
a61943d
Collect small patterns to local
Jul 30, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
Expand All @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
Expand All @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
Expand Down
203 changes: 147 additions & 56 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand All @@ -43,28 +45,45 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

/**
* The maximum number of items allowed in a projected database before local processing. If a
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
// TODO: make configurable with a better default value, 10000 may be too small
private val maxLocalProjDBSize: Long = 10000

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)

/**
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
def getMinSupport: Double = this.minSupport

/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}

/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
def getMaxPatternLength: Double = this.maxPatternLength

/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
// TODO: support unbounded pattern length when maxPatternLength = 0
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
Expand All @@ -78,81 +97,153 @@ class PrefixSpan private (
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
val sc = sequences.sparkContext

if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns

// Convert min support to a min number of transactions for this dataset
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong

// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
val freqItemCounts = sequences
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.collect()

// Pairs of (length 1 prefix, suffix consisting of frequent items)
val itemSuffixPairs = {
val freqItems = freqItemCounts.map(_._1).toSet
sequences.flatMap { seq =>
val filteredSeq = seq.filter(freqItems.contains(_))
freqItems.flatMap { item =>
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
candidateSuffix match {
case suffix if !suffix.isEmpty => Some((List(item), suffix))
case _ => None
}
}
}
}

// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
// frequent length-one prefixes)
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))

// Remaining work to be locally and distributively processed respectfully
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)

// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
// projected database sizes <= `maxLocalProjDBSize`)
while (pairsForDistributed.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
extendPrefixes(minCount, pairsForDistributed)
pairsForDistributed.unpersist()
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
pairsForDistributed = largerPairsPart
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
pairsForLocal ++= smallerPairsPart
resultsAccumulator ++= nextPatternAndCounts.collect()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause all results except for those generated from pairsForLocal to be collected to driver since we continue processing until pairsForDistributed is empty.

Could potentially be many times the size of the dataset since a length k sequence has up to 2^k subsequences.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the worst case. We should assume that the number of frequent patterns are small. Having 1 billion frequent patterns doesn't provide any useful insights. So users should start with a high minSupport and collect just-enough number of frequent patterns.

}

// Process the small projected databases locally
val remainingResults = getPatternsInLocal(
minCount, sc.parallelize(pairsForLocal, 1).groupByKey())

(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
.map { case (pattern, count) => (pattern.toArray, count) }
}


/**
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return minimum count,
* Partitions the prefix-suffix pairs by projected database size.
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
* greater than [[maxLocalProjDBSize]]
*/
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
val prefixToSuffixSize = prefixSuffixPairs
.aggregateByKey(0)(
seqOp = { case (count, suffix) => count + suffix.length },
combOp = { _ + _ })
val smallPrefixes = prefixToSuffixSize
.filter(_._2 <= maxLocalProjDBSize)
.keys
.collect()
.toSet
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
(small.collect(), large)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since small.collect() == smallerPairsPart in run() and pairsForLocal ++= smallerPairsPart, we will end up collecting both pairsForLocal and resultsAccumulator to the driver.

This means that we are assuming all the frequent patterns will fit on the driver.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead collect then immediately parallelize again to break lineage but still prevent storing the entire results set on the driver?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences original sequences data
* @return array of item and count pair
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
* and remaining work.
* @param minCount minimum count
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
* prefix, corresponding suffix) pairs.
*/
private def getFreqItemAndCounts(
private def extendPrefixes(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {

// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
val prefixItemPairAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
}

/**
* Get the frequent prefixes' projected database.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPrefixAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
}.filter(_._2.nonEmpty)
}
// Map from prefix to set of possible next items from suffix
val prefixToNextItems = prefixItemPairAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.collect()
.toMap


// Frequent patterns with length N+1 and their corresponding counts
val extendedPrefixAndCounts = prefixItemPairAndCounts
.map { case ((prefix, item), count) => (item :: prefix, count) }

// Remaining work, all prefixes will have length N+1
val extendedPrefixAndSuffix = prefixSuffixPairs
.filter(x => prefixToNextItems.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToNextItems(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
case _ => None
}
}
}

(extendedPrefixAndCounts, extendedPrefixAndSuffix)
}

/**
* calculate the patterns in local.
* Calculate the patterns in local.
* @param minCount the absolute minimum count
* @param data patterns and projected sequences data data
* @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
data.flatMap {
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
.map { case (pattern: List[Int], count: Long) =>
(pattern.reverse, count)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

val rdd = sc.parallelize(sequences, 2).cache()

def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
Expand All @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue1, result1.collect()))
assert(compareResults(expectedValue1, result1.collect()))

prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
Expand All @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
assert(compareResult(expectedValue2, result2.collect()))
assert(compareResults(expectedValue2, result2.collect()))

prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
Expand All @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue3, result3.collect()))
assert(compareResults(expectedValue3, result3.collect()))
}

private def compareResults(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

}