From e00dc000295f4199e6639c01e303e18cba0442f2 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 9 Aug 2019 14:43:28 -0400 Subject: [PATCH] Address PR comments --- ibis/pyspark/api.py | 4 +- ibis/pyspark/client.py | 30 +++-- ibis/pyspark/compiler.py | 203 +++++++++++++++++-------------- ibis/pyspark/operations.py | 2 +- ibis/pyspark/tests/test_basic.py | 13 +- ibis/tests/backends.py | 1 - 6 files changed, 142 insertions(+), 111 deletions(-) diff --git a/ibis/pyspark/api.py b/ibis/pyspark/api.py index 73209634a47f..36439c6371da 100644 --- a/ibis/pyspark/api.py +++ b/ibis/pyspark/api.py @@ -1,4 +1,4 @@ -from ibis.pyspark.client import PysparkClient +from ibis.pyspark.client import PySparkClient def connect(session): @@ -7,7 +7,7 @@ def connect(session): which pipes them into SparkContext. See documentation for SparkContext: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext """ - client = PysparkClient(session) + client = PySparkClient(session) # Spark internally stores timestamps as UTC values, and timestamp data that # is brought in without a specified time zone is converted as local time to diff --git a/ibis/pyspark/client.py b/ibis/pyspark/client.py index 02dd8adfe4aa..ce6741ffefde 100644 --- a/ibis/pyspark/client.py +++ b/ibis/pyspark/client.py @@ -1,26 +1,30 @@ +from pyspark.sql.column import Column + +import ibis.common.exceptions as com import ibis.expr.types as types -from ibis.pyspark.compiler import translate -from ibis.pyspark.operations import PysparkTable +from ibis.pyspark.compiler import PySparkExprTranslator +from ibis.pyspark.operations import PySparkTable from ibis.spark.client import SparkClient -from pyspark.sql.column import Column - -class PysparkClient(SparkClient): +class PySparkClient(SparkClient): """ - An ibis client that uses Pyspark SQL Dataframe + An ibis client that uses PySpark SQL Dataframe """ dialect = None - table_class = PysparkTable + table_class = PySparkTable + + def __init__(self, session): + super().__init__(session) + self.translator = PySparkExprTranslator() def compile(self, expr, *args, **kwargs): - """Compile an ibis expression to a Pyspark DataFrame object + """Compile an ibis expression to a PySpark DataFrame object """ - return translate(expr) + return self.translator.translate(expr) def execute(self, expr, params=None, limit='default', **kwargs): - if isinstance(expr, types.TableExpr): return self.compile(expr).toPandas() elif isinstance(expr, types.ColumnExpr): @@ -32,8 +36,8 @@ def execute(self, expr, params=None, limit='default', **kwargs): if isinstance(compiled, Column): # attach result column to a fake DataFrame and # select the result - compiled = self._session.range(0, 1) \ - .select(compiled) + compiled = self._session.range(0, 1).select(compiled) return compiled.toPandas().iloc[0, 0] else: - raise ValueError("Unexpected type: ", type(expr)) + raise com.IbisError( + "Cannot execute expression of type: {}".format(type(expr))) diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py index 40ffba53e1f7..4fce6710b7f9 100644 --- a/ibis/pyspark/compiler.py +++ b/ibis/pyspark/compiler.py @@ -1,21 +1,24 @@ import collections +import enum import functools import pyspark.sql.functions as F -from pyspark.sql.functions import PandasUDFType, pandas_udf from pyspark.sql.window import Window import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.types as types -from ibis.pyspark.operations import PysparkTable -from ibis.sql.compiler import Dialect +from ibis.pyspark.operations import PySparkTable -_operation_registry = {} +class AggregationContext(enum.Enum): + ENTIRE = 0 + WINDOW = 1 + GROUP = 2 -class PysparkExprTranslator: - _registry = _operation_registry + +class PySparkExprTranslator: + _registry = {} @classmethod def compiles(cls, klass): @@ -38,85 +41,103 @@ def translate(self, expr, **kwargs): ) -class PysparkDialect(Dialect): - translator = PysparkExprTranslator - - -compiles = PysparkExprTranslator.compiles +compiles = PySparkExprTranslator.compiles -@compiles(PysparkTable) +@compiles(PySparkTable) def compile_datasource(t, expr): op = expr.op() name, _, client = op.args return client._session.table(name) +def compile_table_and_cache(t, expr, cache): + """Compile a Table expression and cache the result + """ + + assert isinstance(expr, types.TableExpr) + if expr in cache: + table = cache[expr] + else: + table = t.translate(expr) + cache[expr] = table + return table + + @compiles(ops.Selection) -def compile_selection(t, expr): +def compile_selection(t, expr, **kwargs): + # Cache compile results for tables + table_cache = {} + op = expr.op() - src_table = t.translate(op.table) + src_table = compile_table_and_cache(t, op.table, table_cache) col_names_in_selection_order = [] + for selection in op.selections: if isinstance(selection, types.TableExpr): col_names_in_selection_order.extend(selection.columns) elif isinstance(selection, types.ColumnExpr): column_name = selection.get_name() col_names_in_selection_order.append(column_name) - if column_name not in src_table.columns: - column = t.translate(selection) - src_table = src_table.withColumn(column_name, column) + column = t.translate(selection, table_cache=table_cache) + src_table = src_table.withColumn(column_name, column) return src_table[col_names_in_selection_order] @compiles(ops.TableColumn) -def compile_column(t, expr): +def compile_column(t, expr, table_cache={}, **kwargs): op = expr.op() - return t.translate(op.table)[op.name] + table = compile_table_and_cache(t, op.table, table_cache) + return table[op.name] @compiles(ops.SelfReference) -def compile_self_reference(t, expr): +def compile_self_reference(t, expr, **kwargs): op = expr.op() return t.translate(op.table) @compiles(ops.Equals) -def compile_equals(t, expr): +def compile_equals(t, expr, **kwargs): op = expr.op() return t.translate(op.left) == t.translate(op.right) @compiles(ops.Greater) -def compile_greater(t, expr): +def compile_greater(t, expr, **kwargs): op = expr.op() return t.translate(op.left) > t.translate(op.right) @compiles(ops.GreaterEqual) -def compile_greater_equal(t, expr): +def compile_greater_equal(t, expr, **kwargs): op = expr.op() return t.translate(op.left) >= t.translate(op.right) @compiles(ops.Multiply) -def compile_multiply(t, expr): +def compile_multiply(t, expr, **kwargs): op = expr.op() return t.translate(op.left) * t.translate(op.right) @compiles(ops.Subtract) -def compile_subtract(t, expr): +def compile_subtract(t, expr, **kwargs): op = expr.op() return t.translate(op.left) - t.translate(op.right) @compiles(ops.Literal) -def compile_literal(t, expr): +def compile_literal(t, expr, raw=False, **kwargs): + """ If raw is True, don't wrap the result with F.lit() + """ value = expr.op().value + if raw: + return value + if isinstance(value, collections.abc.Set): # Don't wrap set with F.lit if isinstance(value, frozenset): @@ -129,27 +150,32 @@ def compile_literal(t, expr): @compiles(ops.Aggregation) -def compile_aggregation(t, expr): +def compile_aggregation(t, expr, **kwargs): op = expr.op() src_table = t.translate(op.table) - aggs = [t.translate(m, context="agg") - for m in op.metrics] if op.by: + context = AggregationContext.GROUP + aggs = [t.translate(m, context=context) + for m in op.metrics] bys = [t.translate(b) for b in op.by] return src_table.groupby(*bys).agg(*aggs) else: + context = AggregationContext.ENTIRE + aggs = [t.translate(m, context=context) + for m in op.metrics] return src_table.agg(*aggs) @compiles(ops.Contains) -def compile_contains(t, expr): - col = t.translate(expr.op().value) - return col.isin(t.translate(expr.op().options)) +def compile_contains(t, expr, **kwargs): + op = expr.op() + col = t.translate(op.value) + return col.isin(t.translate(op.options)) -def compile_aggregator(t, expr, fn, context=None): +def compile_aggregator(t, expr, fn, context=None, **kwargs): op = expr.op() src_col = t.translate(op.arg) @@ -165,7 +191,7 @@ def compile_aggregator(t, expr, fn, context=None): @compiles(ops.GroupConcat) -def compile_group_concat(t, expr, context=None): +def compile_group_concat(t, expr, context=None, **kwargs): sep = expr.op().sep.op().value def fn(col): @@ -174,12 +200,12 @@ def fn(col): @compiles(ops.Any) -def compile_any(t, expr, context=None): +def compile_any(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.max, context) @compiles(ops.NotAny) -def compile_notany(t, expr, context=None): +def compile_notany(t, expr, context=None, **kwargs): def fn(col): return ~F.max(col) @@ -187,12 +213,12 @@ def fn(col): @compiles(ops.All) -def compile_all(t, expr, context=None): +def compile_all(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.min, context) @compiles(ops.NotAll) -def compile_notall(t, expr, context=None): +def compile_notall(t, expr, context=None, **kwargs): def fn(col): return ~F.min(col) @@ -200,32 +226,32 @@ def fn(col): @compiles(ops.Count) -def compile_count(t, expr, context=None): +def compile_count(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.count, context) @compiles(ops.Max) -def compile_max(t, expr, context=None): +def compile_max(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.max, context) @compiles(ops.Min) -def compile_min(t, expr, context=None): +def compile_min(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.min, context) @compiles(ops.Mean) -def compile_mean(t, expr, context=None): +def compile_mean(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.mean, context) @compiles(ops.Sum) -def compile_sum(t, expr, context=None): +def compile_sum(t, expr, context=None, **kwargs): return compile_aggregator(t, expr, F.sum, context) @compiles(ops.StandardDev) -def compile_std(t, expr, context=None): +def compile_std(t, expr, context=None, **kwargs): how = expr.op().how if how == 'sample': @@ -233,13 +259,16 @@ def compile_std(t, expr, context=None): elif how == 'pop': fn = F.stddev_pop else: - raise AssertionError("Unexpected how: {}".format(how)) + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) return compile_aggregator(t, expr, fn, context) @compiles(ops.Variance) -def compile_variance(t, expr, context=None): +def compile_variance(t, expr, context=None, **kwargs): how = expr.op().how if how == 'sample': @@ -247,13 +276,16 @@ def compile_variance(t, expr, context=None): elif how == 'pop': fn = F.var_pop else: - raise AssertionError("Unexpected how: {}".format(how)) + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) return compile_aggregator(t, expr, fn, context) @compiles(ops.Arbitrary) -def compile_arbitrary(t, expr, context=None): +def compile_arbitrary(t, expr, context=None, **kwargs): how = expr.op().how if how == 'first': @@ -261,19 +293,23 @@ def compile_arbitrary(t, expr, context=None): elif how == 'last': fn = functools.partial(F.last, ignorenulls=True) else: - raise NotImplementedError + raise NotImplementedError( + "Does not support 'how': {}".format(how) + ) return compile_aggregator(t, expr, fn, context) @compiles(ops.WindowOp) -def compile_window_op(t, expr): +def compile_window_op(t, expr, **kwargs): op = expr.op() - return t.translate(op.expr).over(compile_window(op.window)) + + return (t.translate(op.expr, context=AggregationContext.WINDOW) + .over(compile_window(op.window))) @compiles(ops.Greatest) -def compile_greatest(t, expr): +def compile_greatest(t, expr, **kwargs): op = expr.op() src_columns = t.translate(op.arg) @@ -284,7 +320,7 @@ def compile_greatest(t, expr): @compiles(ops.Least) -def compile_least(t, expr): +def compile_least(t, expr, **kwargs): op = expr.op() src_columns = t.translate(op.arg) @@ -295,7 +331,7 @@ def compile_least(t, expr): @compiles(ops.Abs) -def compile_abs(t, expr): +def compile_abs(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -303,11 +339,11 @@ def compile_abs(t, expr): @compiles(ops.Round) -def compile_round(t, expr): +def compile_round(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) - scale = op.digits.op().value if op.digits is not None else 0 + scale = t.translate(op.digits, raw=True) if op.digits is not None else 0 rounded = F.round(src_column, scale=scale) if scale == 0: rounded = rounded.astype('long') @@ -315,7 +351,7 @@ def compile_round(t, expr): @compiles(ops.Ceil) -def compile_ceil(t, expr): +def compile_ceil(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -323,7 +359,7 @@ def compile_ceil(t, expr): @compiles(ops.Floor) -def compile_floor(t, expr): +def compile_floor(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -331,7 +367,7 @@ def compile_floor(t, expr): @compiles(ops.Exp) -def compile_exp(t, expr): +def compile_exp(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -339,7 +375,7 @@ def compile_exp(t, expr): @compiles(ops.Sign) -def compile_sign(t, expr): +def compile_sign(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -349,7 +385,7 @@ def compile_sign(t, expr): @compiles(ops.Sqrt) -def compile_sqrt(t, expr): +def compile_sqrt(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -357,15 +393,16 @@ def compile_sqrt(t, expr): @compiles(ops.Log) -def compile_log(t, expr): +def compile_log(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) - return F.log(float(op.base.op().value), src_column) + # Spark log method only takes float + return F.log(float(t.translate(op.base, raw=True)), src_column) @compiles(ops.Ln) -def compile_ln(t, expr): +def compile_ln(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -373,7 +410,7 @@ def compile_ln(t, expr): @compiles(ops.Log2) -def compile_log2(t, expr): +def compile_log2(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -381,7 +418,7 @@ def compile_log2(t, expr): @compiles(ops.Log10) -def compile_log10(t, expr): +def compile_log10(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -389,7 +426,7 @@ def compile_log10(t, expr): @compiles(ops.Modulus) -def compile_modulus(t, expr): +def compile_modulus(t, expr, **kwargs): op = expr.op() left = t.translate(op.left) @@ -398,7 +435,7 @@ def compile_modulus(t, expr): @compiles(ops.Negate) -def compile_negate(t, expr): +def compile_negate(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -406,7 +443,7 @@ def compile_negate(t, expr): @compiles(ops.Add) -def compile_add(t, expr): +def compile_add(t, expr, **kwargs): op = expr.op() left = t.translate(op.left) @@ -415,7 +452,7 @@ def compile_add(t, expr): @compiles(ops.Divide) -def compile_divide(t, expr): +def compile_divide(t, expr, **kwargs): op = expr.op() left = t.translate(op.left) @@ -424,7 +461,7 @@ def compile_divide(t, expr): @compiles(ops.FloorDivide) -def compile_floor_divide(t, expr): +def compile_floor_divide(t, expr, **kwargs): op = expr.op() left = t.translate(op.left) @@ -433,7 +470,7 @@ def compile_floor_divide(t, expr): @compiles(ops.Power) -def compile_power(t, expr): +def compile_power(t, expr, **kwargs): op = expr.op() left = t.translate(op.left) @@ -442,7 +479,7 @@ def compile_power(t, expr): @compiles(ops.IsNan) -def compile_isnan(t, expr): +def compile_isnan(t, expr, **kwargs): op = expr.op() src_column = t.translate(op.arg) @@ -450,26 +487,21 @@ def compile_isnan(t, expr): @compiles(ops.IsInf) -def compile_isinf(t, expr): - import numpy as np +def compile_isinf(t, expr, **kwargs): op = expr.op() - @pandas_udf('boolean', PandasUDFType.SCALAR) - def isinf(v): - return np.isinf(v) - src_column = t.translate(op.arg) - return isinf(src_column) + return (src_column == float('inf')) | (src_column == float('-inf')) @compiles(ops.ValueList) -def compile_value_list(t, expr): +def compile_value_list(t, expr, **kwargs): op = expr.op() return [t.translate(col) for col in op.values] @compiles(ops.InnerJoin) -def compile_inner_join(t, expr): +def compile_inner_join(t, expr, **kwargs): return compile_join(t, expr, 'inner') @@ -489,10 +521,3 @@ def compile_join(t, expr, how): def compile_window(expr): spark_window = Window.partitionBy() return spark_window - - -t = PysparkExprTranslator() - - -def translate(expr): - return t.translate(expr) diff --git a/ibis/pyspark/operations.py b/ibis/pyspark/operations.py index 9ae34861dabb..6491c4e058c1 100644 --- a/ibis/pyspark/operations.py +++ b/ibis/pyspark/operations.py @@ -1,5 +1,5 @@ import ibis.expr.operations as ops -class PysparkTable(ops.DatabaseTable): +class PySparkTable(ops.DatabaseTable): pass diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py index d7298d340a5e..064f34e2aa4b 100644 --- a/ibis/pyspark/tests/test_basic.py +++ b/ibis/pyspark/tests/test_basic.py @@ -37,21 +37,24 @@ def test_projection(client): { 'id': range(0, 10), 'str_col': 'value', - 'v': range(0, 10) + 'v': range(0, 10), } ) result2 = ( - table.mutate(v=table['id']).mutate(v2=table['id']) + table + .mutate(v=table['id']) + .mutate(v2=table['id']) + .mutate(id=table['id'] * 2) .compile().toPandas() ) expected2 = pd.DataFrame( { - 'id': range(0, 10), + 'id': range(0, 20, 2), 'str_col': 'value', 'v': range(0, 10), - 'v2': range(0, 10) + 'v2': range(0, 10), } ) @@ -124,7 +127,7 @@ def test_greatest(client): def test_selection(client): table = client.table('table1') - table = table.mutate(id2=table['id']) + table = table.mutate(id2=table['id'] * 2) result1 = table[['id']].compile() result2 = table[['id', 'id2']].compile() diff --git a/ibis/tests/backends.py b/ibis/tests/backends.py index 8db83c811c95..42e907ffbe3b 100644 --- a/ibis/tests/backends.py +++ b/ibis/tests/backends.py @@ -552,7 +552,6 @@ def awards_players(self) -> ir.TableExpr: class PySpark(Backend, RoundAwayFromZero): - @staticmethod def skip_if_missing_dependencies() -> None: pytest.importorskip('pyspark')