Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8997][MLlib]Performance improvements in LocalPrefixSpan #7360

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,78 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental

/**
*
* :: Experimental ::
*
* Calculate all patterns of a projected database in local.
*/
@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {

/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase the projected dabase
* @param prefixes prefixes in reversed order
* @param database the projected database
* @return a set of sequential pattern pairs,
* the key of pair is sequential pattern (a list of items),
* the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => run(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
val newPrefixes = item :: prefixes
val newProjected = project(filteredDatabase, item)
Iterator.single((newPrefixes, count)) ++
run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence sequence
* Calculate suffix sequence immediately after the first occurrence of an item.
* @param item item to get suffix after
* @param sequence sequence to extract suffix from
* @return suffix sequence
*/
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(item)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences sequences data
* @return array of item and count pair
* @param minCount the minimum count for an item to be frequent
* @param database database of sequences
* @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
sequence.distinct.foreach { item =>
counts(item) += 1L
}
}
counts.filter(_._2 >= minCount)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call toArray.reverse here

.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD

class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

test("PrefixSpan using Integer type") {

Expand Down Expand Up @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
val sortedActualValue = actualValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
sortedExpectedValue.zip(sortedActualValue)
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
.reduce(_&&_)
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
Expand Down