Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ahirreddy committed Apr 15, 2014
1 parent 7515ba0 commit 79f739d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 22 deletions.
66 changes: 44 additions & 22 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,19 +475,13 @@ def __init__(self, sparkContext):
@param sparkContext: The SparkContext to wrap.
# SQLContext
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
# applySchema
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]
True
>>> sqlCtx.applySchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Expand All @@ -498,21 +492,6 @@ def __init__(self, sparkContext):
Traceback (most recent call last):
...
ValueError:...
# registerRDDAsTable
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
# sql
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, {"f1" : 3, "f2": "row3"}]
True
# table
#>>> sqlCtx.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
#>>> sqlCtx.sql('INSERT INTO src (key, value) VALUES (1, "one")')
#>>> sqlCtx.sql('INSERT INTO src (key, value) VALUES (2, "two")')
#>>> srdd3 = sqlCtx.table("src")
#>>> srdd3.collect() == [{"key" : 1, "value" : "one"}, {"key" : 2, "value": "two"}]
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
Expand All @@ -523,6 +502,14 @@ def applySchema(self, rdd):
"""
Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
determine the fields names and types, and then use that to extract all the dictionaries.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]
True
"""
if (rdd.__class__ is SchemaRDD):
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
Expand All @@ -538,29 +525,64 @@ def registerRDDAsTable(self, rdd, tableName):
"""
Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
during the lifetime of this instance of SQLContext.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
"""
if (rdd.__class__ is SchemaRDD):
jschema_rdd = rdd._jschema_rdd
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
else:
raise ValueError("Can only register SchemaRDD as table")

def parquetFile(path):
def parquetFile(self, path):
"""
Loads a Parquet file, returning the result as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
>>> srdd.collect() == srdd2.collect()
True
"""
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)

def sql(self, sqlQuery):
"""
Executes a SQL query using Spark, returning the result as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, {"f1" : 3, "f2": "row3"}]
True
"""
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)

def table(self, tableName):
"""
Returns the specified table as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.table("table1")
>>> srdd.collect() == srdd2.collect()
True
"""
return SchemaRDD(self._ssql_ctx.table(tableName), self)

Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,38 @@ def _jrdd(self):
def _id(self):
return self._jrdd.id()

def saveAsParquetFile(self, path):
"""
Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema. Files
that are written out using this method can be read back in as a SchemaRDD using the
L{SQLContext.parquetFile} method.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/test.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
>>> srdd2.collect() == srdd.collect()
True
"""
self._jschema_rdd.saveAsParquetFile(path)

def registerAsTable(self, name):
"""
Registers this RDD as a temporary table using the given name. The lifetime of this temporary
table is tied to the L{SQLContext} that was used to create this SchemaRDD.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.applySchema(rdd)
>>> srdd.registerAsTable("test")
>>> srdd2 = sqlCtx.sql("select * from test")
>>> srdd.collect() == srdd2.collect()
True
"""
self._jschema_rdd.registerAsTable(name)

def toPython(self):
Expand Down

0 comments on commit 79f739d

Please sign in to comment.