diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 33e381e6d4d66..e056f2146c3f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -84,72 +84,69 @@ class PrefixSpan private ( logWarning("Input data is not cached.") } val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - - var patternsCount = lengthOnePatternsAndCounts.length - var allPatternAndCounts = sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - var currentProjectedDatabase = prefixAndProjectedDatabase - while (patternsCount <= minPatternsBeforeShuffle && - currentProjectedDatabase.count() != 0) { - val (nextPatternAndCounts, nextProjectedDatabase) = - getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase) + val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences) + val prefixSuffixPairs = getPrefixSuffixPairs( + lengthOnePatternsAndCounts.map(_._1).collect(), sequences) + var patternsCount: Long = lengthOnePatternsAndCounts.count() + var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) + var currentPrefixSuffixPairs = prefixSuffixPairs + while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs) patternsCount = nextPatternAndCounts.count().toInt - currentProjectedDatabase = nextProjectedDatabase + currentPrefixSuffixPairs = nextPrefixSuffixPairs allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts } if (patternsCount > 0) { - val groupedProjectedDatabase = currentProjectedDatabase + val projectedDatabase = currentPrefixSuffixPairs .map(x => (x._1.toSeq, x._2)) .groupByKey() .map(x => (x._1.toArray, x._2.toArray)) - val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase) + val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts } allPatternAndCounts } /** - * Get the pattern and counts, and projected database + * Get the pattern and counts, and prefix suffix pairs * @param minCount minimum count - * @param prefixAndProjectedDatabase prefix and projected database, - * @return pattern and counts, and projected database - * (Array[pattern, count], RDD[prefix, projected database ]) + * @param prefixSuffixPairs prefix and suffix pairs, + * @return pattern and counts, and prefix suffix pairs + * (Array[pattern, count], RDD[prefix, suffix ]) */ - private def getPatternCountsAndProjectedDatabase( + private def getPatternCountsAndPrefixSuffixPairs( minCount: Long, - prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]): + prefixSuffixPairs: RDD[(Array[Int], Array[Int])]): (RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = { - val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x => - x._2.distinct.map(y => ((x._1.toSeq, y), 1L)) + val prefixAndFreqentItemAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => + suffix.distinct.map(y => ((prefix.toSeq, y), 1L)) }.reduceByKey(_ + _) .filter(_._2 >= minCount) val patternAndCounts = prefixAndFreqentItemAndCounts - .map(x => (x._1._1.toArray ++ Array(x._1._2), x._2)) - val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length + .map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) } + val prefixlength = prefixSuffixPairs.first()._1.length if (prefixlength + 1 >= maxPatternLength) { - (patternAndCounts, prefixAndProjectedDatabase.filter(x => false)) + (patternAndCounts, prefixSuffixPairs.filter(x => false)) } else { val frequentItemsMap = prefixAndFreqentItemAndCounts - .keys.map(x => (x._1, x._2)) + .keys .groupByKey() .mapValues(_.toSet) .collect .toMap - val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase + val nextPrefixSuffixPairs = prefixSuffixPairs .filter(x => frequentItemsMap.contains(x._1)) - .flatMap { x => - val frequentItemSet = frequentItemsMap(x._1) - val filteredSequence = x._2.filter(frequentItemSet.contains(_)) - val subProjectedDabase = frequentItemSet.map{ y => - (y, LocalPrefixSpan.getSuffix(y, filteredSequence)) + .flatMap { case (prefix, suffix) => + val frequentItemSet = frequentItemsMap(prefix) + val filteredSuffix = suffix.filter(frequentItemSet.contains(_)) + val nextSuffixes = frequentItemSet.map{ item => + (item, LocalPrefixSpan.getSuffix(item, filteredSuffix)) }.filter(_._2.nonEmpty) - subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2)) + nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) } } - (patternAndCounts, nextPrefixAndProjectedDatabase) + (patternAndCounts, nextPrefixSuffixPairs) } } @@ -177,12 +174,12 @@ class PrefixSpan private ( } /** - * Get the frequent prefixes' projected database. + * Get the frequent prefixes and suffix pairs. * @param frequentPrefixes frequent prefixes * @param sequences sequences data - * @return prefixes and projected database + * @return prefixes and suffix pairs. */ - private def getPrefixAndProjectedDatabase( + private def getPrefixSuffixPairs( frequentPrefixes: Array[Int], sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { val filteredSequences = sequences.map { p => @@ -199,7 +196,7 @@ class PrefixSpan private ( /** * calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal(