Skip to content

Commit

Permalink
PR Changes + Method Visibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ahirreddy committed Apr 15, 2014
1 parent 1836944 commit 40491c9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
30 changes: 12 additions & 18 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -286,39 +286,33 @@ private[spark] object PythonRDD {
file.close()
}

def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[_] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
// TODO: Figure out why flatMap is necessay for pyspark
iter.flatMap { row =>
unpickle.loads(row) match {
case objs: java.util.ArrayList[Any] => objs
// Incase the partition doesn't have a collection
case obj => Seq(obj)
}
}
}
}

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps
* TODO: Support more Python types.
*/
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
// TODO: Figure out why flatMap is necessay for pyspark
iter.flatMap { row =>
unpickle.loads(row) match {
case objs: java.util.ArrayList[JMap[String, _]] => objs.map(_.toMap)
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// Incase the partition doesn't have a collection
case obj: JMap[String, _] => Seq(obj.toMap)
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
val unpickle = new Pickler
val pickle = new Pickler
iter.map { row =>
unpickle.dumps(row)
pickle.dumps(row)
}
}
}
Expand Down
23 changes: 11 additions & 12 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
SparkContext._pythonToJava = SparkContext._jvm.PythonRDD.pythonToJava
SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap
SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython

Expand Down Expand Up @@ -481,21 +480,21 @@ def __init__(self, sparkContext):
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> sqlCtx.applySchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> bad_rdd = sc.parallelize([1,2,3])
>>> sqlCtx.applySchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
... "boolean" : True}])
>>> srdd = sqlCtx.applySchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
... x.boolean))
>>> srdd.collect()[0]
(1, u'string', 1.0, 1, True)
Expand All @@ -514,7 +513,7 @@ def _ssql_ctx(self):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext

def applySchema(self, rdd):
def inferSchema(self, rdd):
"""
Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
determine the fields names and types, and then use that to extract all the dictionaries.
Expand All @@ -523,7 +522,7 @@ def applySchema(self, rdd):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}]
True
Expand All @@ -535,7 +534,7 @@ def applySchema(self, rdd):
(SchemaRDD.__name__, rdd.first()))

jrdd = self._sc._pythonToJavaMap(rdd._jrdd)
srdd = self._ssql_ctx.applySchema(jrdd.rdd())
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
return SchemaRDD(srdd, self)

def registerRDDAsTable(self, rdd, tableName):
Expand All @@ -546,7 +545,7 @@ def registerRDDAsTable(self, rdd, tableName):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
"""
if (rdd.__class__ is SchemaRDD):
Expand All @@ -563,7 +562,7 @@ def parquetFile(self, path):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
>>> srdd.collect() == srdd2.collect()
Expand All @@ -580,7 +579,7 @@ def sql(self, sqlQuery):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
Expand All @@ -596,7 +595,7 @@ def table(self, tableName):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.table("table1")
>>> srdd.collect() == srdd2.collect()
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,7 @@ def saveAsParquetFile(self, path):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/test.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
>>> srdd2.collect() == srdd.collect()
Expand All @@ -1461,7 +1461,7 @@ def registerAsTable(self, name):
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.registerAsTable("test")
>>> srdd2 = sqlCtx.sql("select * from test")
>>> srdd.collect() == srdd2.collect()
Expand Down
8 changes: 5 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
def debugExec() = DebugQuery(executedPlan).execute().collect()
}

// TODO: We only support primitive types, add support for nested types. Difficult because java
// objects don't have classTags
def applySchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
/**
* Peek at the first row of the RDD and infer its schema.
* TODO: We only support primitive types, add support for nested types.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
val schema = rdd.first.map { case (fieldName, obj) =>
val dataType = obj.getClass match {
case c: Class[_] if c == classOf[java.lang.String] => StringType
Expand Down
7 changes: 5 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import net.razorvine.pickle.{Pickler, Unpickler}
import net.razorvine.pickle.Pickler

import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
import org.apache.spark.annotation.{AlphaComponent, Experimental}
Expand Down Expand Up @@ -313,12 +313,15 @@ class SchemaRDD(
/** FOR INTERNAL USE ONLY */
def analyze = sqlContext.analyzer(logicalPlan)

def javaToPython: JavaRDD[Array[Byte]] = {
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
val map: JMap[String, Any] = new java.util.HashMap
// TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict].
// Ideally we should be able to pickle an object directly into a Python collection so we
// don't have to create an ArrayList every time.
val arr: java.util.ArrayList[Any] = new java.util.ArrayList
row.zip(fieldNames).foreach { case (obj, name) =>
map.put(name, obj)
Expand Down

0 comments on commit 40491c9

Please sign in to comment.