Skip to content

Commit

Permalink
MengXR code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 15, 2015
1 parent f055d82 commit 9212256
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental

/**
*
* :: Experimental ::
*
* Calculate all patterns of a projected database in local.
*/
@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {

/**
Expand All @@ -43,18 +40,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefix: List[Int],
database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = {
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {

if (database.isEmpty) return Iterator.empty

val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val frequentItems = frequentItemAndCounts.map(_._1).toSet
val frequentPatternAndCounts = frequentItemAndCounts
.map { case (item, count) => ((item :: prefix).reverse.toArray, count) }
.map { case (item, count) => ((item :: prefix), count) }

val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))

if (prefix.length + 1 < maxPatternLength) {
val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item =>
val nextProjected = project(filteredProjectedDatabase, item)
run(minCount, maxPatternLength, item :: prefix, nextProjected)
Expand All @@ -79,7 +76,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}

def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
database
.map(candidateSeq => getSuffix(prefix, candidateSeq))
.filter(_.nonEmpty)
Expand All @@ -93,10 +90,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Iterable[Array[Int]]): Iterable[(Int, Long)] = {
database: Array[Array[Int]]): Iterable[(Int, Long)] = {
database.flatMap(_.distinct)
.foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
ctr + (item -> (ctr(item) + 1))
.foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
ctr(item) += 1
ctr
}
.filter(_._2 >= minCount)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer

/**
*
* :: Experimental ::
Expand Down Expand Up @@ -154,6 +152,7 @@ class PrefixSpan private (
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
}
}
}

0 comments on commit 9212256

Please sign in to comment.