Skip to content

Commit

Permalink
refactor(sql): reorganize sqlglot rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Feb 1, 2024
1 parent 01387d0 commit 5a28ad6
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 114 deletions.
39 changes: 18 additions & 21 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import calendar
import itertools
import math
import operator
import string
from collections.abc import Iterator, Mapping
from functools import partial, reduce, singledispatchmethod
Expand All @@ -19,17 +18,18 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot.rewrites import CTE, Select, Window, sqlize
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import (
from ibis.backends.base.sqlglot.rewrites import (
CTE,
Select,
Window,
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
empty_in_values_right_side,
one_to_zero_index,
replace_bucket,
replace_scalar_parameter,
unwrap_scalar_parameter,
sqlize,
)
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import replace_bucket

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -222,6 +222,15 @@ def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If:
def cast(self, arg, to: dt.DataType) -> sge.Cast:
return sg.cast(sge.convert(arg), to=self.type_mapper.from_ibis(to))

def _prepare_params(self, params):
result = {}
for param, value in params.items():
node = param.op()
if isinstance(node, ops.Alias):
node = node.arg
result[node] = value
return result

def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
"""Translate an ibis operation to a sqlglot expression.
Expand All @@ -245,20 +254,8 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
"""
# substitute parameters immediately to avoid having to define a
# ScalarParameter translation rule
#
# this lets us avoid threading `params` through every `translate_val`
# call only to be used in the one place it would be needed: the
# ScalarParameter `translate_val` rule
params = {
# remove aliases from scalar parameters
param.op().replace(unwrap_scalar_parameter): value
for param, value in (params or {}).items()
}

op = op.replace(
replace_scalar_parameter(params) | reduce(operator.or_, self.rewrites)
)
op, ctes = sqlize(op)
params = self._prepare_params(params)
op, ctes = sqlize(op, params=params, rewrites=self.rewrites)

aliases = {}
counter = itertools.count()
Expand Down
137 changes: 105 additions & 32 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from __future__ import annotations

from typing import Literal, Optional
import operator
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal, Optional

import toolz
from public import public
Expand All @@ -16,11 +18,14 @@
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import var
from ibis.common.graph import Graph
from ibis.common.patterns import Object, replace
from ibis.common.patterns import Object, Pattern, _, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import p
from ibis.expr.rewrites import d, p, replace_parameter
from ibis.expr.schema import Schema

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

x = var("x")
y = var("y")

Expand Down Expand Up @@ -77,25 +82,25 @@ def dtype(self):


@replace(p.Project)
def project_to_select(_):
def project_to_select(_, **kwargs):
"""Convert a Project node to a Select node."""
return Select(_.parent, selections=_.values)


@replace(p.Filter)
def filter_to_select(_):
def filter_to_select(_, **kwargs):
"""Convert a Filter node to a Select node."""
return Select(_.parent, selections=_.values, predicates=_.predicates)


@replace(p.Sort)
def sort_to_select(_):
def sort_to_select(_, **kwargs):
"""Convert a Sort node to a Select node."""
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.WindowFunction)
def window_function_to_window(_):
def window_function_to_window(_, **kwargs):
"""Convert a WindowFunction node to a Window node."""
if isinstance(_.frame, ops.RowsWindowFrame) and _.frame.max_lookback is not None:
raise NotImplementedError("max_lookback is not supported for SQL backends")
Expand All @@ -109,18 +114,8 @@ def window_function_to_window(_):
)


@replace(p.Log2)
def replace_log2(_):
return ops.Log(_.arg, base=2)


@replace(p.Log10)
def replace_log10(_):
return ops.Log(_.arg, base=10)


