Skip to content

Commit

Permalink
Performance improvements in LocalPrefixSpan, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 12, 2015
1 parent 0c5207c commit 70b93e3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental

import scala.collection.mutable.ArrayBuffer

/**
*
* :: Experimental ::
Expand All @@ -42,22 +44,20 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
prefix: ArrayBuffer[Int],
projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
.map(x => ((prefix :+ x._1).toArray, x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => run(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) {
frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap {
case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB)
}
} else {
frequentPatternAndCounts
frequentPatternAndCounts.iterator
}
}

Expand Down Expand Up @@ -86,28 +86,30 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
ctr + (item -> (ctr(item) + 1))
}
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
* @param prefix the frequent prefixes' prefix
* @param frequentPrefixes frequent next prefixes
* @param projDB projected database for given prefix
* @return extensions of prefix by one item and corresponding projected databases
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
prefix: ArrayBuffer[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { nextItem =>
val nextProjDB = filteredProjectedDatabase
.map(candidateSeq => getSuffix(nextItem, candidateSeq))
.filter(_.nonEmpty)
(prefix :+ nextItem, nextProjDB)
}.filter(x => x._2.nonEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer

/**
*
* :: Experimental ::
Expand Down Expand Up @@ -150,8 +152,8 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD

class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

test("PrefixSpan using Integer type") {

Expand Down Expand Up @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
val sortedActualValue = actualValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
sortedExpectedValue.zip(sortedActualValue)
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
.reduce(_&&_)
expectedValue.map(x => (x._1.toList, x._2)).toSet ==
actualValue.map(x => (x._1.toList, x._2)).toSet
}

val prefixspan = new PrefixSpan()
Expand Down

0 comments on commit 70b93e3

Please sign in to comment.