From 4f5c5a3e852b19a9a0d9b776bb6866ff6eb6921b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 4 Jan 2016 13:59:48 -0800 Subject: [PATCH] eliminated inferItemType --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 47 ++++--------------- 1 file changed, 10 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 06792f2acf37e..821161e904606 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.types._ @@ -88,8 +88,8 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { @Since("2.0.0") override def load(sc: SparkContext, path: String): FPGrowthModel[_] = { - val inferredItemset = FPGrowthModel.SaveLoadV1_0.inferItemType(sc, path) - FPGrowthModel.SaveLoadV1_0.load(sc, path, inferredItemset) + //val inferredItemset = FPGrowthModel.SaveLoadV1_0.inferItemType(sc, path) + FPGrowthModel.SaveLoadV1_0.load(sc, path) } private[fpm] object SaveLoadV1_0 { @@ -122,41 +122,9 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } - def inferItemType(sc: SparkContext, path: String): FreqItemset[_] = { - val sqlContext = SQLContext.getOrCreate(sc) - val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) - val itemsetType = freqItemsets.schema("items").dataType - val freqType = freqItemsets.schema("freq").dataType - require(itemsetType.isInstanceOf[ArrayType], - s"items should be ArrayType, but got $itemsetType") - require(freqType.isInstanceOf[LongType], s"freq should be LongType, but got $freqType") - val itemType = itemsetType.asInstanceOf[ArrayType].elementType - val result = itemType match { - case BooleanType => new FreqItemset(Array[Boolean](), 0L) - case BinaryType => new FreqItemset(Array(Array[Byte]()), 0L) - case StringType => new FreqItemset(Array[String](), 0L) - case ByteType => new FreqItemset(Array[Byte](), 0L) - case ShortType => new FreqItemset(Array[Short](), 0L) - case IntegerType => new FreqItemset(Array[Int](), 0L) - case LongType => new FreqItemset(Array[Long](), 0L) - case FloatType => new FreqItemset(Array[Float](), 0L) - case DoubleType => new FreqItemset(Array[Double](), 0L) - case DateType => new FreqItemset(Array[java.sql.Date](), 0L) - case DecimalType.SYSTEM_DEFAULT => new FreqItemset(Array[java.math.BigDecimal](), 0L) - case TimestampType => new FreqItemset(Array[java.sql.Timestamp](), 0L) - case _: ArrayType => new FreqItemset(Array[Seq[_]](), 0L) - case _: MapType => new FreqItemset(Array[Map[_, _]](), 0L) - case _: StructType => new FreqItemset(Array[Row](), 0L) - case other => - throw new UnsupportedOperationException(s"Schema for type $other is not supported") - } - result - } - - def load[Item: ClassTag]( + def load( sc: SparkContext, - path: String, - inferredItemset: FreqItemset[Item]): FPGrowthModel[Item] = { + path: String): FPGrowthModel[_] = { implicit val formats = DefaultFormats val sqlContext = SQLContext.getOrCreate(sc) @@ -165,6 +133,11 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { assert(formatVersion == thisFormatVersion) val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqItemsets.select("items").head().get(0) + loadImpl(freqItemsets, sample) + } + + def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1)