Skip to content

Commit

Permalink
Add new object LocalPrefixSpan, and do some optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 10, 2015
1 parent ba5df34 commit 574e56c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 102 deletions.
129 changes: 129 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.fpm

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental

/**
*
* :: Experimental ::
*
* Calculate all patterns of a projected database in local.
*/
@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {

/**
* Calculate all patterns of a projected database in local.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase the projected dabase
* @return a set of sequential pattern pairs,
* the key of pair is pattern (a list of elements),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase)
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence sequence
* @return suffix sequence
*/
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
}
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences sequences data
* @return array of item and count pair
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
}

/**
* Calculate all patterns of a projected database in local.
* @param minCount the minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase projected database
* @return patterns
*/
private def getPatternsWithPrefix(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
}
}
}
127 changes: 27 additions & 100 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class PrefixSpan private (
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1)
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
this.minSupport = minSupport
this
}
Expand All @@ -62,7 +63,8 @@ class PrefixSpan private (
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1)
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
Expand All @@ -73,35 +75,38 @@ class PrefixSpan private (
* a sequence is an ordered list of elements.
* @return a set of sequential pattern pairs,
* the key of pair is pattern (a list of elements),
* the value of pair is the pattern's support value.
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getAbsoluteMinSupport(sequences)
val minCount = getMinCount(sequences)
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
findLengthOnePatterns(minCount, sequences)
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
val allPatterns = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) ++ nextPatterns
val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates)
val nextPatterns = getPatternsInLocal(minCount, projectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns
}

/**
* Get the absolute minimum support value (sequences count * minSupport).
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return absolute minimum support value,
* @return minimum count,
*/
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else (sequences.count() * minSupport).toLong
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount the absolute minimum support
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences original sequences data
* @return array of frequent pattern ordered by their frequencies
* @return array of item and count pair
*/
private def getFreqItemAndCounts(
minCount: Long,
Expand All @@ -111,22 +116,6 @@ class PrefixSpan private (
.filter(_._2 >= minCount)
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount the absolute minimum support
* @param sequences sequences data
* @return array of frequent pattern ordered by their frequencies
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param frequentPrefixes frequent prefixes
Expand All @@ -141,44 +130,25 @@ class PrefixSpan private (
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = getSuffix(y, x)
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
}
}.filter(x => x._2.nonEmpty)
}

/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
}.filter(_._2.nonEmpty)
}
}

/**
* Find the patterns that it's length is one
* @param minCount the absolute minimum support
* @param minCount the minimum count
* @param sequences original sequences data
* @return length-one patterns and projection table
*/
private def findLengthOnePatterns(
minCount: Long,
sequences: RDD[Array[Int]]): (RDD[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
(frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
(frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase)
}

/**
Expand All @@ -195,58 +165,15 @@ class PrefixSpan private (

/**
* calculate the patterns in local.
* @param minCount the absolute minimum support
* @param minCount the absolute minimum count
* @param data patterns and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
getPatternsWithPrefix(minCount, x._1, x._2)
}
}

/**
* calculate the patterns with one prefix in local.
* @param minCount the absolute minimum support
* @param prefix prefix
* @param projectedDatabase patterns and projected sequences data
* @return patterns
*/
private def getPatternsWithPrefix(
minCount: Long,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => getPatternsWithPrefix(minCount, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
}
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence sequence
* @return suffix sequence
*/
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.34)
.setMinSupport(0.33)
.setMaxPatternLength(50)
val result1 = prefixspan.run(rdd)
val expectedValue1 = Array(
Expand Down Expand Up @@ -97,7 +97,7 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
)
assert(compareResult(expectedValue2, result2.collect()))

prefixspan.setMinSupport(0.34).setMaxPatternLength(2)
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
val expectedValue3 = Array(
(Array(1), 4L),
Expand Down

0 comments on commit 574e56c

Please sign in to comment.