-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-8998][MLlib] Distribute PrefixSpan computation for large projected databases #7783
Changes from all commits
91fd7e6
575995f
951fd42
a2eb14c
89bc368
1dd33ad
4c60fb3
ba5df34
574e56c
ca9c4c8
22b0ef4
078d410
4dd1c8a
a8fde87
6560c69
baa2885
095aa3a
b07e20c
d2250b7
64271b3
6e149fa
01c9ae9
cb2a4fc
da0091b
1235cfc
c2caa5c
87fa021
ad23aa9
4ddf479
a61943d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
|
||
package org.apache.spark.mllib.fpm | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.rdd.RDD | ||
|
@@ -43,28 +45,45 @@ class PrefixSpan private ( | |
private var minSupport: Double, | ||
private var maxPatternLength: Int) extends Logging with Serializable { | ||
|
||
/** | ||
* The maximum number of items allowed in a projected database before local processing. If a | ||
* projected database exceeds this size, another iteration of distributed PrefixSpan is run. | ||
*/ | ||
// TODO: make configurable with a better default value, 10000 may be too small | ||
private val maxLocalProjDBSize: Long = 10000 | ||
|
||
/** | ||
* Constructs a default instance with default parameters | ||
* {minSupport: `0.1`, maxPatternLength: `10`}. | ||
*/ | ||
def this() = this(0.1, 10) | ||
|
||
/** | ||
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered | ||
* frequent). | ||
*/ | ||
def getMinSupport: Double = this.minSupport | ||
|
||
/** | ||
* Sets the minimal support level (default: `0.1`). | ||
*/ | ||
def setMinSupport(minSupport: Double): this.type = { | ||
require(minSupport >= 0 && minSupport <= 1, | ||
"The minimum support value must be between 0 and 1, including 0 and 1.") | ||
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") | ||
this.minSupport = minSupport | ||
this | ||
} | ||
|
||
/** | ||
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. | ||
*/ | ||
def getMaxPatternLength: Double = this.maxPatternLength | ||
|
||
/** | ||
* Sets maximal pattern length (default: `10`). | ||
*/ | ||
def setMaxPatternLength(maxPatternLength: Int): this.type = { | ||
require(maxPatternLength >= 1, | ||
"The maximum pattern length value must be greater than 0.") | ||
// TODO: support unbounded pattern length when maxPatternLength = 0 | ||
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") | ||
this.maxPatternLength = maxPatternLength | ||
this | ||
} | ||
|
@@ -78,81 +97,153 @@ class PrefixSpan private ( | |
* the value of pair is the pattern's count. | ||
*/ | ||
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { | ||
val sc = sequences.sparkContext | ||
|
||
if (sequences.getStorageLevel == StorageLevel.NONE) { | ||
logWarning("Input data is not cached.") | ||
} | ||
val minCount = getMinCount(sequences) | ||
val lengthOnePatternsAndCounts = | ||
getFreqItemAndCounts(minCount, sequences).collect() | ||
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( | ||
lengthOnePatternsAndCounts.map(_._1), sequences) | ||
val groupedProjectedDatabase = prefixAndProjectedDatabase | ||
.map(x => (x._1.toSeq, x._2)) | ||
.groupByKey() | ||
.map(x => (x._1.toArray, x._2.toArray)) | ||
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) | ||
val lengthOnePatternsAndCountsRdd = | ||
sequences.sparkContext.parallelize( | ||
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) | ||
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns | ||
allPatterns | ||
|
||
// Convert min support to a min number of transactions for this dataset | ||
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong | ||
|
||
// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold | ||
val freqItemCounts = sequences | ||
.flatMap(seq => seq.distinct.map(item => (item, 1L))) | ||
.reduceByKey(_ + _) | ||
.filter(_._2 >= minCount) | ||
.collect() | ||
|
||
// Pairs of (length 1 prefix, suffix consisting of frequent items) | ||
val itemSuffixPairs = { | ||
val freqItems = freqItemCounts.map(_._1).toSet | ||
sequences.flatMap { seq => | ||
val filteredSeq = seq.filter(freqItems.contains(_)) | ||
freqItems.flatMap { item => | ||
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) | ||
candidateSuffix match { | ||
case suffix if !suffix.isEmpty => Some((List(item), suffix)) | ||
case _ => None | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e. | ||
// frequent length-one prefixes) | ||
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) | ||
|
||
// Remaining work to be locally and distributively processed respectfully | ||
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) | ||
|
||
// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have | ||
// projected database sizes <= `maxLocalProjDBSize`) | ||
while (pairsForDistributed.count() != 0) { | ||
val (nextPatternAndCounts, nextPrefixSuffixPairs) = | ||
extendPrefixes(minCount, pairsForDistributed) | ||
pairsForDistributed.unpersist() | ||
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) | ||
pairsForDistributed = largerPairsPart | ||
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) | ||
pairsForLocal ++= smallerPairsPart | ||
resultsAccumulator ++= nextPatternAndCounts.collect() | ||
} | ||
|
||
// Process the small projected databases locally | ||
val remainingResults = getPatternsInLocal( | ||
minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) | ||
|
||
(sc.parallelize(resultsAccumulator, 1) ++ remainingResults) | ||
.map { case (pattern, count) => (pattern.toArray, count) } | ||
} | ||
|
||
|
||
/** | ||
* Get the minimum count (sequences count * minSupport). | ||
* @param sequences input data set, contains a set of sequences, | ||
* @return minimum count, | ||
* Partitions the prefix-suffix pairs by projected database size. | ||
* @param prefixSuffixPairs prefix (length n) and suffix pairs, | ||
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or | ||
* greater than [[maxLocalProjDBSize]] | ||
*/ | ||
private def getMinCount(sequences: RDD[Array[Int]]): Long = { | ||
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong | ||
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) | ||
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { | ||
val prefixToSuffixSize = prefixSuffixPairs | ||
.aggregateByKey(0)( | ||
seqOp = { case (count, suffix) => count + suffix.length }, | ||
combOp = { _ + _ }) | ||
val smallPrefixes = prefixToSuffixSize | ||
.filter(_._2 <= maxLocalProjDBSize) | ||
.keys | ||
.collect() | ||
.toSet | ||
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } | ||
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } | ||
(small.collect(), large) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since This means that we are assuming all the frequent patterns will fit on the driver. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Driver still stores the data in that case. See https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala. |
||
} | ||
|
||
/** | ||
* Generates frequent items by filtering the input data using minimal count level. | ||
* @param minCount the absolute minimum count | ||
* @param sequences original sequences data | ||
* @return array of item and count pair | ||
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes | ||
* and remaining work. | ||
* @param minCount minimum count | ||
* @param prefixSuffixPairs prefix (length N) and suffix pairs, | ||
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended | ||
* prefix, corresponding suffix) pairs. | ||
*/ | ||
private def getFreqItemAndCounts( | ||
private def extendPrefixes( | ||
minCount: Long, | ||
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { | ||
sequences.flatMap(_.distinct.map((_, 1L))) | ||
prefixSuffixPairs: RDD[(List[Int], Array[Int])]) | ||
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { | ||
|
||
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences | ||
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` | ||
val prefixItemPairAndCounts = prefixSuffixPairs | ||
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } | ||
.reduceByKey(_ + _) | ||
.filter(_._2 >= minCount) | ||
} | ||
|
||
/** | ||
* Get the frequent prefixes' projected database. | ||
* @param frequentPrefixes frequent prefixes | ||
* @param sequences sequences data | ||
* @return prefixes and projected database | ||
*/ | ||
private def getPrefixAndProjectedDatabase( | ||
frequentPrefixes: Array[Int], | ||
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { | ||
val filteredSequences = sequences.map { p => | ||
p.filter (frequentPrefixes.contains(_) ) | ||
} | ||
filteredSequences.flatMap { x => | ||
frequentPrefixes.map { y => | ||
val sub = LocalPrefixSpan.getSuffix(y, x) | ||
(Array(y), sub) | ||
}.filter(_._2.nonEmpty) | ||
} | ||
// Map from prefix to set of possible next items from suffix | ||
val prefixToNextItems = prefixItemPairAndCounts | ||
.keys | ||
.groupByKey() | ||
.mapValues(_.toSet) | ||
.collect() | ||
.toMap | ||
|
||
|
||
// Frequent patterns with length N+1 and their corresponding counts | ||
val extendedPrefixAndCounts = prefixItemPairAndCounts | ||
.map { case ((prefix, item), count) => (item :: prefix, count) } | ||
|
||
// Remaining work, all prefixes will have length N+1 | ||
val extendedPrefixAndSuffix = prefixSuffixPairs | ||
.filter(x => prefixToNextItems.contains(x._1)) | ||
.flatMap { case (prefix, suffix) => | ||
val frequentNextItems = prefixToNextItems(prefix) | ||
val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) | ||
frequentNextItems.flatMap { item => | ||
LocalPrefixSpan.getSuffix(item, filteredSuffix) match { | ||
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) | ||
case _ => None | ||
} | ||
} | ||
} | ||
|
||
(extendedPrefixAndCounts, extendedPrefixAndSuffix) | ||
} | ||
|
||
/** | ||
* calculate the patterns in local. | ||
* Calculate the patterns in local. | ||
* @param minCount the absolute minimum count | ||
* @param data patterns and projected sequences data data | ||
* @param data prefixes and projected sequences data data | ||
* @return patterns | ||
*/ | ||
private def getPatternsInLocal( | ||
minCount: Long, | ||
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { | ||
data.flatMap { case (prefix, projDB) => | ||
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) | ||
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } | ||
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { | ||
data.flatMap { | ||
case (prefix, projDB) => | ||
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) | ||
.map { case (pattern: List[Int], count: Long) => | ||
(pattern.reverse, count) | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause all results except for those generated from
pairsForLocal
to be collected to driver since we continue processing untilpairsForDistributed
is empty.Could potentially be many times the size of the dataset since a length k sequence has up to 2^k subsequences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the worst case. We should assume that the number of frequent patterns are small. Having 1 billion frequent patterns doesn't provide any useful insights. So users should start with a high
minSupport
and collect just-enough number of frequent patterns.