Skip to content

Commit

Permalink
Use Iterable[Array[_]] over Array[Array[_]] for database
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 28, 2015
1 parent da0091b commit 1235cfc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
Expand All @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
Expand All @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
Expand Down
37 changes: 18 additions & 19 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000
/**
* The maximum number of items allowed in a projected database before local processing. If a
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
private val maxLocalProjDBSize: Long = 10000

/**
* Constructs a default instance with default parameters
Expand All @@ -63,8 +67,7 @@ class PrefixSpan private (
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be in [0, 1].")
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}
Expand All @@ -79,8 +82,7 @@ class PrefixSpan private (
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
// TODO: support unbounded pattern length when maxPatternLength = 0
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
Expand Down Expand Up @@ -119,13 +121,13 @@ class PrefixSpan private (
}.filter(_._2.nonEmpty)
}
}
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs)
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = partitionByProjDBSize(prefixSuffixPairs)

while (largePrefixSuffixPairs.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs)
largePrefixSuffixPairs.unpersist()
val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
largePrefixSuffixPairs = largerPairsPart
largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
smallPrefixSuffixPairs ++= smallerPairsPart
Expand All @@ -136,7 +138,6 @@ class PrefixSpan private (
val projectedDatabase = smallPrefixSuffixPairs
// TODO aggregateByKey
.groupByKey()
.mapValues(_.toArray)
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
allPatternAndCounts ++= nextPatternAndCounts
}
Expand All @@ -145,23 +146,21 @@ class PrefixSpan private (


/**
* Split prefix suffix pairs to two parts:
* Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and
* Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing
* Partitions the prefix-suffix pairs by projected database size.
*
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return small size prefix suffix pairs and big size prefix suffix pairs
* (RDD[prefix, suffix], RDD[prefix, suffix ])
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
* greater than [[maxLocalProjDBSize]]
*/
private def splitPrefixSuffixPairs(
prefixSuffixPairs: RDD[(List[Int], Array[Int])]):
(RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
val prefixToSuffixSize = prefixSuffixPairs
.aggregateByKey(0)(
seqOp = { case (count, suffix) => count + suffix.length },
combOp = { _ + _ })
val smallPrefixes = prefixToSuffixSize
.filter(_._2 <= maxProjectedDBSizeBeforeLocalProcessing)
.map(_._1)
.filter(_._2 <= maxLocalProjDBSize)
.keys
.collect()
.toSet
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
Expand Down Expand Up @@ -214,7 +213,7 @@ class PrefixSpan private (
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(List[Int], Array[Array[Int]])]): RDD[(List[Int], Long)] = {
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
data.flatMap {
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
Expand Down

0 comments on commit 1235cfc

Please sign in to comment.