Skip to content

Commit

Permalink
fix: fix SQL backend has_operation to include operations supported th…
Browse files Browse the repository at this point in the history
…rough rewrite rules
  • Loading branch information
jcrist committed Mar 21, 2024
1 parent f2e1465 commit 133a1f1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ def dialect(self) -> sg.Dialect:
@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
compiler = cls.compiler
if operation in compiler.extra_supported_ops:
return True
method = getattr(compiler, f"visit_{operation.__name__}", None)
return method is not None and method not in (
return method not in (
None,
compiler.visit_Undefined,
compiler.visit_Unsupported,
)
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from public import public

import ibis.common.exceptions as com
import ibis.common.patterns as pats
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
Expand Down Expand Up @@ -175,6 +176,20 @@ class SQLGlotCompiler(abc.ABC):
)
"""A sequence of rewrites to apply to the expression tree before compilation."""

extra_supported_ops: frozenset = frozenset(
(
ops.Project,
ops.Filter,
ops.Sort,
ops.WindowFunction,
ops.RowsWindowFrame,
ops.RangeWindowFrame,
)
)
"""A frozenset of ops classes that are supported, but don't have explicit
`visit_*` methods (usually due to being handled by rewrite rules). Used by
`has_operation`"""

no_limit_value: sge.Null | None = None
"""The value to use to indicate no limit."""

Expand Down Expand Up @@ -366,6 +381,15 @@ def impl(self, _, *, _name: str = target_name, **kw):
if not hasattr(cls, name):
setattr(cls, name, cls.visit_Undefined)

# Expand extra_supported_ops with any rewrite rules
extra_supported_ops = set(cls.extra_supported_ops)
for rule in cls.rewrites:
if isinstance(rule, pats.Replace) and isinstance(
rule.matcher, pats.InstanceOf
):
extra_supported_ops.add(rule.matcher.type)
cls.extra_supported_ops = frozenset(extra_supported_ops)

@property
@abc.abstractmethod
def dialect(self) -> str:
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/sqlite/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytest import param

import ibis
import ibis.expr.operations as ops
from ibis.conftest import not_windows


Expand Down Expand Up @@ -75,3 +76,15 @@ def test_connect(url, ext, tmp_path):
con = ibis.connect(url(path))
one = ibis.literal(1)
assert con.execute(one) == 1


def test_has_operation(con):
# Core operations handled in non-standard ways
for op in [ops.Project, ops.Filter, ops.Sort, ops.Aggregate]:
assert con.has_operation(op)
# Handled by base class rewrite
assert con.has_operation(ops.Capitalize)
# Handled by compiler-specific rewrite
assert con.has_operation(ops.Sample)
# Handled by visit_* method
assert con.has_operation(ops.Cast)

0 comments on commit 133a1f1

Please sign in to comment.