From cd5f79fa67903080441b9c2cb91e3a171be827ca Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Mon, 7 Apr 2014 21:25:14 -0700 Subject: [PATCH] Switched to using Scala SQLContext --- .../apache/spark/api/python/PythonRDD.scala | 14 +++++++ python/pyspark/context.py | 16 +++++--- python/pyspark/java_gateway.py | 2 +- .../org/apache/spark/sql/SQLContext.scala | 27 ++++++++++++++ .../org/apache/spark/sql/SchemaRDD.scala | 13 +++++++ .../spark/sql/api/java/JavaSQLContext.scala | 37 +------------------ .../spark/sql/api/java/JavaSchemaRDD.scala | 12 ------ 7 files changed, 66 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 9eb16e1cec050..11ab81f1498ba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -300,6 +300,20 @@ object PythonRDD { } } + def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + // TODO: Figure out why flatMap is necessay for pyspark + iter.flatMap { row => + unpickle.loads(row) match { + case objs: java.util.ArrayList[JMap[String, _]] => objs.map(_.toMap) + // Incase the partition doesn't have a collection + case obj: JMap[String, _] => Seq(obj.toMap) + } + } + } + } + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => val unpickle = new Pickler diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 22a98a7ec955e..b8ac6db974573 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -175,6 +175,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile SparkContext._pythonToJava = SparkContext._jvm.PythonRDD.pythonToJava + SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython if instance: @@ -468,15 +469,18 @@ def __init__(self, sparkContext): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._jsql_ctx = self._jvm.JavaSQLContext(self._jsc) + self._ssql_ctx = self._jvm.SQLContext(self._jsc.sc()) def sql(self, sqlQuery): - return SchemaRDD(self._jsql_ctx.sql(sqlQuery), self) + return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) - def applySchema(self, rdd, fieldNames): - fieldNames = ListConverter().convert(fieldNames, self._sc._gateway._gateway_client) - jrdd = self._sc._pythonToJava(rdd._jrdd) - srdd = self._jsql_ctx.applySchema(jrdd, fieldNames) + def applySchema(self, rdd): + first = rdd.first() + if (rdd.__class__ is SchemaRDD): + raise Exception("Cannot apply schema to %s" % SchemaRDD.__name__) + + jrdd = self._sc._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchema(jrdd.rdd()) return SchemaRDD(srdd, self) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 8b079f7215b4b..d8dd2a65225e1 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -64,6 +64,6 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.api.java.JavaSQLContext") + java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d3d4c56bafe41..fdf28822c59d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,11 +26,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ +import org.apache.spark.api.java.JavaRDD /** * :: AlphaComponent :: @@ -241,4 +243,29 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def debugExec() = DebugQuery(executedPlan).execute().collect() } + + def applySchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + val schema = rdd.first.map { case (fieldName, obj) => + val dataType = obj.getClass match { + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + // case c: Class[_] if c == java.lang.Short.TYPE => ShortType + // case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + // case c: Class[_] if c == java.lang.Long.TYPE => LongType + // case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + // case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + // case c: Class[_] if c == java.lang.Float.TYPE => FloatType + // case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + } + AttributeReference(fieldName, dataType, true)() + }.toSeq + + val rowRdd = rdd.mapPartitions { iter => + iter.map { map => + new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + } + } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 91500416eefaa..f15fb113d3754 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import net.razorvine.pickle.{Pickler, Unpickler} + import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.rdd.RDD @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.api.java.JavaRDD /** * :: AlphaComponent :: @@ -308,4 +311,14 @@ class SchemaRDD( /** FOR INTERNAL USE ONLY */ def analyze = sqlContext.analyzer(logicalPlan) + + def javaToPython: JavaRDD[Array[Byte]] = { + this.mapPartitions { iter => + val unpickle = new Pickler + iter.map { row => + val fields: Array[Any] = (for (i <- 0 to row.length - 1) yield row(i)).toArray + unpickle.dumps(fields) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index bd9fe7fbb0096..4ca4505fbfc5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.api.java import java.beans.{Introspector, PropertyDescriptor} +import java.util.{Map => JMap} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.SQLContext @@ -82,42 +83,6 @@ class JavaSQLContext(sparkContext: JavaSparkContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } - /** - * Applies a schema to an RDD of Array[Any] - */ - def applySchema(rdd: JavaRDD[_], fieldNames: java.util.ArrayList[Any]): JavaSchemaRDD = { - val fields = rdd.first match { - case row: java.util.ArrayList[_] => row.toArray.map(_.getClass) - case row => throw new Exception(s"Rows must be Lists 1 ${row.getClass}") - } - - val schema = fields.zip(fieldNames.toArray).map { case (klass, fieldName) => - val dataType = klass match { - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - // case c: Class[_] if c == java.lang.Short.TYPE => ShortType - // case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - // case c: Class[_] if c == java.lang.Long.TYPE => LongType - // case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - // case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - // case c: Class[_] if c == java.lang.Float.TYPE => FloatType - // case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - } - - println(fieldName.toString) - // TODO: No bueno, fieldName.toString used because I can't figure out the casting - AttributeReference(fieldName.toString, dataType, true)() - } - - val rowRdd = rdd.rdd.mapPartitions { iter => - iter.map { - case row: java.util.ArrayList[_] => new GenericRow(row.toArray.asInstanceOf[Array[Any]]): ScalaRow - case row => throw new Exception(s"Rows must be Lists 2 ${row.getClass}") - } - } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) - } - /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index f068519cc0e5e..d43d672938f51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.api.java -import net.razorvine.pickle.{Pickler, Unpickler} - import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -47,14 +45,4 @@ class JavaSchemaRDD( override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) val rdd = baseSchemaRDD.map(new Row(_)) - - def javaToPython: JavaRDD[Array[Byte]] = { - this.rdd.mapPartitions { iter => - val unpickle = new Pickler - iter.map { row => - val fields: Array[Any] = (for (i <- 0 to row.length - 1) yield row.get(i)).toArray - unpickle.dumps(fields) - } - } - } }