Skip to content

Commit

Permalink
Moved everything into sql.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ahirreddy committed Apr 15, 2014
1 parent a19afe4 commit f2312c7
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 321 deletions.
12 changes: 6 additions & 6 deletions python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
Finer-grained cache persistence levels.
Spark SQL:
- L{SQLContext<pyspark.context.SQLContext>}
- L{SQLContext<pyspark.sql.SQLContext>}
Main entry point for SQL functionality.
- L{SchemaRDD<pyspark.rdd.SchemaRDD>}
- L{SchemaRDD<pyspark.sql.SchemaRDD>}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
addition to normal RDD operations, SchemaRDDs also support SQL.
- L{Row<pyspark.rdd.Row>}
- L{Row<pyspark.sql.Row>}
A Row of data returned by a Spark SQL query.
Hive:
Expand All @@ -58,10 +58,10 @@

from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.context import SQLContext
from pyspark.sql import SQLContext
from pyspark.rdd import RDD
from pyspark.rdd import SchemaRDD
from pyspark.rdd import Row
from pyspark.sql import SchemaRDD
from pyspark.sql import Row
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel

Expand Down
224 changes: 1 addition & 223 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@
PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD, SchemaRDD
from pyspark.rdd import RDD

from py4j.java_collections import ListConverter
from py4j.protocol import Py4JError


class SparkContext(object):
Expand Down Expand Up @@ -175,8 +174,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap
SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython

if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
Expand Down Expand Up @@ -463,225 +460,6 @@ def sparkUser(self):
"""
return self._jsc.sc().sparkUser()

class SQLContext:
"""
Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s,
register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
"""

def __init__(self, sparkContext):
"""
Create a new SQLContext.
@param sparkContext: The SparkContext to wrap.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> bad_rdd = sc.parallelize([1,2,3])
>>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
... "boolean" : True}])
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
... x.boolean))
>>> srdd.collect()[0]
(1, u'string', 1.0, 1, True)
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm

@property
def _ssql_ctx(self):
"""
Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide
their own JVM Contexts.
"""
if not hasattr(self, '_scala_SQLContext'):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext

def inferSchema(self, rdd):
"""
Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
determine the fields names and types, and then use that to extract all the dictionaries.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}]
True
"""
if (rdd.__class__ is SchemaRDD):
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
elif not isinstance(rdd.first(), dict):
raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
(SchemaRDD.__name__, rdd.first()))

jrdd = self._sc._pythonToJavaMap(rdd._jrdd)
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
return SchemaRDD(srdd, self)

def registerRDDAsTable(self, rdd, tableName):
"""
Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
during the lifetime of this instance of SQLContext.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
"""
if (rdd.__class__ is SchemaRDD):
jschema_rdd = rdd._jschema_rdd
self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
else:
raise ValueError("Can only register SchemaRDD as table")

def parquetFile(self, path):
"""
Loads a Parquet file, returning the result as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
>>> srdd.collect() == srdd2.collect()
True
"""
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)

def sql(self, sqlQuery):
"""
Executes a SQL query using Spark, returning the result as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
... {"f1" : 3, "f2": "row3"}]
True
"""
return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)

def table(self, tableName):
"""
Returns the specified table as a L{SchemaRDD}.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.table("table1")
>>> srdd.collect() == srdd2.collect()
True
"""
return SchemaRDD(self._ssql_ctx.table(tableName), self)

def cacheTable(tableName):
"""
Caches the specified table in-memory.
"""
self._ssql_ctx.cacheTable(tableName)

def uncacheTable(tableName):
"""
Removes the specified table from the in-memory cache.
"""
self._ssql_ctx.uncacheTable(tableName)

class HiveContext(SQLContext):
"""
An instance of the Spark SQL execution engine that integrates with data stored in Hive.
Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL
and HiveQL commands.
"""

@property
def _ssql_ctx(self):
try:
if not hasattr(self, '_scala_HiveContext'):
self._scala_HiveContext = self._get_hive_ctx()
return self._scala_HiveContext
except Py4JError as e:
raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
"sbt/sbt assembly" , e)

def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())

def hiveql(self, hqlQuery):
"""
Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
"""
return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)

def hql(self, hqlQuery):
"""
Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
"""
return self.hiveql(hqlQuery)

class LocalHiveContext(HiveContext):
"""
Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
>>> import os
>>> from pyspark.context import LocalHiveContext
>>> hiveCtx = LocalHiveContext(sc)
>>> try:
... supress = hiveCtx.hql("DROP TABLE src")
... except Exception:
... pass
>>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
>>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
>>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
>>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
>>> num = results.count()
>>> reduce_sum = results.reduce(lambda x, y: x + y)
>>> num
500
>>> reduce_sum
130091
"""

def _get_hive_ctx(self):
return self._jvm.LocalHiveContext(self._jsc.sc())

class TestHiveContext(HiveContext):

def _get_hive_ctx(self):
return self._jvm.TestHiveContext(self._jsc.sc())

def _test():
import atexit
import doctest
Expand Down
89 changes: 0 additions & 89 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,95 +1387,6 @@ def _jrdd(self):
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)

class Row(dict):
"""
An extended L{dict} that takes a L{dict} in its constructor, and exposes those items as fields.
>>> r = Row({"hello" : "world", "foo" : "bar"})
>>> r.hello
'world'
>>> r.foo
'bar'
"""

def __init__(self, d):
d.update(self.__dict__)
self.__dict__ = d
dict.__init__(self, d)

class SchemaRDD(RDD):
"""
An RDD of Row objects that has an associated schema. The underlying JVM object is a SchemaRDD,
not a PythonRDD, so we can utilize the relational query api exposed by SparkSQL.
For normal L{RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on directly, as
it's underlying implementation is a RDD composed of Java objects. Instead it is converted to a
PythonRDD in the JVM, on which Python operations can be done.
"""

def __init__(self, jschema_rdd, sql_ctx):
self.sql_ctx = sql_ctx
self._sc = sql_ctx._sc
self._jschema_rdd = jschema_rdd

self.is_cached = False
self.is_checkpointed = False
self.ctx = self.sql_ctx._sc
self._jrdd_deserializer = self.ctx.serializer

@property
def _jrdd(self):
"""
Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
L{RDD} super class (map, count, etc.).
"""
return self.toPython()._jrdd

@property
def _id(self):
return self._jrdd.id()

def saveAsParquetFile(self, path):
"""
Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema. Files
that are written out using this method can be read back in as a SchemaRDD using the
L{SQLContext.parquetFile} method.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile("/tmp/test.parquet")
>>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
>>> srdd2.collect() == srdd.collect()
True
"""
self._jschema_rdd.saveAsParquetFile(path)

def registerAsTable(self, name):
"""
Registers this RDD as a temporary table using the given name. The lifetime of this temporary
table is tied to the L{SQLContext} that was used to create this SchemaRDD.
>>> from pyspark.context import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"},
... {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.registerAsTable("test")
>>> srdd2 = sqlCtx.sql("select * from test")
>>> srdd.collect() == srdd2.collect()
True
"""
self._jschema_rdd.registerAsTable(name)

def toPython(self):
jrdd = self._jschema_rdd.javaToPython()
# TODO: This is inefficient, we should construct the Python Row object
# in Java land in the javaToPython function. May require a custom
# pickle serializer in Pyrolite
return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))

def _test():
import doctest
Expand Down
Loading

0 comments on commit f2312c7

Please sign in to comment.