Skip to content

Commit

Permalink
eliminated inferItemType
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Jan 4, 2016
1 parent 7381b31 commit 4f5c5a3
Showing 1 changed file with 10 additions and 37 deletions.
47 changes: 10 additions & 37 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -88,8 +88,8 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {

@Since("2.0.0")
override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
val inferredItemset = FPGrowthModel.SaveLoadV1_0.inferItemType(sc, path)
FPGrowthModel.SaveLoadV1_0.load(sc, path, inferredItemset)
//val inferredItemset = FPGrowthModel.SaveLoadV1_0.inferItemType(sc, path)
FPGrowthModel.SaveLoadV1_0.load(sc, path)
}

private[fpm] object SaveLoadV1_0 {
Expand Down Expand Up @@ -122,41 +122,9 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}

def inferItemType(sc: SparkContext, path: String): FreqItemset[_] = {
val sqlContext = SQLContext.getOrCreate(sc)
val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
val itemsetType = freqItemsets.schema("items").dataType
val freqType = freqItemsets.schema("freq").dataType
require(itemsetType.isInstanceOf[ArrayType],
s"items should be ArrayType, but got $itemsetType")
require(freqType.isInstanceOf[LongType], s"freq should be LongType, but got $freqType")
val itemType = itemsetType.asInstanceOf[ArrayType].elementType
val result = itemType match {
case BooleanType => new FreqItemset(Array[Boolean](), 0L)
case BinaryType => new FreqItemset(Array(Array[Byte]()), 0L)
case StringType => new FreqItemset(Array[String](), 0L)
case ByteType => new FreqItemset(Array[Byte](), 0L)
case ShortType => new FreqItemset(Array[Short](), 0L)
case IntegerType => new FreqItemset(Array[Int](), 0L)
case LongType => new FreqItemset(Array[Long](), 0L)
case FloatType => new FreqItemset(Array[Float](), 0L)
case DoubleType => new FreqItemset(Array[Double](), 0L)
case DateType => new FreqItemset(Array[java.sql.Date](), 0L)
case DecimalType.SYSTEM_DEFAULT => new FreqItemset(Array[java.math.BigDecimal](), 0L)
case TimestampType => new FreqItemset(Array[java.sql.Timestamp](), 0L)
case _: ArrayType => new FreqItemset(Array[Seq[_]](), 0L)
case _: MapType => new FreqItemset(Array[Map[_, _]](), 0L)
case _: StructType => new FreqItemset(Array[Row](), 0L)
case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
result
}

def load[Item: ClassTag](
def load(
sc: SparkContext,
path: String,
inferredItemset: FreqItemset[Item]): FPGrowthModel[Item] = {
path: String): FPGrowthModel[_] = {
implicit val formats = DefaultFormats
val sqlContext = SQLContext.getOrCreate(sc)

Expand All @@ -165,6 +133,11 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
assert(formatVersion == thisFormatVersion)

val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
val sample = freqItemsets.select("items").head().get(0)
loadImpl(freqItemsets, sample)
}

def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x =>
val items = x.getAs[Seq[Item]](0).toArray
val freq = x.getLong(1)
Expand Down

0 comments on commit 4f5c5a3

Please sign in to comment.