Skip to content

Commit

Permalink
Link existing tests with PySpark backend (ibis-project#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Aug 7, 2019
1 parent 45278cc commit 9fe05ab
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 14 deletions.
126 changes: 126 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,129 @@ def data_directory():
pytest.skip('test data directory not found')

return datadir


@pytest.fixture(scope='session')
def spark_session(data_directory):
pytest.importorskip('pyspark')

import pyspark.sql.types as pt
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df_functional_alltypes = spark.read.csv(
path=str(data_directory / 'functional_alltypes.csv'),
schema=pt.StructType([
pt.StructField('index', pt.IntegerType(), True),
pt.StructField('Unnamed: 0', pt.IntegerType(), True),
pt.StructField('id', pt.IntegerType(), True),
# cast below, Spark can't read 0/1 as bool
pt.StructField('bool_col', pt.ByteType(), True),
pt.StructField('tinyint_col', pt.ByteType(), True),
pt.StructField('smallint_col', pt.ShortType(), True),
pt.StructField('int_col', pt.IntegerType(), True),
pt.StructField('bigint_col', pt.LongType(), True),
pt.StructField('float_col', pt.FloatType(), True),
pt.StructField('double_col', pt.DoubleType(), True),
pt.StructField('date_string_col', pt.StringType(), True),
pt.StructField('string_col', pt.StringType(), True),
pt.StructField('timestamp_col', pt.TimestampType(), True),
pt.StructField('year', pt.IntegerType(), True),
pt.StructField('month', pt.IntegerType(), True),
]),
mode='FAILFAST',
header=True,
)
df_functional_alltypes = df_functional_alltypes.withColumn(
"bool_col", df_functional_alltypes["bool_col"].cast("boolean"))
df_functional_alltypes.createOrReplaceTempView('functional_alltypes')

df_batting = spark.read.csv(
path=str(data_directory / 'batting.csv'),
schema=pt.StructType([
pt.StructField('playerID', pt.StringType(), True),
pt.StructField('yearID', pt.IntegerType(), True),
pt.StructField('stint', pt.IntegerType(), True),
pt.StructField('teamID', pt.StringType(), True),
pt.StructField('lgID', pt.StringType(), True),
pt.StructField('G', pt.IntegerType(), True),
pt.StructField('AB', pt.DoubleType(), True),
pt.StructField('R', pt.DoubleType(), True),
pt.StructField('H', pt.DoubleType(), True),
pt.StructField('X2B', pt.DoubleType(), True),
pt.StructField('X3B', pt.DoubleType(), True),
pt.StructField('HR', pt.DoubleType(), True),
pt.StructField('RBI', pt.DoubleType(), True),
pt.StructField('SB', pt.DoubleType(), True),
pt.StructField('CS', pt.DoubleType(), True),
pt.StructField('BB', pt.DoubleType(), True),
pt.StructField('SO', pt.DoubleType(), True),
pt.StructField('IBB', pt.DoubleType(), True),
pt.StructField('HBP', pt.DoubleType(), True),
pt.StructField('SH', pt.DoubleType(), True),
pt.StructField('SF', pt.DoubleType(), True),
pt.StructField('GIDP', pt.DoubleType(), True),
]),
header=True,
)
df_batting.createOrReplaceTempView('batting')

df_awards_players = spark.read.csv(
path=str(data_directory / 'awards_players.csv'),
schema=pt.StructType([
pt.StructField('playerID', pt.StringType(), True),
pt.StructField('awardID', pt.StringType(), True),
pt.StructField('yearID', pt.IntegerType(), True),
pt.StructField('lgID', pt.StringType(), True),
pt.StructField('tie', pt.StringType(), True),
pt.StructField('notes', pt.StringType(), True),
]),
header=True,
)
df_awards_players.createOrReplaceTempView('awards_players')

df_simple = spark.createDataFrame([(1, 'a')], ['foo', 'bar'])
df_simple.createOrReplaceTempView('simple')

df_struct = spark.createDataFrame(
[((1, 2, 'a'),)],
['struct_col']
)
df_struct.createOrReplaceTempView('struct')

df_nested_types = spark.createDataFrame(
[
(
[1, 2],
[[3, 4], [5, 6]],
{'a' : [[2, 4], [3, 5]]},
)
],
[
'list_of_ints',
'list_of_list_of_ints',
'map_string_list_of_list_of_ints'
]
)
df_nested_types.createOrReplaceTempView('nested_types')

df_complicated = spark.createDataFrame(
[({(1, 3) : [[2, 4], [3, 5]]},)],
['map_tuple_list_of_list_of_ints']
)
df_complicated.createOrReplaceTempView('complicated')

return spark


@pytest.fixture(scope='session')
def spark_client_testing(spark_session):
pytest.importorskip('pyspark')
return ibis.spark.connect(spark_session)


@pytest.fixture(scope='session')
def pyspark_client_testing(spark_session):
pytest.importorskip('pyspark')
return ibis.pyspark.connect(spark_session)
10 changes: 8 additions & 2 deletions ibis/pyspark/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from ibis.pyspark.client import PysparkClient


def connect(**kwargs):
def connect(session):
"""
Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient,
which pipes them into SparkContext. See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
client = PysparkClient(**kwargs)
client = PysparkClient(session)

# Spark internally stores timestamps as UTC values, and timestamp data that
# is brought in without a specified time zone is converted as local time to
# UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
client._session.conf.set('spark.sql.session.timeZone', 'UTC')

return client
9 changes: 5 additions & 4 deletions ibis/pyspark/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ibis.spark.client import SparkClient
from ibis.pyspark.operations import PysparkTable
from ibis.pyspark.compiler import translate
from ibis.pyspark.operations import PysparkTable
from ibis.spark.client import SparkClient


class PysparkClient(SparkClient):
"""
Expand All @@ -15,5 +16,5 @@ def compile(self, expr, *args, **kwargs):
"""
return translate(expr)

def execute(self, df, params=None, limit='default', **kwargs):
return df.toPandas()
def execute(self, expr, params=None, limit='default', **kwargs):
return self.compile(expr).toPandas()
4 changes: 2 additions & 2 deletions ibis/spark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def verify(expr, params=None):
return False


def connect(**kwargs):
def connect(spark_session):
"""
Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient,
which pipes them into SparkContext. See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
client = SparkClient(**kwargs)
client = SparkClient(spark_session)

# Spark internally stores timestamps as UTC values, and timestamp data that
# is brought in without a specified time zone is converted as local time to
Expand Down
8 changes: 4 additions & 4 deletions ibis/spark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ class SparkClient(SQLClient):
query_class = SparkQuery
table_class = SparkTable

def __init__(self, **kwargs):
self._context = ps.SparkContext(**kwargs)
self._session = ps.sql.SparkSession(self._context)
self._catalog = self._session.catalog
def __init__(self, session):
self._context = session.sparkContext
self._session = session
self._catalog = session.catalog

def close(self):
"""
Expand Down
23 changes: 21 additions & 2 deletions ibis/tests/all/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,36 @@ def geo_df(geo):

_spark_testing_client = None


def get_spark_testing_client(data_directory):
global _spark_testing_client
if _spark_testing_client is None:
_spark_testing_client = get_common_spark_testing_client(
data_directory,
lambda: ibis.spark.connect()
)
return _spark_testing_client


def get_pyspark_testing_client(data_directory):
global _pyspark_testing_client
if _pyspark_testing_client is None:
_pyspark_testing_client = get_common_spark_testing_client(
data_directory,
lambda: ibis.pyspark.connect()
)
return _pyspark_testing_client


def get_common_spark_testing_client(data_directory, connect):
global _spark_testing_client

if _spark_testing_client is not None:
return _spark_testing_client

pytest.importorskip('pyspark')
import pyspark.sql.types as pt

_spark_testing_client = ibis.spark.connect()
_spark_testing_client = connect()
s = _spark_testing_client._session

df_functional_alltypes = s.read.csv(
Expand Down
11 changes: 11 additions & 0 deletions ibis/tests/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,14 @@ def skip_if_missing_dependencies() -> None:
def connect(data_directory):
from ibis.tests.all.conftest import get_spark_testing_client
return get_spark_testing_client(data_directory)


class PySpark(Backend, RoundHalfToEven):
@staticmethod
def skip_if_missing_dependencies() -> None:
pytest.importorskip('pyspark')

@staticmethod
def connect(data_directory):
from ibis.tests.all.conftest import get_pyspark_testing_client
return get_pyspark_testing_client(data_directory)

0 comments on commit 9fe05ab

Please sign in to comment.