Skip to content

Commit

Permalink
update LocalPrefixSpan impl
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jul 15, 2015
1 parent 9212256 commit 91e4357
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,25 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param database the projected dabase
* @param prefixes prefixes in reversed order
* @param database the projected database
* @return a set of sequential pattern pairs,
* the key of pair is sequential pattern (a list of items),
* the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: List[Int],
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {

if (database.isEmpty) return Iterator.empty

if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val frequentItems = frequentItemAndCounts.map(_._1).toSet
val frequentPatternAndCounts = frequentItemAndCounts
.map { case (item, count) => ((item :: prefix), count) }


if (prefix.length + 1 < maxPatternLength) {
val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item =>
val nextProjected = project(filteredProjectedDatabase, item)
run(minCount, maxPatternLength, item :: prefix, nextProjected)
}
} else {
frequentPatternAndCounts.iterator
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
val newPrefixes = item :: prefixes
val newProjected = project(filteredDatabase, item)
Iterator.single((newPrefixes, count)) ++
run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}

Expand All @@ -78,24 +69,26 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
database
.map(candidateSeq => getSuffix(prefix, candidateSeq))
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the minimum count for an item to be frequent
* @param database database of sequences
* @return item and count pairs
* @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Array[Array[Int]]): Iterable[(Int, Long)] = {
database.flatMap(_.distinct)
.foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
ctr(item) += 1
ctr
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
sequence.distinct.foreach { item =>
counts(item) += 1L
}
.filter(_._2 >= minCount)
}
counts.filter(_._2 >= minCount)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toList, x._2)).toSet ==
actualValue.map(x => (x._1.toList, x._2)).toSet
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
Expand Down

0 comments on commit 91e4357

Please sign in to comment.