@replace(Object(Select, Object(Select)))
def merge_select_select(_):
def merge_select_select(_, **kwargs):
"""Merge subsequent Select relations into one.
This rewrites eliminates `_.parent` by merging the outer and the inner
Expand Down Expand Up @@ -165,27 +160,105 @@ def extract_ctes(node):
return result


def sqlize(node):
"""Lower the ibis expression graph to a SQL-like relational algebra."""
def sqlize(
node: ops.Node,
params: Mapping[ops.ScalarParameter, Any],
rewrites: Sequence[Pattern] = (),
) -> tuple[ops.Node, list[ops.Node]]:
"""Lower the ibis expression graph to a SQL-like relational algebra.
Parameters
----------
node
The root node of the expression graph.
params
A mapping of scalar parameters to their values.
rewrites
Supplementary rewrites to apply to the expression graph.
Returns
-------
Tuple of the rewritten expression graph and a list of CTEs.
"""
assert isinstance(node, ops.Relation)

step1 = node.replace(
window_function_to_window
# apply the backend specific rewrites
node = node.replace(reduce(operator.or_, rewrites))

# lower the expression graph to a SQL-like relational algebra
context = {"params": params}
sqlized = node.replace(
replace_parameter
| project_to_select
| filter_to_select
| sort_to_select
| window_function_to_window,
context=context,
)
step2 = step1.replace(merge_select_select)

ctes = extract_ctes(step2)
# squash subsequent Select nodes into one
simplified = sqlized.replace(merge_select_select)

# extract common table expressions while wrapping them in a CTE node
ctes = extract_ctes(simplified)
subs = {cte: CTE(cte) for cte in ctes}
step3 = step2.replace(subs)
result = simplified.replace(subs)

return result, ctes


# supplemental rewrites selectively used on a per-backend basis

return step3, ctes
"""Replace `log2` and `log10` with `log`."""
replace_log2 = p.Log2 >> d.Log(_.arg, base=2)
replace_log10 = p.Log10 >> d.Log(_.arg, base=10)


"""Add an ORDER BY clause to rank window functions that don't have one."""
add_order_by_to_empty_ranking_window_functions = p.WindowFunction(
func=p.NTile(y),
frame=p.WindowFrame(order_by=()) >> _.copy(order_by=(y,)),
)

"""Replace checks against an empty right side with `False`."""
empty_in_values_right_side = p.InValues(options=()) >> d.Literal(False, dtype=dt.bool)


@replace(
p.WindowFunction(p.RankBase | p.NTile)
| p.StringFind
| p.FindInSet
| p.ArrayPosition
)
def one_to_zero_index(_, **kwargs):
"""Subtract one from one-index functions."""
return ops.Subtract(_, 1)


@replace(ops.NthValue)
def add_one_to_nth_value_input(_, **kwargs):
if isinstance(_.nth, ops.Literal):
nth = ops.Literal(_.nth.value + 1, dtype=_.nth.dtype)
else:
nth = ops.Add(_.nth, 1)
return _.copy(nth=nth)


@replace(p.Sample)
def rewrite_sample_as_filter(_, **kwargs):
"""Rewrite Sample as `t.filter(random() <= fraction)`.
Errors as unsupported if a `seed` is specified.
"""
if _.seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),))


@replace(p.WindowFunction(p.First(x, y)))
def rewrite_first_to_first_value(_, x, y):
@replace(p.WindowFunction(p.First(x, where=y)))
def rewrite_first_to_first_value(_, x, y, **kwargs):
"""Rewrite Ibis's first to first_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
Expand All @@ -194,8 +267,8 @@ def rewrite_first_to_first_value(_, x, y):
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, y)))
def rewrite_last_to_last_value(_, x, y):
@replace(p.WindowFunction(p.Last(x, where=y)))
def rewrite_last_to_last_value(_, x, y, **kwargs):
"""Rewrite Ibis's last to last_value when used in a window function."""
if y is not None:
raise com.UnsupportedOperationError(
Expand All @@ -205,7 +278,7 @@ def rewrite_last_to_last_value(_, x, y):


@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, y, **__):
def rewrite_empty_order_by_window(_, y, **kwargs):
return _.copy(frame=y.copy(order_by=(ops.NULL,)))


Expand All @@ -220,5 +293,5 @@ def exclude_unsupported_window_frame_from_row_number(_, y):
y @ p.WindowFrame(start=None),
)
)
def exclude_unsupported_window_frame_from_ops(_, y):
def exclude_unsupported_window_frame_from_ops(_, y, **kwargs):
return _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,)))
6 changes: 1 addition & 5 deletions ibis/backends/pandas/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ibis.common.collections import FrozenDict
from ibis.common.patterns import replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import replace_parameter
from ibis.expr.schema import Schema
from ibis.util import gen_name

Expand Down Expand Up @@ -298,11 +299,6 @@ def rewrite_scalar_subquery(_, **kwargs):
return PandasScalarSubquery(_.rel)


@replace(ops.ScalarParameter)
def replace_parameter(_, params, **kwargs):
return ops.Literal(value=params[_], dtype=_.dtype)


@replace(ops.UnboundTable)
def bind_unbound_table(_, backend, **kwargs):
return ops.DatabaseTable(name=_.name, schema=_.schema, source=backend)
Expand Down
Loading

0 comments on commit 5a28ad6

Please sign in to comment.