diff --git a/ibis/__init__.py b/ibis/__init__.py index bd4a73ce00fb..40ec9f46dd77 100644 --- a/ibis/__init__.py +++ b/ibis/__init__.py @@ -57,6 +57,9 @@ # pip install ibis-framework[spark] import ibis.spark.api as spark # noqa: F401 +with suppress(ImportError): + import ibis.pyspark.api as pyspark # noqa: F401 + def hdfs_connect( host='localhost', diff --git a/ibis/pyspark/__init__.py b/ibis/pyspark/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/pyspark/api.py b/ibis/pyspark/api.py new file mode 100644 index 000000000000..36439c6371da --- /dev/null +++ b/ibis/pyspark/api.py @@ -0,0 +1,18 @@ +from ibis.pyspark.client import PySparkClient + + +def connect(session): + """ + Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient, + which pipes them into SparkContext. See documentation for SparkContext: + https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext + """ + client = PySparkClient(session) + + # Spark internally stores timestamps as UTC values, and timestamp data that + # is brought in without a specified time zone is converted as local time to + # UTC with microsecond resolution. + # https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics + client._session.conf.set('spark.sql.session.timeZone', 'UTC') + + return client diff --git a/ibis/pyspark/client.py b/ibis/pyspark/client.py new file mode 100644 index 000000000000..63062cfcee55 --- /dev/null +++ b/ibis/pyspark/client.py @@ -0,0 +1,46 @@ +from pyspark.sql.column import Column + +import ibis.common.exceptions as com +import ibis.expr.types as types +from ibis.pyspark.compiler import PySparkExprTranslator +from ibis.pyspark.operations import PySparkTable +from ibis.spark.client import SparkClient + + +class PySparkClient(SparkClient): + """ + An ibis client that uses PySpark SQL Dataframe + """ + + table_class = PySparkTable + + def __init__(self, session): + super().__init__(session) + self.translator = PySparkExprTranslator() + + def compile(self, expr, *args, **kwargs): + """Compile an ibis expression to a PySpark DataFrame object + """ + return self.translator.translate(expr, scope={}) + + def execute(self, expr, params=None, limit='default', **kwargs): + if isinstance(expr, types.TableExpr): + return self.compile(expr).toPandas() + elif isinstance(expr, types.ColumnExpr): + # expression must be named for the projection + expr = expr.name('tmp') + return self.compile(expr.to_projection()).toPandas()['tmp'] + elif isinstance(expr, types.ScalarExpr): + compiled = self.compile(expr) + if isinstance(compiled, Column): + # attach result column to a fake DataFrame and + # select the result + compiled = self._session.range(0, 1).select(compiled) + return compiled.toPandas().iloc[0, 0] + else: + raise com.IbisError( + "Cannot execute expression of type: {}".format(type(expr))) + + def sql(self, query): + raise NotImplementedError( + "PySpark backend doesn't support sql query") diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py new file mode 100644 index 000000000000..72e0028318ce --- /dev/null +++ b/ibis/pyspark/compiler.py @@ -0,0 +1,519 @@ +import collections +import enum +import functools + +import pyspark.sql.functions as F + +import ibis.common.exceptions as com +import ibis.expr.operations as ops +import ibis.expr.types as types +from ibis.pyspark.operations import PySparkTable + + +class AggregationContext(enum.Enum): + ENTIRE = 0 + WINDOW = 1 + GROUP = 2 + + +class PySparkExprTranslator: + _registry = {} + + @classmethod + def compiles(cls, klass): + def decorator(f): + cls._registry[klass] = f + return f + + return decorator + + def translate(self, expr, scope, **kwargs): + # The operation node type the typed expression wraps + op = expr.op() + + if type(op) in self._registry: + formatter = self._registry[type(op)] + return formatter(self, expr, scope, **kwargs) + else: + raise com.OperationNotDefinedError( + 'No translation rule for {}'.format(type(op)) + ) + + +compiles = PySparkExprTranslator.compiles + + +def compile_with_scope(t, expr, scope): + """Compile a expression and put the result in scope. + + If the expression is already in scope, return it. + """ + op = expr.op() + + if op in scope: + result = scope[op] + else: + result = t.translate(expr, scope) + scope[op] = result + + return result + + +@compiles(PySparkTable) +def compile_datasource(t, expr, scope): + op = expr.op() + name, _, client = op.args + return client._session.table(name) + + +@compiles(ops.Selection) +def compile_selection(t, expr, scope, **kwargs): + # Cache compile results for tables + op = expr.op() + + # TODO: Support predicates and sort_keys + if op.predicates or op.sort_keys: + raise NotImplementedError( + "predicates and sort_keys are not supported with Selection") + + src_table = compile_with_scope(t, op.table, scope) + col_names_in_selection_order = [] + + for selection in op.selections: + if isinstance(selection, types.TableExpr): + col_names_in_selection_order.extend(selection.columns) + elif isinstance(selection, types.ColumnExpr): + column_name = selection.get_name() + col_names_in_selection_order.append(column_name) + column = t.translate(selection, scope=scope) + src_table = src_table.withColumn(column_name, column) + + return src_table[col_names_in_selection_order] + + +@compiles(ops.TableColumn) +def compile_column(t, expr, scope, **kwargs): + op = expr.op() + table = compile_with_scope(t, op.table, scope) + return table[op.name] + + +@compiles(ops.SelfReference) +def compile_self_reference(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.table, scope) + + +@compiles(ops.Equals) +def compile_equals(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) == t.translate(op.right, scope) + + +@compiles(ops.Greater) +def compile_greater(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) > t.translate(op.right, scope) + + +@compiles(ops.GreaterEqual) +def compile_greater_equal(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) >= t.translate(op.right, scope) + + +@compiles(ops.Multiply) +def compile_multiply(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) * t.translate(op.right, scope) + + +@compiles(ops.Subtract) +def compile_subtract(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) - t.translate(op.right, scope) + + +@compiles(ops.Literal) +def compile_literal(t, expr, scope, raw=False, **kwargs): + """ If raw is True, don't wrap the result with F.lit() + """ + value = expr.op().value + + if raw: + return value + + if isinstance(value, collections.abc.Set): + # Don't wrap set with F.lit + if isinstance(value, frozenset): + # Spark doens't like frozenset + return set(value) + else: + return value + else: + return F.lit(expr.op().value) + + +@compiles(ops.Aggregation) +def compile_aggregation(t, expr, scope, **kwargs): + op = expr.op() + + src_table = t.translate(op.table, scope) + + if op.by: + context = AggregationContext.GROUP + aggs = [t.translate(m, scope, context=context) + for m in op.metrics] + bys = [t.translate(b, scope) for b in op.by] + return src_table.groupby(*bys).agg(*aggs) + else: + context = AggregationContext.ENTIRE + aggs = [t.translate(m, scope, context=context) + for m in op.metrics] + return src_table.agg(*aggs) + + +@compiles(ops.Contains) +def compile_contains(t, expr, scope, **kwargs): + op = expr.op() + col = t.translate(op.value, scope) + return col.isin(t.translate(op.options, scope)) + + +def compile_aggregator(t, expr, scope, fn, context=None, **kwargs): + op = expr.op() + src_col = t.translate(op.arg, scope) + + if getattr(op, "where", None) is not None: + condition = t.translate(op.where, scope) + src_col = F.when(condition, src_col) + + col = fn(src_col) + if context: + return col + else: + # We are trying to compile a expr such as some_col.max() + # to a Spark expression. + # Here we get the root table df of that column and compile + # the expr to: + # df.select(max(some_col)) + return t.translate(expr.op().arg.op().table, scope).select(col) + + +@compiles(ops.GroupConcat) +def compile_group_concat(t, expr, scope, context=None, **kwargs): + sep = expr.op().sep.op().value + + def fn(col): + return F.concat_ws(sep, F.collect_list(col)) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Any) +def compile_any(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.max, context) + + +@compiles(ops.NotAny) +def compile_notany(t, expr, scope, context=None, **kwargs): + + def fn(col): + return ~F.max(col) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.All) +def compile_all(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.min, context) + + +@compiles(ops.NotAll) +def compile_notall(t, expr, scope, context=None, **kwargs): + + def fn(col): + return ~F.min(col) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Count) +def compile_count(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.count, context) + + +@compiles(ops.Max) +def compile_max(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.max, context) + + +@compiles(ops.Min) +def compile_min(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.min, context) + + +@compiles(ops.Mean) +def compile_mean(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.mean, context) + + +@compiles(ops.Sum) +def compile_sum(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.sum, context) + + +@compiles(ops.StandardDev) +def compile_std(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'sample': + fn = F.stddev_samp + elif how == 'pop': + fn = F.stddev_pop + else: + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Variance) +def compile_variance(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'sample': + fn = F.var_samp + elif how == 'pop': + fn = F.var_pop + else: + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Arbitrary) +def compile_arbitrary(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'first': + fn = functools.partial(F.first, ignorenulls=True) + elif how == 'last': + fn = functools.partial(F.last, ignorenulls=True) + else: + raise NotImplementedError( + "Does not support 'how': {}".format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Greatest) +def compile_greatest(t, expr, scope, **kwargs): + op = expr.op() + + src_columns = t.translate(op.arg, scope) + if len(src_columns) == 1: + return src_columns[0] + else: + return F.greatest(*src_columns) + + +@compiles(ops.Least) +def compile_least(t, expr, scope, **kwargs): + op = expr.op() + + src_columns = t.translate(op.arg, scope) + if len(src_columns) == 1: + return src_columns[0] + else: + return F.least(*src_columns) + + +@compiles(ops.Abs) +def compile_abs(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.abs(src_column) + + +@compiles(ops.Round) +def compile_round(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + scale = (t.translate(op.digits, scope, raw=True) + if op.digits is not None else 0) + rounded = F.round(src_column, scale=scale) + if scale == 0: + rounded = rounded.astype('long') + return rounded + + +@compiles(ops.Ceil) +def compile_ceil(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.ceil(src_column) + + +@compiles(ops.Floor) +def compile_floor(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.floor(src_column) + + +@compiles(ops.Exp) +def compile_exp(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.exp(src_column) + + +@compiles(ops.Sign) +def compile_sign(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + + return F.when(src_column == 0, F.lit(0.0)) \ + .otherwise(F.when(src_column > 0, F.lit(1.0)).otherwise(-1.0)) + + +@compiles(ops.Sqrt) +def compile_sqrt(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.sqrt(src_column) + + +@compiles(ops.Log) +def compile_log(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + # Spark log method only takes float + return F.log(float(t.translate(op.base, scope, raw=True)), src_column) + + +@compiles(ops.Ln) +def compile_ln(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log(src_column) + + +@compiles(ops.Log2) +def compile_log2(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log2(src_column) + + +@compiles(ops.Log10) +def compile_log10(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log10(src_column) + + +@compiles(ops.Modulus) +def compile_modulus(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left % right + + +@compiles(ops.Negate) +def compile_negate(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return -src_column + + +@compiles(ops.Add) +def compile_add(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left + right + + +@compiles(ops.Divide) +def compile_divide(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left / right + + +@compiles(ops.FloorDivide) +def compile_floor_divide(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return F.floor(left / right) + + +@compiles(ops.Power) +def compile_power(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return F.pow(left, right) + + +@compiles(ops.IsNan) +def compile_isnan(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.isnan(src_column) + + +@compiles(ops.IsInf) +def compile_isinf(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return (src_column == float('inf')) | (src_column == float('-inf')) + + +@compiles(ops.ValueList) +def compile_value_list(t, expr, scope, **kwargs): + op = expr.op() + return [t.translate(col, scope) for col in op.values] + + +@compiles(ops.InnerJoin) +def compile_inner_join(t, expr, scope, **kwargs): + return compile_join(t, expr, scope, 'inner') + + +def compile_join(t, expr, scope, how): + op = expr.op() + + left_df = t.translate(op.left, scope) + right_df = t.translate(op.right, scope) + # TODO: Handle multiple predicates + predicates = t.translate(op.predicates[0], scope) + + return left_df.join(right_df, predicates, how) diff --git a/ibis/pyspark/operations.py b/ibis/pyspark/operations.py new file mode 100644 index 000000000000..6491c4e058c1 --- /dev/null +++ b/ibis/pyspark/operations.py @@ -0,0 +1,5 @@ +import ibis.expr.operations as ops + + +class PySparkTable(ops.DatabaseTable): + pass diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py new file mode 100644 index 000000000000..b83612d12d34 --- /dev/null +++ b/ibis/pyspark/tests/test_basic.py @@ -0,0 +1,176 @@ +import pandas as pd +import pandas.util.testing as tm +import pytest + +import ibis +import ibis.common.exceptions as comm + +pytest.importorskip('pyspark') +pytestmark = pytest.mark.pyspark + + +@pytest.fixture(scope='session') +def client(): + from pyspark.sql import SparkSession + import pyspark.sql.functions as F + + session = SparkSession.builder.getOrCreate() + client = ibis.pyspark.connect(session) + df = client._session.range(0, 10) + df = df.withColumn("str_col", F.lit('value')) + df.createTempView('table1') + + df1 = client._session.createDataFrame([(True,), (False,)]).toDF('v') + df1.createTempView('table2') + return client + + +def test_basic(client): + table = client.table('table1') + result = table.compile().toPandas() + expected = pd.DataFrame({'id': range(0, 10), 'str_col': 'value'}) + + tm.assert_frame_equal(result, expected) + + +def test_projection(client): + table = client.table('table1') + result1 = table.mutate(v=table['id']).compile().toPandas() + + expected1 = pd.DataFrame( + { + 'id': range(0, 10), + 'str_col': 'value', + 'v': range(0, 10), + } + ) + + result2 = ( + table + .mutate(v=table['id']) + .mutate(v2=table['id']) + .mutate(id=table['id'] * 2) + .compile().toPandas() + ) + + expected2 = pd.DataFrame( + { + 'id': range(0, 20, 2), + 'str_col': 'value', + 'v': range(0, 10), + 'v2': range(0, 10), + } + ) + + tm.assert_frame_equal(result1, expected1) + tm.assert_frame_equal(result2, expected2) + + +def test_aggregation_col(client): + table = client.table('table1') + result = table['id'].count().execute() + assert result == table.compile().count() + + +def test_aggregation(client): + import pyspark.sql.functions as F + + table = client.table('table1') + result = table.aggregate(table['id'].max()).compile() + expected = table.compile().agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_groupby(client): + import pyspark.sql.functions as F + + table = client.table('table1') + result = table.groupby('id').aggregate(table['id'].max()).compile() + expected = table.compile().groupby('id').agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +@pytest.mark.xfail( + reason='This is not implemented yet', + raises=comm.OperationNotDefinedError +) +def test_window(client): + import pyspark.sql.functions as F + from pyspark.sql.window import Window + + table = client.table('table1') + w = ibis.window() + result = ( + table + .mutate( + grouped_demeaned=table['id'] - table['id'].mean().over(w)) + .compile() + ) + result2 = ( + table + .groupby('id') + .mutate( + grouped_demeaned=table['id'] - table['id'].mean()) + .compile() + ) + + spark_window = Window.partitionBy() + spark_table = table.compile() + expected = spark_table.withColumn( + 'grouped_demeaned', + spark_table['id'] - F.mean(spark_table['id']).over(spark_window) + ) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + tm.assert_frame_equal(result2.toPandas(), expected.toPandas()) + + +def test_greatest(client): + table = client.table('table1') + result = ( + table + .mutate(greatest=ibis.greatest(table.id)) + .compile() + ) + df = table.compile() + expected = table.compile().withColumn('greatest', df.id) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_selection(client): + table = client.table('table1') + table = table.mutate(id2=table['id'] * 2) + + result1 = table[['id']].compile() + result2 = table[['id', 'id2']].compile() + result3 = table[[table, (table.id + 1).name('plus1')]].compile() + result4 = table[[(table.id + 1).name('plus1'), table]].compile() + + df = table.compile() + tm.assert_frame_equal(result1.toPandas(), df[['id']].toPandas()) + tm.assert_frame_equal(result2.toPandas(), df[['id', 'id2']].toPandas()) + tm.assert_frame_equal(result3.toPandas(), + df[[df.columns]].withColumn('plus1', df.id + 1) + .toPandas()) + tm.assert_frame_equal(result4.toPandas(), + df.withColumn('plus1', df.id + 1) + [['plus1', *df.columns]].toPandas()) + + +@pytest.mark.xfail( + reason='Join is not fully implemented', + raises=AssertionError +) +def test_join(client): + table = client.table('table1') + result = table.join(table, ['id', 'str_col']).compile() + spark_table = table.compile() + expected = ( + spark_table + .join(spark_table, ['id', 'str_col']) + ) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) diff --git a/ibis/spark/api.py b/ibis/spark/api.py index 284de775efeb..5cbf92bde083 100644 --- a/ibis/spark/api.py +++ b/ibis/spark/api.py @@ -28,13 +28,13 @@ def verify(expr, params=None): return False -def connect(**kwargs): +def connect(spark_session): """ Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient, which pipes them into SparkContext. See documentation for SparkContext: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext """ - client = SparkClient(**kwargs) + client = SparkClient(spark_session) # Spark internally stores timestamps as UTC values, and timestamp data that # is brought in without a specified time zone is converted as local time to diff --git a/ibis/spark/client.py b/ibis/spark/client.py index 895c78788469..2056c8a8c1a4 100644 --- a/ibis/spark/client.py +++ b/ibis/spark/client.py @@ -294,10 +294,10 @@ class SparkClient(SQLClient): table_class = SparkDatabaseTable table_expr_class = SparkTable - def __init__(self, **kwargs): - self._context = ps.SparkContext(**kwargs) - self._session = ps.sql.SparkSession(self._context) - self._catalog = self._session.catalog + def __init__(self, session): + self._context = session.sparkContext + self._session = session + self._catalog = session.catalog def close(self): """ diff --git a/ibis/tests/all/conftest.py b/ibis/tests/all/conftest.py index 542d1ef48e2f..89ee85c71276 100644 --- a/ibis/tests/all/conftest.py +++ b/ibis/tests/all/conftest.py @@ -197,18 +197,36 @@ def geo_df(geo): _spark_testing_client = None +_pyspark_testing_client = None def get_spark_testing_client(data_directory): global _spark_testing_client + if _spark_testing_client is None: + _spark_testing_client = get_common_spark_testing_client( + data_directory, + lambda session: ibis.spark.connect(session) + ) + return _spark_testing_client + + +def get_pyspark_testing_client(data_directory): + global _pyspark_testing_client + if _pyspark_testing_client is None: + _pyspark_testing_client = get_common_spark_testing_client( + data_directory, + lambda session: ibis.pyspark.connect(session) + ) + return _pyspark_testing_client - if _spark_testing_client is not None: - return _spark_testing_client +def get_common_spark_testing_client(data_directory, connect): pytest.importorskip('pyspark') import pyspark.sql.types as pt + from pyspark.sql import SparkSession - _spark_testing_client = ibis.spark.connect() + spark = SparkSession.builder.getOrCreate() + _spark_testing_client = connect(spark) s = _spark_testing_client._session df_functional_alltypes = s.read.csv( diff --git a/ibis/tests/all/test_aggregation.py b/ibis/tests/all/test_aggregation.py index e13af9e5915b..a8aee75b699d 100644 --- a/ibis/tests/all/test_aggregation.py +++ b/ibis/tests/all/test_aggregation.py @@ -154,4 +154,5 @@ def test_group_concat(backend, alltypes, df, result_fn, expected_fn): expr = result_fn(alltypes) result = expr.execute() expected = expected_fn(df) - assert set(result) == set(expected) + + assert set(result.iloc[:, 1]) == set(expected.iloc[:, 1]) diff --git a/ibis/tests/all/test_client.py b/ibis/tests/all/test_client.py index c7bca9da848d..1bd7d3193b3a 100644 --- a/ibis/tests/all/test_client.py +++ b/ibis/tests/all/test_client.py @@ -3,7 +3,7 @@ import ibis import ibis.expr.datatypes as dt -from ibis.tests.backends import BigQuery +from ibis.tests.backends import BigQuery, PySpark @pytest.mark.xfail_unsupported @@ -25,6 +25,7 @@ def test_version(backend, con): ), ], ) +@pytest.mark.xfail_backends((PySpark,)) def test_query_schema(backend, con, alltypes, expr_fn, expected): if not hasattr(con, '_build_ast'): pytest.skip( diff --git a/ibis/tests/backends.py b/ibis/tests/backends.py index 6bea2074b7da..b9b87034e01d 100644 --- a/ibis/tests/backends.py +++ b/ibis/tests/backends.py @@ -549,3 +549,26 @@ def batting(self) -> ir.TableExpr: @property def awards_players(self) -> ir.TableExpr: return self.connection.table('awards_players') + + +class PySpark(Backend, RoundAwayFromZero): + @staticmethod + def skip_if_missing_dependencies() -> None: + pytest.importorskip('pyspark') + + @staticmethod + def connect(data_directory): + from ibis.tests.all.conftest import get_pyspark_testing_client + return get_pyspark_testing_client(data_directory) + + @property + def functional_alltypes(self) -> ir.TableExpr: + return self.connection.table('functional_alltypes') + + @property + def batting(self) -> ir.TableExpr: + return self.connection.table('batting') + + @property + def awards_players(self) -> ir.TableExpr: + return self.connection.table('awards_players') diff --git a/setup.cfg b/setup.cfg index 3d9461987eff..124f3c83b7c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ markers = postgis postgresql postgres_extensions + pyspark skip_backends skip_missing_feature spark