From 78dc3939750045193dad6f209b0216ed638c9fb5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 14 Feb 2024 07:56:26 -0500 Subject: [PATCH] refactor(backends): remove singledispatchmethod from the sql backends (#8338) ## Description of changes This pull request does two things: 1. Removes use of `singledispatchmethod` in the SQL compilers 1. Fixes the support matrix accuracy for SQL backends Follow-ups: * ~Deal with filtering out RisingWave geospatial~ Handled here * `__init_subclass__` for `SIMPLE_OPS` handled here * Fix coverage accuracy for the non-SQL backends ## Issues closed Fixes #8283. Thanks to @jcrist for the `__init_subclass__` tip, that saved N backends duplication of filling in undefined operations. --- docs/support_matrix.qmd | 2 - ibis/backends/base/sqlglot/__init__.py | 10 +- ibis/backends/base/sqlglot/compiler.py | 207 +++++++----------- ibis/backends/bigquery/compiler.py | 154 ++++--------- .../datetime/out.sql | 2 +- .../string_time/out.sql | 2 +- .../string_timestamp/out.sql | 2 +- .../time/out.sql | 2 +- .../timestamp/out.sql | 2 +- .../test_literal_year/date/out.sql | 2 +- .../test_literal_year/datetime/out.sql | 2 +- .../test_literal_year/string_date/out.sql | 2 +- .../string_timestamp/out.sql | 2 +- .../test_literal_year/timestamp/out.sql | 2 +- .../test_literal_year/timestamp_date/out.sql | 2 +- ibis/backends/clickhouse/compiler.py | 109 +++------ ibis/backends/datafusion/compiler.py | 130 ++++------- ibis/backends/druid/compiler.py | 116 ++++------ ibis/backends/duckdb/compiler.py | 54 +---- ibis/backends/exasol/compiler.py | 154 ++++++------- ibis/backends/flink/compiler.py | 147 +++++-------- ibis/backends/impala/compiler.py | 119 ++++------ ibis/backends/mssql/compiler.py | 181 +++++++-------- ibis/backends/mysql/compiler.py | 101 +++------ ibis/backends/oracle/compiler.py | 117 ++++------ ibis/backends/polars/__init__.py | 10 +- ibis/backends/postgres/compiler.py | 79 +------ .../backends/postgres/tests/test_functions.py | 3 +- ibis/backends/pyspark/compiler.py | 75 ++----- ibis/backends/risingwave/compiler.py | 33 ++- ibis/backends/snowflake/compiler.py | 115 ++-------- ibis/backends/sqlite/compiler.py | 144 ++++-------- ibis/backends/tests/test_generic.py | 15 +- ibis/backends/tests/test_string.py | 1 - ibis/backends/trino/compiler.py | 80 ++----- 35 files changed, 711 insertions(+), 1467 deletions(-) diff --git a/docs/support_matrix.qmd b/docs/support_matrix.qmd index 48ec4d07bc4e..e9e70d37bf06 100644 --- a/docs/support_matrix.qmd +++ b/docs/support_matrix.qmd @@ -7,8 +7,6 @@ hide: ```{python} #| echo: false -from pathlib import Path - import pandas as pd import ibis diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 47c715129cf2..9380ba8c8916 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -35,10 +35,12 @@ def dialect(self) -> sg.Dialect: @classmethod def has_operation(cls, operation: type[ops.Value]) -> bool: - # singledispatchmethod overrides `__get__` so we can't directly access - # the dispatcher - dispatcher = cls.compiler.visit_node.register.__self__.dispatcher - return dispatcher.dispatch(operation) is not dispatcher.dispatch(object) + compiler = cls.compiler + method = getattr(compiler, f"visit_{operation.__name__}", None) + return method is not None and method not in ( + compiler.visit_Undefined, + compiler.visit_Unsupported, + ) def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: import pandas as pd diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index c6760b06cde1..bb481db9bd91 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -6,7 +6,7 @@ import math import string from collections.abc import Iterator, Mapping -from functools import partial, reduce, singledispatchmethod +from functools import partial, reduce from itertools import starmap from typing import TYPE_CHECKING, Any, Callable, ClassVar @@ -19,9 +19,6 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.rewrites import ( - CTE, - Select, - Window, add_one_to_nth_value_input, add_order_by_to_empty_ranking_window_functions, empty_in_values_right_side, @@ -39,6 +36,17 @@ from ibis.backends.base.sqlglot.datatypes import SqlglotType +def get_leaf_classes(op): + for child_class in op.__subclasses__(): + if not child_class.__subclasses__(): + yield child_class + else: + yield from get_leaf_classes(child_class) + + +ALL_OPERATIONS = frozenset(get_leaf_classes(ops.Node)) + + class AggGen: __slots__ = ("aggfunc",) @@ -181,11 +189,26 @@ class SQLGlotCompiler(abc.ABC): ) """Backend's negative infinity literal.""" + UNSUPPORTED_OPERATIONS: frozenset[type[ops.Node]] = frozenset() + """Set of operations the backend doesn't support.""" + def __init__(self) -> None: self.agg = AggGen(aggfunc=self._aggregate) self.f = FuncGen() self.v = VarGen() + def __init_subclass__(cls, **kwargs): + for leaf in ALL_OPERATIONS: + if not hasattr(cls, f"visit_{leaf.__name__}"): + setattr(cls, f"visit_{leaf.__name__}", cls.visit_Undefined) + + for leaf in cls.UNSUPPORTED_OPERATIONS: + # change to visit_Unsupported in a follow up + # TODO: handle geoespatial ops as a separate case? + setattr(cls, f"visit_{leaf.__name__}", cls.visit_Undefined) + + super().__init_subclass__(**kwargs) + @property @abc.abstractmethod def dialect(self) -> str: @@ -295,31 +318,34 @@ def fn(node, _, **kwargs): return out - @singledispatchmethod - def visit_node(self, op: ops.Node, **_): - raise com.OperationNotDefinedError( - f"No translation rule for {type(op).__name__}" - ) + def visit_node(self, op: ops.Node, **kwargs): + if isinstance(op, ops.ScalarUDF): + return self.visit_ScalarUDF(op, **kwargs) + elif isinstance(op, ops.AggUDF): + return self.visit_AggUDF(op, **kwargs) + else: + method = getattr(self, f"visit_{type(op).__name__}", None) + if method is not None: + return method(op, **kwargs) + else: + raise com.OperationNotDefinedError( + f"No translation rule for {type(op).__name__}" + ) - @visit_node.register(ops.Field) def visit_Field(self, op, *, rel, name): return sg.column( self._gen_valid_name(name), table=rel.alias_or_name, quoted=self.quoted ) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): return self.cast(arg, to) - @visit_node.register(ops.ScalarSubquery) def visit_ScalarSubquery(self, op, *, rel): return rel.this.subquery() - @visit_node.register(ops.Alias) def visit_Alias(self, op, *, arg, name): return arg - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): """Compile a literal value. @@ -443,17 +469,14 @@ def visit_DefaultLiteral(self, op, *, value, dtype): raise NotImplementedError(f"Unsupported type: {dtype!r}") - @visit_node.register(ops.BitwiseNot) def visit_BitwiseNot(self, op, *, arg): return sge.BitwiseNot(this=arg) ### Mathematical Calisthenics - @visit_node.register(ops.E) def visit_E(self, op): return self.f.exp(1) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): if base is None: return self.f.ln(arg) @@ -462,7 +485,6 @@ def visit_Log(self, op, *, arg, base): else: return self.f.ln(arg) / self.f.ln(base) - @visit_node.register(ops.Clip) def visit_Clip(self, op, *, arg, lower, upper): if upper is not None: arg = self.if_(arg.is_(NULL), arg, self.f.least(upper, arg)) @@ -472,16 +494,15 @@ def visit_Clip(self, op, *, arg, lower, upper): return arg - @visit_node.register(ops.FloorDivide) def visit_FloorDivide(self, op, *, left, right): return self.cast(self.f.floor(left / right), op.dtype) - @visit_node.register(ops.Ceil) - @visit_node.register(ops.Floor) - def visit_CeilFloor(self, op, *, arg): - return self.cast(self.f[type(op).__name__.lower()](arg), op.dtype) + def visit_Ceil(self, op, *, arg): + return self.cast(self.f.ceil(arg), op.dtype) + + def visit_Floor(self, op, *, arg): + return self.cast(self.f.floor(arg), op.dtype) - @visit_node.register(ops.Round) def visit_Round(self, op, *, arg, digits): if digits is not None: return sge.Round(this=arg, decimals=digits) @@ -489,21 +510,17 @@ def visit_Round(self, op, *, arg, digits): ### Dtype Dysmorphia - @visit_node.register(ops.TryCast) def visit_TryCast(self, op, *, arg, to): return sge.TryCast(this=arg, to=self.type_mapper.from_ibis(to)) ### Comparator Conundrums - @visit_node.register(ops.Between) def visit_Between(self, op, *, arg, lower_bound, upper_bound): return sge.Between(this=arg, low=lower_bound, high=upper_bound) - @visit_node.register(ops.Negate) def visit_Negate(self, op, *, arg): return -paren(arg) - @visit_node.register(ops.Not) def visit_Not(self, op, *, arg): if isinstance(arg, sge.Filter): return sge.Filter(this=sg.not_(arg.this), expression=arg.expression) @@ -511,61 +528,45 @@ def visit_Not(self, op, *, arg): ### Timey McTimeFace - @visit_node.register(ops.Time) def visit_Time(self, op, *, arg): return self.cast(arg, to=dt.time) - @visit_node.register(ops.TimestampNow) def visit_TimestampNow(self, op): return sge.CurrentTimestamp() - @visit_node.register(ops.Strftime) def visit_Strftime(self, op, *, arg, format_str): return sge.TimeToStr(this=arg, format=format_str) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.epoch(self.cast(arg, dt.timestamp)) - @visit_node.register(ops.ExtractYear) def visit_ExtractYear(self, op, *, arg): return self.f.extract(self.v.year, arg) - @visit_node.register(ops.ExtractMonth) def visit_ExtractMonth(self, op, *, arg): return self.f.extract(self.v.month, arg) - @visit_node.register(ops.ExtractDay) def visit_ExtractDay(self, op, *, arg): return self.f.extract(self.v.day, arg) - @visit_node.register(ops.ExtractDayOfYear) def visit_ExtractDayOfYear(self, op, *, arg): return self.f.extract(self.v.dayofyear, arg) - @visit_node.register(ops.ExtractQuarter) def visit_ExtractQuarter(self, op, *, arg): return self.f.extract(self.v.quarter, arg) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.extract(self.v.week, arg) - @visit_node.register(ops.ExtractHour) def visit_ExtractHour(self, op, *, arg): return self.f.extract(self.v.hour, arg) - @visit_node.register(ops.ExtractMinute) def visit_ExtractMinute(self, op, *, arg): return self.f.extract(self.v.minute, arg) - @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): return self.f.extract(self.v.second, arg) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimeTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", @@ -584,11 +585,15 @@ def visit_TimestampTruncate(self, op, *, arg, unit): return self.f.date_trunc(unit, arg) - @visit_node.register(ops.DayOfWeekIndex) + def visit_DateTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + + def visit_TimeTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.dayofweek(arg) + 6) % 7 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): # day of week number is 0-indexed # Sunday == 0 @@ -598,25 +603,20 @@ def visit_DayOfWeekName(self, op, *, arg): ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))), ) - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): return sge.Interval(this=sge.convert(arg), unit=unit.singular.upper()) ### String Instruments - @visit_node.register(ops.Strip) def visit_Strip(self, op, *, arg): return self.f.trim(arg, string.whitespace) - @visit_node.register(ops.RStrip) def visit_RStrip(self, op, *, arg): return self.f.rtrim(arg, string.whitespace) - @visit_node.register(ops.LStrip) def visit_LStrip(self, op, *, arg): return self.f.ltrim(arg, string.whitespace) - @visit_node.register(ops.Substring) def visit_Substring(self, op, *, arg, start, length): start += 1 arg_length = self.f.length(arg) @@ -633,7 +633,6 @@ def visit_Substring(self, op, *, arg, start, length): self.f.substring(arg, start + arg_length, length), ) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( @@ -647,72 +646,57 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.strpos(arg, substr) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return self.f.regexp_replace(arg, pattern, replacement, "g") - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): return self.f.concat(*arg) - @visit_node.register(ops.StringJoin) def visit_StringJoin(self, op, *, sep, arg): return self.f.concat_ws(sep, *arg) - @visit_node.register(ops.StringSQLLike) def visit_StringSQLLike(self, op, *, arg, pattern, escape): return arg.like(pattern) - @visit_node.register(ops.StringSQLILike) def visit_StringSQLILike(self, op, *, arg, pattern, escape): return arg.ilike(pattern) ### NULL PLAYER CHARACTER - @visit_node.register(ops.IsNull) def visit_IsNull(self, op, *, arg): return arg.is_(NULL) - @visit_node.register(ops.NotNull) def visit_NotNull(self, op, *, arg): return arg.is_(sg.not_(NULL)) - @visit_node.register(ops.InValues) def visit_InValues(self, op, *, value, options): return value.isin(*options) ### Counting - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): return self.agg.count(sge.Distinct(expressions=[arg]), where=where) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): return self.agg.count(sge.Distinct(expressions=[STAR]), where=where) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): return self.agg.count(STAR, where=where) - @visit_node.register(ops.Sum) def visit_Sum(self, op, *, arg, where): if op.arg.dtype.is_boolean(): arg = self.cast(arg, dt.int32) return self.agg.sum(arg, where=where) - @visit_node.register(ops.Mean) def visit_Mean(self, op, *, arg, where): if op.arg.dtype.is_boolean(): arg = self.cast(arg, dt.int32) return self.agg.avg(arg, where=where) - @visit_node.register(ops.Min) def visit_Min(self, op, *, arg, where): if op.arg.dtype.is_boolean(): return self.agg.bool_and(arg, where=where) return self.agg.min(arg, where=where) - @visit_node.register(ops.Max) def visit_Max(self, op, *, arg, where): if op.arg.dtype.is_boolean(): return self.agg.bool_or(arg, where=where) @@ -720,8 +704,6 @@ def visit_Max(self, op, *, arg, where): ### Stats - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) def visit_Quantile(self, op, *, arg, quantile, where): suffix = "cont" if op.arg.dtype.is_numeric() else "disc" funcname = f"percentile_{suffix}" @@ -733,9 +715,8 @@ def visit_Quantile(self, op, *, arg, quantile, where): expr = sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - @visit_node.register(ops.Variance) - @visit_node.register(ops.StandardDev) - @visit_node.register(ops.Covariance) + visit_MultiQuantile = visit_Quantile + def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): hows = {"sample": "samp", "pop": "pop"} funcs = { @@ -754,7 +735,10 @@ def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): funcname = f"{funcs[type(op)]}_{hows[how]}" return self.agg[funcname](*args, where=where) - @visit_node.register(ops.Arbitrary) + visit_Variance = ( + visit_StandardDev + ) = visit_Covariance = visit_VarianceStandardDevCovariance + def visit_Arbitrary(self, op, *, arg, how, where): if how == "heavy": raise com.UnsupportedOperationError( @@ -762,69 +746,56 @@ def visit_Arbitrary(self, op, *, arg, how, where): ) return self.agg[how](arg, where=where) - @visit_node.register(ops.SimpleCase) - @visit_node.register(ops.SearchedCase) def visit_SimpleCase(self, op, *, base=None, cases, results, default): return sge.Case( this=base, ifs=list(map(self.if_, cases, results)), default=default ) - @visit_node.register(ops.ExistsSubquery) + visit_SearchedCase = visit_SimpleCase + def visit_ExistsSubquery(self, op, *, rel): select = rel.this.select(1, append=False) return self.f.exists(select) - @visit_node.register(ops.InSubquery) def visit_InSubquery(self, op, *, rel, needle): return needle.isin(rel.this) - @visit_node.register(ops.Array) def visit_Array(self, op, *, exprs): return self.f.array(*exprs) - @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( [value.as_(name, quoted=self.quoted) for name, value in zip(names, values)] ) - @visit_node.register(ops.StructField) def visit_StructField(self, op, *, arg, field): return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) - @visit_node.register(ops.IdenticalTo) def visit_IdenticalTo(self, op, *, left, right): return sge.NullSafeEQ(this=left, expression=right) - @visit_node.register(ops.Greatest) def visit_Greatest(self, op, *, arg): return self.f.greatest(*arg) - @visit_node.register(ops.Least) def visit_Least(self, op, *, arg): return self.f.least(*arg) - @visit_node.register(ops.Coalesce) def visit_Coalesce(self, op, *, arg): return self.f.coalesce(*arg) ### Ordering and window functions - @visit_node.register(ops.SortKey) def visit_SortKey(self, op, *, expr, ascending: bool): return sge.Ordered(this=expr, desc=not ascending) - @visit_node.register(ops.ApproxMedian) def visit_ApproxMedian(self, op, *, arg, where): return self.agg.approx_quantile(arg, 0.5, where=where) - @visit_node.register(ops.WindowBoundary) def visit_WindowBoundary(self, op, *, value, preceding): # TODO: bit of a hack to return a dict, but there's no sqlglot expression # that corresponds to _only_ this information return {"value": value, "side": "preceding" if preceding else "following"} - @visit_node.register(Window) def visit_Window(self, op, *, how, func, start, end, group_by, order_by): if start is None: start = {} @@ -862,8 +833,6 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by): def _minimize_spec(start, end, spec): return spec - @visit_node.register(ops.Lag) - @visit_node.register(ops.Lead) def visit_LagLead(self, op, *, arg, offset, default): args = [arg] @@ -878,11 +847,11 @@ def visit_LagLead(self, op, *, arg, offset, default): return self.f[type(op).__name__.lower()](*args) - @visit_node.register(ops.Argument) + visit_Lag = visit_Lead = visit_LagLead + def visit_Argument(self, op, *, name: str, shape, dtype): return sg.to_identifier(op.param) - @visit_node.register(ops.RowID) def visit_RowID(self, op, *, table): return sg.column(op.name, table=table.alias_or_name, quoted=self.quoted) @@ -903,17 +872,12 @@ def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str: self.dialect ) - @visit_node.register(ops.ScalarUDF) def visit_ScalarUDF(self, op, **kw): return self.f[self.__sql_name__(op)](*kw.values()) - @visit_node.register(ops.AggUDF) def visit_AggUDF(self, op, *, where, **kw): return self.agg[self.__sql_name__(op)](*kw.values(), where=where) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.TimestampDelta) def visit_TimestampDelta(self, op, *, part, left, right): # dialect is necessary due to sqlglot's default behavior # of `part` coming last @@ -921,14 +885,14 @@ def visit_TimestampDelta(self, op, *, part, left, right): this=left, expression=right, unit=part, dialect=self.dialect ) - @visit_node.register(ops.TimestampBucket) + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta + def visit_TimestampBucket(self, op, *, arg, interval, offset): origin = self.f.cast("epoch", self.type_mapper.from_ibis(dt.timestamp)) if offset is not None: origin += offset return self.f.time_bucket(interval, arg, origin) - @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): return sge.ArrayConcat(this=arg[0], expressions=list(arg[1:])) @@ -961,7 +925,6 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): self._dedup_name, toolz.keymap(self._gen_valid_name, exprs).items() ) - @visit_node.register(Select) def visit_Select(self, op, *, parent, selections, predicates, sort_keys): # if we've constructed a useless projection return the parent relation if not selections and not predicates and not sort_keys: @@ -980,11 +943,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys): return result - @visit_node.register(ops.DummyTable) def visit_DummyTable(self, op, *, values): return sg.select(*self._cleanup_names(values)) - @visit_node.register(ops.UnboundTable) def visit_UnboundTable( self, op, *, name: str, schema: sch.Schema, namespace: ops.Namespace ) -> sg.Table: @@ -992,13 +953,11 @@ def visit_UnboundTable( name, db=namespace.schema, catalog=namespace.database, quoted=self.quoted ) - @visit_node.register(ops.InMemoryTable) def visit_InMemoryTable( self, op, *, name: str, schema: sch.Schema, data ) -> sg.Table: return sg.table(name, quoted=self.quoted) - @visit_node.register(ops.DatabaseTable) def visit_DatabaseTable( self, op, @@ -1012,11 +971,9 @@ def visit_DatabaseTable( name, db=namespace.schema, catalog=namespace.database, quoted=self.quoted ) - @visit_node.register(ops.SelfReference) def visit_SelfReference(self, op, *, parent, identifier): return parent - @visit_node.register(ops.JoinChain) def visit_JoinChain(self, op, *, first, rest, values): result = sg.select(*self._cleanup_names(values)).from_(first) @@ -1026,7 +983,6 @@ def visit_JoinChain(self, op, *, first, rest, values): result = result.join(link) return result - @visit_node.register(ops.JoinLink) def visit_JoinLink(self, op, *, how, table, predicates): sides = { "inner": None, @@ -1063,7 +1019,6 @@ def visit_JoinLink(self, op, *, how, table, predicates): def _generate_groups(groups): return map(sge.convert, range(1, len(groups) + 1)) - @visit_node.register(ops.Aggregate) def visit_Aggregate(self, op, *, parent, groups, metrics): sel = sg.select( *self._cleanup_names(groups), *self._cleanup_names(metrics) @@ -1079,7 +1034,6 @@ def _add_parens(self, op, sg_expr): return paren(sg_expr) return sg_expr - @visit_node.register(ops.Filter) def visit_Filter(self, op, *, parent, predicates): predicates = ( self._add_parens(raw_predicate, predicate) @@ -1090,14 +1044,12 @@ def visit_Filter(self, op, *, parent, predicates): except AttributeError: return sg.select(STAR).from_(parent).where(*predicates) - @visit_node.register(ops.Sort) def visit_Sort(self, op, *, parent, keys): try: return parent.order_by(*keys) except AttributeError: return sg.select(STAR).from_(parent).order_by(*keys) - @visit_node.register(ops.Union) def visit_Union(self, op, *, left, right, distinct): if isinstance(left, sge.Table): left = sg.select(STAR).from_(left) @@ -1111,7 +1063,6 @@ def visit_Union(self, op, *, left, right, distinct): distinct=distinct, ) - @visit_node.register(ops.Intersection) def visit_Intersection(self, op, *, left, right, distinct): if isinstance(left, sge.Table): left = sg.select(STAR).from_(left) @@ -1125,7 +1076,6 @@ def visit_Intersection(self, op, *, left, right, distinct): distinct=distinct, ) - @visit_node.register(ops.Difference) def visit_Difference(self, op, *, left, right, distinct): if isinstance(left, sge.Table): left = sg.select(STAR).from_(left) @@ -1139,7 +1089,6 @@ def visit_Difference(self, op, *, left, right, distinct): distinct=distinct, ) - @visit_node.register(ops.Limit) def visit_Limit(self, op, *, parent, n, offset): # push limit/offset into subqueries if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: @@ -1175,11 +1124,9 @@ def visit_Limit(self, op, *, parent, n, offset): return result.subquery(alias) return result - @visit_node.register(ops.Distinct) def visit_Distinct(self, op, *, parent): return sg.select(STAR).distinct().from_(parent) - @visit_node.register(ops.DropNa) def visit_DropNa(self, op, *, parent, how, subset): if subset is None: subset = [ @@ -1205,7 +1152,6 @@ def visit_DropNa(self, op, *, parent, how, subset): except AttributeError: return sg.select(STAR).from_(parent).where(predicate) - @visit_node.register(ops.FillNa) def visit_FillNa(self, op, *, parent, replacements): if isinstance(replacements, Mapping): mapping = replacements @@ -1225,11 +1171,9 @@ def visit_FillNa(self, op, *, parent, replacements): } return sg.select(*self._cleanup_names(exprs)).from_(parent) - @visit_node.register(CTE) def visit_CTE(self, op, *, parent): return sg.table(parent.alias_or_name, quoted=self.quoted) - @visit_node.register(ops.View) def visit_View(self, op, *, child, name: str): if isinstance(child, sge.Table): child = sg.select(STAR).from_(child) @@ -1239,26 +1183,28 @@ def visit_View(self, op, *, child, name: str): except AttributeError: return child.as_(name) - @visit_node.register(ops.SQLStringView) def visit_SQLStringView(self, op, *, query: str, child, schema): return sg.parse_one(query, read=self.dialect) - @visit_node.register(ops.SQLQueryResult) def visit_SQLQueryResult(self, op, *, query, schema, source): return sg.parse_one(query, dialect=self.dialect).subquery() - @visit_node.register(ops.JoinTable) def visit_JoinTable(self, op, *, parent, index): return parent - @visit_node.register(ops.Value) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError( + f"Compilation rule for {type(op).__name__!r} operation is not defined" + ) + + def visit_Unsupported(self, op, **_): + raise com.UnsupportedOperationError( + f"{type(op).__name__!r} operation is not supported in the {self.dialect} backend" + ) + _SIMPLE_OPS = { ops.Abs: "abs", @@ -1376,7 +1322,6 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): for _op, _sym in _BINARY_INFIX_OPS.items(): - @SQLGlotCompiler.visit_node.register(_op) def _fmt(self, op, *, _sym: sge.Expression = _sym, left, right): return _sym( this=self._add_parens(op.left, left), @@ -1393,14 +1338,12 @@ def _fmt(self, op, *, _sym: sge.Expression = _sym, left, right): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @SQLGlotCompiler.visit_node.register(_op) - def _fmt(self, op, *, _name: str = _name, where, **kw): + def _fmt(self, _, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @SQLGlotCompiler.visit_node.register(_op) - def _fmt(self, op, *, _name: str = _name, **kw): + def _fmt(self, _, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) setattr(SQLGlotCompiler, f"visit_{_op.__name__}", _fmt) diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index ada8ddf901c9..58e019adc789 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -3,7 +3,6 @@ from __future__ import annotations import re -from functools import singledispatchmethod import sqlglot as sg import sqlglot.expressions as sge @@ -42,6 +41,29 @@ class BigQueryCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.CountDistinctStar, + ops.DateDiff, + ops.ExtractAuthority, + ops.ExtractFile, + ops.ExtractFragment, + ops.ExtractHost, + ops.ExtractPath, + ops.ExtractProtocol, + ops.ExtractQuery, + ops.ExtractUserInfo, + ops.FindInSet, + ops.Median, + ops.Quantile, + ops.MultiQuantile, + ops.RegexSplit, + ops.RowID, + ops.TimestampBucket, + ops.TimestampDiff, + ) + ) + NAN = sge.Cast( this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) ) @@ -71,21 +93,14 @@ def _minimize_spec(start, end, spec): return None return spec - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.GeoXMax) - @visit_node.register(ops.GeoXMin) - @visit_node.register(ops.GeoYMax) - @visit_node.register(ops.GeoYMin) def visit_BoundingBox(self, op, *, arg): name = type(op).__name__[len("Geo") :].lower() return sge.Dot( this=self.f.st_boundingbox(arg), expression=sg.to_identifier(name) ) - @visit_node.register(ops.GeoSimplify) + visit_GeoXMax = visit_GeoXMin = visit_GeoYMax = visit_GeoYMin = visit_BoundingBox + def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): if ( not isinstance(op.preserve_collapsed, ops.Literal) @@ -97,27 +112,21 @@ def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): ) return self.f.st_simplify(arg, tolerance) - @visit_node.register(ops.ApproxMedian) def visit_ApproxMedian(self, op, *, arg, where): return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)] - @visit_node.register(ops.Pi) def visit_Pi(self, op): return self.f.acos(-1) - @visit_node.register(ops.E) def visit_E(self, op): return self.f.exp(1) - @visit_node.register(ops.TimeDelta) def visit_TimeDelta(self, op, *, left, right, part): return self.f.time_diff(left, right, part, dialect=self.dialect) - @visit_node.register(ops.DateDelta) def visit_DateDelta(self, op, *, left, right, part): return self.f.date_diff(left, right, part, dialect=self.dialect) - @visit_node.register(ops.TimestampDelta) def visit_TimestampDelta(self, op, *, left, right, part): left_tz = op.left.dtype.timezone right_tz = op.right.dtype.timezone @@ -131,27 +140,22 @@ def visit_TimestampDelta(self, op, *, left, right, part): "timestamp difference with mixed timezone/timezoneless values is not implemented" ) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.string_agg(arg, sep) - @visit_node.register(ops.FloorDivide) def visit_FloorDivide(self, op, *, left, right): return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) - @visit_node.register(ops.Log2) def visit_Log2(self, op, *, arg): return self.f.log(arg, 2, dialect=self.dialect) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): if base is None: return self.f.ln(arg) return self.f.log(arg, base, dialect=self.dialect) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): start = step = 1 array_length = self.f.array_length(arg) @@ -165,13 +169,11 @@ def visit_ArrayRepeat(self, op, *, arg, times): sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i)) ) - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): return self.f.concat( self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) ) - @visit_node.register(ops.NthValue) def visit_NthValue(self, op, *, arg, nth): if not isinstance(op.nth, ops.Literal): raise com.UnsupportedOperationError( @@ -179,33 +181,23 @@ def visit_NthValue(self, op, *, arg, nth): ) return self.f.nth_value(arg, nth) - @visit_node.register(ops.StrRight) def visit_StrRight(self, op, *, arg, nchars): return self.f.substr(arg, -self.f.least(self.f.length(arg), nchars)) - @visit_node.register(ops.StringJoin) def visit_StringJoin(self, op, *, arg, sep): return self.f.array_to_string(self.f.array(*arg), sep) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.f.mod(self.f.extract(self.v.dayofweek, arg) + 5, 7) - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return self.f.initcap(sge.Cast(this=arg, to="STRING FORMAT 'DAY'")) - @visit_node.register(ops.StringToTimestamp) def visit_StringToTimestamp(self, op, *, arg, format_str): if (timezone := op.dtype.timezone) is not None: return self.f.parse_timestamp(format_str, arg, timezone) return self.f.parse_datetime(format_str, arg) - @visit_node.register(ops.Floor) - def visit_Floor(self, op, *, arg): - return self.cast(self.f.floor(arg), op.dtype) - - @visit_node.register(ops.ArrayCollect) def visit_ArrayCollect(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) @@ -214,7 +206,6 @@ def visit_ArrayCollect(self, op, *, arg, where): def _neg_idx_to_pos(self, arg, idx): return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): index = sg.to_identifier("bq_arr_slice") cond = [index >= self._neg_idx_to_pos(arg, start)] @@ -227,11 +218,9 @@ def visit_ArraySlice(self, op, *, arg, start, stop): sg.select(el).from_(self._unnest(arg, as_=el, offset=index)).where(*cond) ) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return arg[self.f.safe_offset(index)] - @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): name = sg.to_identifier(util.gen_name("bq_arr_contains")) return sge.Exists( @@ -240,11 +229,9 @@ def visit_ArrayContains(self, op, *, arg, other): .where(name.eq(other)) ) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.strpos(haystack, needle) > 0 - @visit_node.register(ops.StringFind) def visti_StringFind(self, op, *, arg, substr, start, end): if start is not None: raise NotImplementedError( @@ -294,7 +281,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(str(value)) return None - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): if unit == IntervalUnit.NANOSECOND: raise com.UnsupportedOperationError( @@ -302,7 +288,6 @@ def visit_IntervalFromInteger(self, op, *, arg, unit): ) return sge.Interval(this=arg, unit=self.v[unit.singular]) - @visit_node.register(ops.Strftime) def visit_Strftime(self, op, *, arg, format_str): arg_dtype = op.arg.dtype if arg_dtype.is_timestamp(): @@ -316,12 +301,10 @@ def visit_Strftime(self, op, *, arg, format_str): assert arg_dtype.is_time(), arg_dtype return self.f.format_time(format_str, arg) - @visit_node.register(ops.IntervalMultiply) def visit_IntervalMultiply(self, op, *, left, right): unit = self.v[op.left.dtype.resolution.upper()] return sge.Interval(this=self.f.extract(unit, left) * right, unit=unit) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): unit = op.unit if unit == TimestampUnit.SECOND: @@ -337,7 +320,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): else: raise com.UnsupportedOperationError(f"Unit not supported: {unit}") - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if from_.is_timestamp() and to.is_integer(): @@ -360,34 +342,22 @@ def visit_Cast(self, op, *, arg, to): return self.cast(self.f.trunc(arg), dt.int64) return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): return arg[index] - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.unix_seconds(arg) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.extract(self.v.isoweek, arg) - @visit_node.register(ops.ExtractYear) - @visit_node.register(ops.ExtractQuarter) - @visit_node.register(ops.ExtractMonth) - @visit_node.register(ops.ExtractDay) - @visit_node.register(ops.ExtractDayOfYear) - @visit_node.register(ops.ExtractHour) - @visit_node.register(ops.ExtractMinute) - @visit_node.register(ops.ExtractSecond) - @visit_node.register(ops.ExtractMicrosecond) - @visit_node.register(ops.ExtractMillisecond) - def visit_ExtractDateField(self, op, *, arg): - name = type(op).__name__[len("Extract") :].upper() - return self.f.extract(self.v[name], arg) - - @visit_node.register(ops.TimestampTruncate) - def visit_Timestamp(self, op, *, arg, unit): + def visit_ExtractMillisecond(self, op, *, arg): + return self.f.extract(self.v.millisecond, arg) + + def visit_ExtractMicrosecond(self, op, *, arg): + return self.f.extract(self.v.microsecond, arg) + + def visit_TimestampTruncate(self, op, *, arg, unit): if unit == IntervalUnit.NANOSECOND: raise com.UnsupportedOperationError( f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" @@ -398,7 +368,6 @@ def visit_Timestamp(self, op, *, arg, unit): unit = unit.name return self.f.timestamp_trunc(arg, self.v[unit], dialect=self.dialect) - @visit_node.register(ops.DateTruncate) def visit_DateTruncate(self, op, *, arg, unit): if unit == DateUnit.WEEK: unit = "WEEK(MONDAY)" @@ -406,7 +375,6 @@ def visit_DateTruncate(self, op, *, arg, unit): unit = unit.name return self.f.date_trunc(arg, self.v[unit], dialect=self.dialect) - @visit_node.register(ops.TimeTruncate) def visit_TimeTruncate(self, op, *, arg, unit): if unit == TimeUnit.NANOSECOND: raise com.UnsupportedOperationError( @@ -454,11 +422,9 @@ def _make_range(self, func, start, stop, step, step_dtype): ) return self.if_(condition, self.f.array(inner), self.f.array()) - @visit_node.register(ops.IntegerRange) def visit_IntegerRange(self, op, *, start, stop, step): return self._make_range(self.f.generate_array, start, stop, step, op.step.dtype) - @visit_node.register(ops.TimestampRange) def visit_TimestampRange(self, op, *, start, stop, step): if op.start.dtype.timezone is None or op.stop.dtype.timezone is None: raise com.IbisTypeError( @@ -468,7 +434,6 @@ def visit_TimestampRange(self, op, *, start, stop, step): self.f.generate_timestamp_array, start, stop, step, op.step.dtype ) - @visit_node.register(ops.First) def visit_First(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) @@ -477,15 +442,13 @@ def visit_First(self, op, *, arg, where): ) return array[self.f.safe_offset(0)] - @visit_node.register(ops.Last) def visit_Last(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) return array[self.f.safe_offset(0)] - @visit_node.register(ops.Arbitrary) - def _arbitrary(self, op, *, arg, how, where): + def visit_Arbitrary(self, op, *, arg, how, where): if how != "first": raise com.UnsupportedOperationError( f"{how!r} value not supported for arbitrary in BigQuery" @@ -493,17 +456,14 @@ def _arbitrary(self, op, *, arg, how, where): return self.agg.any_value(arg, where=where) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, body, param): return self.f.array( sg.select(param).from_(self._unnest(arg, as_=param)).where(body) ) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, body, param): return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param))) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op, *, arg): lengths = [self.f.array_length(arr) - 1 for arr in arg] idx = sg.to_identifier(util.gen_name("bq_arr_idx")) @@ -518,7 +478,6 @@ def visit_ArrayZip(self, op, *, arg): sge.Select(kind="STRUCT", expressions=struct_fields).from_(indices) ) - @visit_node.register(ops.ArrayPosition) def visit_ArrayPosition(self, op, *, arg, other): name = sg.to_identifier(util.gen_name("bq_arr")) idx = sg.to_identifier(util.gen_name("bq_arr_idx")) @@ -532,27 +491,23 @@ def _unnest(self, expression, *, as_, offset=None): alias = sge.TableAlias(columns=[sg.to_identifier(as_)]) return sge.Unnest(expressions=[expression], alias=alias, offset=offset) - @visit_node.register(ops.ArrayRemove) def visit_ArrayRemove(self, op, *, arg, other): name = sg.to_identifier(util.gen_name("bq_arr")) unnest = self._unnest(arg, as_=name) return self.f.array(sg.select(name).from_(unnest).where(name.neq(other))) - @visit_node.register(ops.ArrayDistinct) def visit_ArrayDistinct(self, op, *, arg): name = util.gen_name("bq_arr") return self.f.array( sg.select(name).distinct().from_(self._unnest(arg, as_=name)) ) - @visit_node.register(ops.ArraySort) def visit_ArraySort(self, op, *, arg): name = util.gen_name("bq_arr") return self.f.array( sg.select(name).from_(self._unnest(arg, as_=name)).order_by(name) ) - @visit_node.register(ops.ArrayUnion) def visit_ArrayUnion(self, op, *, left, right): lname = util.gen_name("bq_arr_left") rname = util.gen_name("bq_arr_right") @@ -560,7 +515,6 @@ def visit_ArrayUnion(self, op, *, left, right): rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) return self.f.array(sg.union(lhs, rhs, distinct=True)) - @visit_node.register(ops.ArrayIntersect) def visit_ArrayIntersect(self, op, *, left, right): lname = util.gen_name("bq_arr_left") rname = util.gen_name("bq_arr_right") @@ -568,7 +522,6 @@ def visit_ArrayIntersect(self, op, *, left, right): rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) return self.f.array(sg.intersect(lhs, rhs, distinct=True)) - @visit_node.register(ops.Substring) def visit_Substring(self, op, *, arg, start, length): if isinstance(op.length, ops.Literal) and (value := op.length.value) < 0: raise com.IbisInputError( @@ -579,7 +532,6 @@ def visit_Substring(self, op, *, arg, start, length): if_neg = self.f.substr(arg, self.f.length(arg) + start + 1, *suffix) return self.if_(start >= 0, if_pos, if_neg) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): matches = self.f.regexp_contains(arg, pattern) nonzero_index_replace = self.f.regexp_replace( @@ -593,8 +545,6 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): extract = self.if_(index.eq(0), zero_index_replace, nonzero_index_replace) return self.if_(matches, extract, NULL) - @visit_node.register(ops.TimestampAdd) - @visit_node.register(ops.TimestampSub) def visit_TimestampAddSub(self, op, *, left, right): if not isinstance(right, sge.Interval): raise com.OperationNotDefinedError( @@ -610,8 +560,8 @@ def visit_TimestampAddSub(self, op, *, left, right): funcname = f"TIMESTAMP_{opname.upper()}" return self.f.anon[funcname](left, right) - @visit_node.register(ops.DateAdd) - @visit_node.register(ops.DateSub) + visit_TimestampAdd = visit_TimestampSub = visit_TimestampAddSub + def visit_DateAddSub(self, op, *, left, right): if not isinstance(right, sge.Interval): raise com.OperationNotDefinedError( @@ -626,7 +576,8 @@ def visit_DateAddSub(self, op, *, left, right): funcname = f"DATE_{opname.upper()}" return self.f.anon[funcname](left, right) - @visit_node.register(ops.Covariance) + visit_DateAdd = visit_DateSub = visit_DateAddSub + def visit_Covariance(self, op, *, left, right, how, where): if where is not None: left = self.if_(where, left, NULL) @@ -642,7 +593,6 @@ def visit_Covariance(self, op, *, left, right, how, where): assert how in ("POP", "SAMP"), 'how not in ("POP", "SAMP")' return self.agg[f"COVAR_{how}"](left, right, where=where) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise ValueError(f"Correlation with how={how!r} is not supported.") @@ -659,7 +609,6 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.TypeOf) def visit_TypeOf(self, op, *, arg): name = sg.to_identifier(util.gen_name("bq_typeof")) from_ = self._unnest(self.f.array(self.f.format("%T", arg)), as_=name) @@ -689,11 +638,9 @@ def visit_TypeOf(self, op, *, arg): case = sge.Case(ifs=ifs, default=sge.convert("UNKNOWN")) return sg.select(case).from_(from_).subquery() - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) - @visit_node.register(ops.HashBytes) def visit_HashBytes(self, op, *, arg, how): if how not in ("md5", "sha1", "sha256", "sha512"): raise NotImplementedError(how) @@ -703,47 +650,22 @@ def visit_HashBytes(self, op, *, arg, how): def _gen_valid_name(name: str) -> str: return "_".join(_NAME_REGEX.findall(name)) or "tmp" - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.countif(where) return self.f.count(STAR) - @visit_node.register(ops.Degrees) def visit_Degrees(self, op, *, arg): return paren(180 * arg / self.f.acos(-1)) - @visit_node.register(ops.Radians) def visit_Radians(self, op, *, arg): return paren(self.f.acos(-1) * arg / 180) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.DateDiff) - @visit_node.register(ops.ExtractAuthority) - @visit_node.register(ops.ExtractFile) - @visit_node.register(ops.ExtractFragment) - @visit_node.register(ops.ExtractHost) - @visit_node.register(ops.ExtractPath) - @visit_node.register(ops.ExtractProtocol) - @visit_node.register(ops.ExtractQuery) - @visit_node.register(ops.ExtractUserInfo) - @visit_node.register(ops.FindInSet) - @visit_node.register(ops.Median) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDiff) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.StringAscii: "ascii", @@ -806,13 +728,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @BigQueryCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @BigQueryCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/datetime/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/datetime/out.sql index bfefa9d352a9..ab23dad0560f 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/datetime/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/datetime/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(HOUR FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file + EXTRACT(hour FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_time/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_time/out.sql index b3acf7e814fc..dcbea35ecd83 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_time/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_time/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(HOUR FROM time(4, 55, 59)) AS `tmp` \ No newline at end of file + EXTRACT(hour FROM time(4, 55, 59)) AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_timestamp/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_timestamp/out.sql index bfefa9d352a9..ab23dad0560f 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_timestamp/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/string_timestamp/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(HOUR FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file + EXTRACT(hour FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/time/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/time/out.sql index b3acf7e814fc..dcbea35ecd83 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/time/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/time/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(HOUR FROM time(4, 55, 59)) AS `tmp` \ No newline at end of file + EXTRACT(hour FROM time(4, 55, 59)) AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/timestamp/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/timestamp/out.sql index bfefa9d352a9..ab23dad0560f 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/timestamp/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_timestamp_or_time/timestamp/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(HOUR FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file + EXTRACT(hour FROM datetime('2017-01-01T04:55:59')) AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/date/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/date/out.sql index c4e62dab9bdf..df8033bc163e 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/date/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/date/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file + EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/datetime/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/datetime/out.sql index c1f82282802c..afa341282049 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/datetime/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/datetime/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file + EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_date/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_date/out.sql index c4e62dab9bdf..df8033bc163e 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_date/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_date/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file + EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_timestamp/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_timestamp/out.sql index c1f82282802c..afa341282049 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_timestamp/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/string_timestamp/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file + EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp/out.sql index c1f82282802c..afa341282049 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file + EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp_date/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp_date/out.sql index c4e62dab9bdf..df8033bc163e 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp_date/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_literal_year/timestamp_date/out.sql @@ -1,2 +1,2 @@ SELECT - EXTRACT(YEAR FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file + EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1` \ No newline at end of file diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 78dd25b012d2..bb5311a8815b 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -2,7 +2,6 @@ import calendar import math -from functools import singledispatchmethod from typing import Any import sqlglot as sg @@ -30,6 +29,18 @@ class ClickHouseCompiler(SQLGlotCompiler): type_mapper = ClickHouseType rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.RowID, + ops.CumeDist, + ops.PercentRank, + ops.Time, + ops.TimeDelta, + ops.StringToTimestamp, + ops.Levenshtein, + ) + ) + def _aggregate(self, funcname: str, *args, where): has_filter = where is not None func = self.f[funcname + "If" * has_filter] @@ -48,11 +59,6 @@ def _minimize_spec(start, end, spec): return None return spec - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): _interval_cast_suffixes = { "s": "Second", @@ -74,21 +80,17 @@ def visit_Cast(self, op, *, arg, to): return self.f.toTimeZone(result, timezone) return result - @visit_node.register(ops.TryCast) def visit_TryCast(self, op, *, arg, to): return self.f.accurateCastOrNull(arg, self.type_mapper.to_string(to)) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return arg[self.if_(index >= 0, index + 1, index)] - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): param = sg.to_identifier("_") func = sge.Lambda(this=arg, expressions=[param]) return self.f.arrayFlatten(self.f.arrayMap(func, self.f.range(times))) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): start = parenthesize(op.start, start) start_correct = self.if_(start < 0, start, start + 1) @@ -109,15 +111,12 @@ def visit_ArraySlice(self, op, *, arg, start, stop): else: return self.f.arraySlice(arg, start_correct) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, where, arg): if where is not None: return self.f.countIf(where) return sge.Count(this=STAR) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) - def visit_QuantileMultiQuantile(self, op, *, arg, quantile, where): + def visit_Quantile(self, op, *, arg, quantile, where): if where is None: return self.agg.quantile(arg, quantile, where=where) @@ -128,7 +127,8 @@ def visit_QuantileMultiQuantile(self, op, *, arg, quantile, where): params=[arg, where], ) - @visit_node.register(ops.Correlation) + visit_MultiQuantile = visit_Quantile + def visit_Correlation(self, op, *, left, right, how, where): if how == "pop": raise ValueError( @@ -136,7 +136,6 @@ def visit_Correlation(self, op, *, left, right, how, where): ) return self.agg.corr(left, right, where=where) - @visit_node.register(ops.Arbitrary) def visit_Arbitrary(self, op, *, arg, how, where): if how == "first": return self.agg.any(arg, where=where) @@ -146,7 +145,6 @@ def visit_Arbitrary(self, op, *, arg, how, where): assert how == "heavy" return self.agg.anyHeavy(arg, where=where) - @visit_node.register(ops.Substring) def visit_Substring(self, op, *, arg, start, length): # Clickhouse is 1-indexed suffix = (length,) * (length is not None) @@ -154,7 +152,6 @@ def visit_Substring(self, op, *, arg, start, length): if_neg = self.f.substring(arg, self.f.length(arg) + start + 1, *suffix) return self.if_(start >= 0, if_pos, if_neg) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( @@ -166,11 +163,9 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.locate(arg, substr) - @visit_node.register(ops.RegexSearch) def visit_RegexSearch(self, op, *, arg, pattern): return sge.RegexpLike(this=arg, expression=pattern) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): arg = self.cast(arg, dt.String(nullable=False)) @@ -185,20 +180,16 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): return self.if_(self.f.notEmpty(then), then, NULL) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.indexOf(self.f.array(*values), needle) - @visit_node.register(ops.Sign) def visit_Sign(self, op, *, arg): """Workaround for missing sign function in older versions of clickhouse.""" return self.f.intDivOrZero(arg, self.f.abs(arg)) - @visit_node.register(ops.Hash) def visit_Hash(self, op, *, arg): return self.f.sipHash64(arg) - @visit_node.register(ops.HashBytes) def visit_HashBytes(self, op, *, arg, how): supported_algorithms = { "md5": "MD5", @@ -221,14 +212,13 @@ def visit_HashBytes(self, op, *, arg, how): return self.f[funcname](arg) - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): dtype = op.dtype if dtype.unit.short in ("ms", "us", "ns"): raise com.UnsupportedOperationError( "Clickhouse doesn't support subsecond interval resolutions" ) - return super().visit_node(op, arg=arg, unit=unit) + return super().visit_IntervalFromInteger(op, arg=arg, unit=unit) def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_inet(): @@ -322,16 +312,12 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): if (unit := unit.short) in {"ms", "us", "ns"}: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") return self.f.toDateTime(arg) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.TimeTruncate) - def visit_TimeTruncate(self, op, *, arg, unit): + def visit_TimestampTruncate(self, op, *, arg, unit): converters = { "Y": "toStartOfYear", "M": "toStartOfMonth", @@ -348,7 +334,8 @@ def visit_TimeTruncate(self, op, *, arg, unit): return self.f[converter](arg) - @visit_node.register(ops.TimestampBucket) + visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate + def visit_TimestampBucket(self, op, *, arg, interval, offset): if offset is not None: raise com.UnsupportedOperationError( @@ -357,7 +344,6 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): return self.f.toStartOfInterval(arg, interval) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): return self.f.toDate( self.f.concat( @@ -369,7 +355,6 @@ def visit_DateFromYMD(self, op, *, year, month, day): ) ) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds, **_ ): @@ -392,34 +377,28 @@ def visit_TimestampFromYMDHMS( return self.f.toTimeZone(to_datetime, timezone) return to_datetime - @visit_node.register(ops.StringSplit) def visit_StringSplit(self, op, *, arg, delimiter): return self.f.splitByString( delimiter, self.cast(arg, dt.String(nullable=False)) ) - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): return self.f.concat( self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) ) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): call = self.agg.groupArray(arg, where=where) return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep)) - @visit_node.register(ops.Cot) def visit_Cot(self, op, *, arg): return 1.0 / self.f.tan(arg) - @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, values, names): # ClickHouse struct types cannot be nullable # (non-nested fields can be nullable) return self.cast(self.f.tuple(*values), op.dtype.copy(nullable=False)) - @visit_node.register(ops.Clip) def visit_Clip(self, op, *, arg, lower, upper): if upper is not None: arg = self.if_(self.f.isNull(arg), NULL, self.f.least(upper, arg)) @@ -429,26 +408,21 @@ def visit_Clip(self, op, *, arg, lower, upper): return arg - @visit_node.register(ops.StructField) def visit_StructField(self, op, *, arg, field: str): arg_dtype = op.arg.dtype idx = arg_dtype.names.index(field) return self.cast(sge.Dot(this=arg, expression=sge.convert(idx + 1)), op.dtype) - @visit_node.register(ops.Repeat) def visit_Repeat(self, op, *, arg, times): return self.f.repeat(arg, self.f.accurateCast(times, "UInt64")) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, haystack, needle): return self.f.locate(haystack, needle) > 0 - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): weekdays = len(calendar.day_name) return (((self.f.toDayOfWeek(arg) - 1) % weekdays) + weekdays) % weekdays - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): # ClickHouse 20 doesn't support dateName # @@ -470,31 +444,24 @@ def visit_DayOfWeekName(self, op, *, arg): default=sge.convert(""), ) - @visit_node.register(ops.Map) def visit_Map(self, op, *, keys, values): # cast here to allow lookups of nullable columns return self.cast(self.f.tuple(keys, values), op.dtype) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): return self.if_(self.f.mapContains(arg, key), arg[key], default) - @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): return self.f.arrayConcat(*arg) - @visit_node.register(ops.BitAnd) - @visit_node.register(ops.BitOr) - @visit_node.register(ops.BitXor) def visit_BitAndOrXor(self, op, *, arg, where): if not (dtype := op.arg.dtype).is_unsigned_integer(): nbits = dtype.nbytes * 8 arg = self.f[f"reinterpretAsUInt{nbits}"](arg) return self.agg[f"group{type(op).__name__}"](arg, where=where) - @visit_node.register(ops.StandardDev) - @visit_node.register(ops.Variance) - @visit_node.register(ops.Covariance) + visit_BitAnd = visit_BitOr = visit_BitXor = visit_BitAndOrXor + def visit_StandardDevVariance(self, op, *, how, where, **kw): funcs = { ops.StandardDev: "stddev", @@ -506,14 +473,14 @@ def visit_StandardDevVariance(self, op, *, how, where, **kw): funcname = variants[how] return self.agg[funcname](*kw.values(), where=where) - @visit_node.register(ops.ArrayDistinct) + visit_StandardDev = visit_Variance = visit_Covariance = visit_StandardDevVariance + def visit_ArrayDistinct(self, op, *, arg): null_element = self.if_( self.f.countEqual(arg, NULL) > 0, self.f.array(NULL), self.f.array() ) return self.f.arrayConcat(self.f.arrayDistinct(arg), null_element) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): dtype = op.dtype return self.cast( @@ -522,7 +489,6 @@ def visit_ExtractMicrosecond(self, op, *, arg): dtype, ) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): dtype = op.dtype return self.cast( @@ -531,9 +497,7 @@ def visit_ExtractMillisecond(self, op, *, arg): dtype, ) - @visit_node.register(ops.Lag) - @visit_node.register(ops.Lead) - def formatter(self, op, *, arg, offset, default): + def visit_LagLead(self, op, *, arg, offset, default): args = [arg] if default is not None: @@ -548,38 +512,33 @@ def formatter(self, op, *, arg, offset, default): func = self.f[f"{type(op).__name__.lower()}InFrame"] return func(*args) - @visit_node.register(ops.ExtractFile) + visit_Lag = visit_Lead = visit_LagLead + def visit_ExtractFile(self, op, *, arg): return self.f.cutFragment(self.f.pathFull(arg)) - @visit_node.register(ops.ExtractQuery) def visit_ExtractQuery(self, op, *, arg, key): if key is not None: return self.f.extractURLParameter(arg, key) else: return self.f.queryString(arg) - @visit_node.register(ops.ArrayStringJoin) def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.arrayStringConcat(arg, sep) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, param, body): func = sge.Lambda(this=body, expressions=[param]) return self.f.arrayMap(func, arg) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, param, body): func = sge.Lambda(this=body, expressions=[param]) return self.f.arrayFilter(func, arg) - @visit_node.register(ops.ArrayRemove) def visit_ArrayRemove(self, op, *, arg, other): x = sg.to_identifier("x") body = x.neq(other) return self.f.arrayFilter(sge.Lambda(this=body, expressions=[x]), arg) - @visit_node.register(ops.ArrayUnion) def visit_ArrayUnion(self, op, *, left, right): arg = self.f.arrayConcat(left, right) null_element = self.if_( @@ -587,11 +546,9 @@ def visit_ArrayUnion(self, op, *, left, right): ) return self.f.arrayConcat(self.f.arrayDistinct(arg), null_element) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str: return self.f.arrayZip(*arg) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar( self, op: ops.CountDistinctStar, *, where, **_: Any ) -> str: @@ -602,7 +559,6 @@ def visit_CountDistinctStar( else: return self.f.countDistinct(columns) - @visit_node.register(ops.TimestampRange) def visit_TimestampRange(self, op, *, start, stop, step): unit = op.step.dtype.unit.name.lower() @@ -627,7 +583,6 @@ def visit_TimestampRange(self, op, *, start, stop, step): func, self.f.range(0, self.f.timestampDiff(unit, start, stop), step_value) ) - @visit_node.register(ops.RegexSplit) def visit_RegexSplit(self, op, *, arg, pattern): return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False))) @@ -635,16 +590,6 @@ def visit_RegexSplit(self, op, *, arg, pattern): def _generate_groups(groups): return groups - @visit_node.register(ops.RowID) - @visit_node.register(ops.CumeDist) - @visit_node.register(ops.PercentRank) - @visit_node.register(ops.Time) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.Levenshtein) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.All: "min", @@ -720,13 +665,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @ClickHouseCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @ClickHouseCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 9c5c131118ff..f9223dc41f29 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -2,7 +2,7 @@ import calendar import math -from functools import partial, singledispatchmethod +from functools import partial from itertools import starmap import sqlglot as sg @@ -33,6 +33,43 @@ class DataFusionCompiler(SQLGlotCompiler): type_mapper = DataFusionType rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayDistinct, + ops.ArrayFilter, + ops.ArrayFlatten, + ops.ArrayIntersect, + ops.ArrayMap, + ops.ArraySort, + ops.ArrayUnion, + ops.ArrayZip, + ops.BitwiseNot, + ops.Clip, + ops.CountDistinctStar, + ops.DateDelta, + ops.Greatest, + ops.GroupConcat, + ops.IntervalFromInteger, + ops.Least, + ops.MultiQuantile, + ops.Quantile, + ops.RowID, + ops.Strftime, + ops.TimeDelta, + ops.TimestampBucket, + ops.TimestampDelta, + ops.TimestampNow, + ops.TypeOf, + ops.Unnest, + ops.EndsWith, + ops.StringToTimestamp, + ops.Levenshtein, + ) + ) + def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) if where is not None: @@ -53,10 +90,6 @@ def _to_timestamp(self, value, target_dtype, literal=False): str_value = str(value) if literal else value return self.f.arrow_cast(str_value, f"Timestamp({unit}, {tz})") - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_decimal(): return self.cast( @@ -91,7 +124,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): if to.is_interval(): unit = to.unit.name.lower() @@ -104,14 +136,12 @@ def visit_Cast(self, op, *, arg, to): return self.f.arrow_cast(arg, f"{PyArrowType.from_ibis(to)}".capitalize()) return self.cast(arg, to) - @visit_node.register(ops.Substring) def visit_Substring(self, op, *, arg, start, length): start = self.if_(start < 0, self.f.length(arg) + start + 1, start + 1) if length is not None: return self.f.substr(arg, start, length) return self.f.substr(arg, start) - @visit_node.register(ops.Variance) def visit_Variance(self, op, *, arg, how, where): if how == "sample": return self.agg.var_samp(arg, where=where) @@ -120,7 +150,6 @@ def visit_Variance(self, op, *, arg, how, where): else: raise ValueError(f"Unrecognized how value: {how}") - @visit_node.register(ops.StandardDev) def visit_StandardDev(self, op, *, arg, how, where): if how == "sample": return self.agg.stddev_samp(arg, where=where) @@ -129,7 +158,6 @@ def visit_StandardDev(self, op, *, arg, how, where): else: raise ValueError(f"Unrecognized how value: {how}") - @visit_node.register(ops.ScalarUDF) def visit_ScalarUDF(self, op, **kw): input_type = op.__input_type__ if input_type in (InputType.PYARROW, InputType.BUILTIN): @@ -139,13 +167,11 @@ def visit_ScalarUDF(self, op, **kw): f"DataFusion only supports PyArrow UDFs: got a {input_type.name.lower()} UDF" ) - @visit_node.register(ops.ElementWiseVectorizedUDF) def visit_ElementWiseVectorizedUDF( self, op, *, func, func_args, input_type, return_type ): return self.f[func.__name__](*func_args) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): if not isinstance(op.index, ops.Literal): raise ValueError( @@ -154,7 +180,6 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): ) return self.f.regexp_match(arg, self.f.concat("(", pattern, ")"))[index] - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise NotImplementedError("`end` not yet implemented") @@ -165,7 +190,6 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.strpos(arg, substr) - @visit_node.register(ops.RegexSearch) def visit_RegexSearch(self, op, *, arg, pattern): return self.if_( sg.or_(arg.is_(NULL), pattern.is_(NULL)), @@ -178,73 +202,59 @@ def visit_RegexSearch(self, op, *, arg, pattern): ), ) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.strpos(haystack, needle) > sg.exp.convert(0) - @visit_node.register(ops.ExtractFragment) def visit_ExtractFragment(self, op, *, arg): return self.f.extract_url_field(arg, "fragment") - @visit_node.register(ops.ExtractProtocol) def visit_ExtractProtocol(self, op, *, arg): return self.f.extract_url_field(arg, "scheme") - @visit_node.register(ops.ExtractAuthority) def visit_ExtractAuthority(self, op, *, arg): return self.f.extract_url_field(arg, "netloc") - @visit_node.register(ops.ExtractPath) def visit_ExtractPath(self, op, *, arg): return self.f.extract_url_field(arg, "path") - @visit_node.register(ops.ExtractHost) def visit_ExtractHost(self, op, *, arg): return self.f.extract_url_field(arg, "hostname") - @visit_node.register(ops.ExtractQuery) def visit_ExtractQuery(self, op, *, arg, key): if key is not None: return self.f.extract_query_param(arg, key) return self.f.extract_query(arg) - @visit_node.register(ops.ExtractUserInfo) def visit_ExtractUserInfo(self, op, *, arg): return self.f.extract_user_info(arg) - @visit_node.register(ops.ExtractYear) - @visit_node.register(ops.ExtractMonth) - @visit_node.register(ops.ExtractQuarter) - @visit_node.register(ops.ExtractDay) def visit_ExtractYearMonthQuarterDay(self, op, *, arg): skip = len("Extract") part = type(op).__name__[skip:].lower() return self.f.date_part(part, arg) - @visit_node.register(ops.ExtractDayOfYear) + visit_ExtractYear = ( + visit_ExtractMonth + ) = visit_ExtractQuarter = visit_ExtractDay = visit_ExtractYearMonthQuarterDay + def visit_ExtractDayOfYear(self, op, *, arg): return self.f.date_part("doy", arg) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.date_part("dow", arg) + 6) % 7 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return sg.exp.Case( this=paren(self.f.date_part("dow", arg) + 6) % 7, ifs=list(starmap(self.if_, enumerate(calendar.day_name))), ) - @visit_node.register(ops.Date) def visit_Date(self, op, *, arg): return self.f.date_trunc("day", arg) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.date_part("week", arg) - @visit_node.register(ops.TimestampTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): if unit in ( IntervalUnit.MILLISECOND, @@ -257,7 +267,6 @@ def visit_TimestampTruncate(self, op, *, arg, unit): return self.f.date_trunc(unit.name.lower(), arg) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): if op.arg.dtype.is_date(): return self.f.extract_epoch_seconds_date(arg) @@ -268,7 +277,6 @@ def visit_ExtractEpochSeconds(self, op, *, arg): f"The function is not defined for {op.arg.dtype}" ) - @visit_node.register(ops.ExtractMinute) def visit_ExtractMinute(self, op, *, arg): if op.arg.dtype.is_date(): return self.f.date_part("minute", arg) @@ -281,7 +289,6 @@ def visit_ExtractMinute(self, op, *, arg): f"The function is not defined for {op.arg.dtype}" ) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): if op.arg.dtype.is_time(): return self.f.extract_millisecond_time(arg) @@ -292,7 +299,6 @@ def visit_ExtractMillisecond(self, op, *, arg): f"The function is not defined for {op.arg.dtype}" ) - @visit_node.register(ops.ExtractHour) def visit_ExtractHour(self, op, *, arg): if op.arg.dtype.is_date() or op.arg.dtype.is_timestamp(): return self.f.date_part("hour", arg) @@ -303,7 +309,6 @@ def visit_ExtractHour(self, op, *, arg): f"The function is not defined for {op.arg.dtype}" ) - @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): if op.arg.dtype.is_date() or op.arg.dtype.is_timestamp(): return self.f.extract_second_timestamp(arg) @@ -314,15 +319,12 @@ def visit_ExtractSecond(self, op, *, arg): f"The function is not defined for {op.arg.dtype}" ) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): return self.f.flatten(self.f.array_repeat(arg, times)) - @visit_node.register(ops.ArrayPosition) def visit_ArrayPosition(self, op, *, arg, other): return self.f.coalesce(self.f.array_position(arg, other), 0) - @visit_node.register(ops.Covariance) def visit_Covariance(self, op, *, left, right, how, where): x = op.left if x.dtype.is_boolean(): @@ -339,7 +341,6 @@ def visit_Covariance(self, op, *, left, right, how, where): else: raise ValueError(f"Unrecognized how = `{how}` value") - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, where, how): x = op.left if x.dtype.is_boolean(): @@ -351,21 +352,17 @@ def visit_Correlation(self, op, *, left, right, where, how): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.IsNan) def visit_IsNan(self, op, *, arg): return sg.and_(arg.is_(sg.not_(NULL)), self.f.isnan(arg)) - @visit_node.register(ops.ArrayStringJoin) def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.coalesce( self.f.array_position(self.f.make_array(*values), needle), 0 ) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): if unit == TimestampUnit.SECOND: return self.f.from_unixtime(arg) @@ -378,7 +375,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): else: raise com.UnsupportedOperationError(f"Unsupported unit {unit}") - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): return self.cast( self.f.concat( @@ -391,7 +387,6 @@ def visit_DateFromYMD(self, op, *, year, month, day): dt.date, ) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds, **_ ): @@ -412,57 +407,18 @@ def visit_TimestampFromYMDHMS( ) ) - @visit_node.register(ops.IsInf) def visit_IsInf(self, op, *, arg): return sg.and_(sg.not_(self.f.isnan(arg)), self.f.abs(arg).eq(self.POS_INF)) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return self.f.array_element(arg, index + self.cast(index >= 0, op.index.dtype)) - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): any_args_null = (a.is_(NULL) for a in arg) return self.if_( sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg) ) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayDistinct) - @visit_node.register(ops.ArrayFilter) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayIntersect) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayUnion) - @visit_node.register(ops.ArrayZip) - @visit_node.register(ops.BitwiseNot) - @visit_node.register(ops.Clip) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.Greatest) - @visit_node.register(ops.GroupConcat) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.Least) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.RowID) - @visit_node.register(ops.Strftime) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.TimestampNow) - @visit_node.register(ops.TypeOf) - @visit_node.register(ops.Unnest) - @visit_node.register(ops.EndsWith) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.Levenshtein) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - - @visit_node.register(ops.Aggregate) def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted @@ -530,13 +486,11 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @DataFusionCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @DataFusionCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index c0661cd4cf71..648a9f3ced72 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -1,12 +1,9 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge import toolz -import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler @@ -22,17 +19,60 @@ class DruidCompiler(SQLGlotCompiler): type_mapper = DruidType rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.ApproxMedian, + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.ArrayDistinct, + ops.ArrayFilter, + ops.ArrayFlatten, + ops.ArrayIntersect, + ops.ArrayMap, + ops.ArraySort, + ops.ArrayUnion, + ops.ArrayZip, + ops.CountDistinctStar, + ops.Covariance, + ops.DateDelta, + ops.DayOfWeekIndex, + ops.DayOfWeekName, + ops.First, + ops.IntervalFromInteger, + ops.IsNan, + ops.IsInf, + ops.Last, + ops.Levenshtein, + ops.Median, + ops.MultiQuantile, + ops.Quantile, + ops.RegexReplace, + ops.RegexSplit, + ops.RowID, + ops.StandardDev, + ops.Strftime, + ops.StringAscii, + ops.StringSplit, + ops.StringToTimestamp, + ops.TimeDelta, + ops.TimestampBucket, + ops.TimestampDelta, + ops.TimestampNow, + ops.Translate, + ops.TypeOf, + ops.Unnest, + ops.Variance, + ) + ) + def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) if where is not None: return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where)) return expr - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.InMemoryTable) def visit_InMemoryTable(self, op, *, name, schema, data): # the performance of this is rather terrible tuples = data.to_frame().itertuples(index=False) @@ -49,31 +89,24 @@ def visit_InMemoryTable(self, op, *, name, schema, data): ) return sg.select(*columns).from_(expr) - @visit_node.register(ops.StringJoin) def visit_StringJoin(self, op, *, arg, sep): return self.f.concat(*toolz.interpose(sep, arg)) - @visit_node.register(ops.Pi) def visit_Pi(self, op): return self.f.acos(-1) - @visit_node.register(ops.Sign) def visit_Sign(self, op, *, arg): return self.if_(arg.eq(0), 0, self.if_(arg > 0, 1, -1)) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): return self.agg.string_agg(arg, sep, 1 << 20, where=where) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.right(arg, self.f.length(end)).eq(end) - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): return self.if_( self.f.length(arg) < 2, @@ -84,17 +117,14 @@ def visit_Capitalize(self, op, *, arg): ), ) - @visit_node.register(ops.RegexSearch) def visit_RegexSearch(self, op, *, arg, pattern): return self.f.anon.regexp_like(arg, pattern) - @visit_node.register(ops.StringSQLILike) def visit_StringSQLILike(self, op, *, arg, pattern, escape): if escape is not None: raise NotImplementedError("non-None escape not supported") return self.f.upper(arg).like(self.f.upper(pattern)) - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): if value is None: return NULL @@ -106,7 +136,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return None - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if from_.is_integer() and to.is_timestamp(): @@ -116,7 +145,6 @@ def visit_Cast(self, op, *, arg, to): return self.f.time_parse(arg) return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): @@ -137,52 +165,6 @@ def visit_TimestampFromYMDHMS( ) ) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.ArrayDistinct) - @visit_node.register(ops.ArrayFilter) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayIntersect) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayUnion) - @visit_node.register(ops.ArrayZip) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.DayOfWeekIndex) - @visit_node.register(ops.DayOfWeekName) - @visit_node.register(ops.First) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.IsNan) - @visit_node.register(ops.IsInf) - @visit_node.register(ops.Last) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.Median) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.RegexReplace) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.StandardDev) - @visit_node.register(ops.Strftime) - @visit_node.register(ops.StringAscii) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.TimestampNow) - @visit_node.register(ops.Translate) - @visit_node.register(ops.TypeOf) - @visit_node.register(ops.Unnest) - @visit_node.register(ops.Variance) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.BitAnd: "bit_and", @@ -205,13 +187,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @DruidCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @DruidCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index ee106e12001d..da0c93cbaaf7 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from functools import partial, reduce, singledispatchmethod +from functools import partial, reduce import sqlglot as sg import sqlglot.expressions as sge @@ -11,7 +11,11 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.base.sqlglot.compiler import ( + NULL, + STAR, + SQLGlotCompiler, +) from ibis.backends.base.sqlglot.datatypes import DuckDBType _INTERVAL_SUFFIXES = { @@ -39,11 +43,6 @@ def _aggregate(self, funcname: str, *args, where): return sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - - @visit_node.register(ops.ArrayDistinct) def visit_ArrayDistinct(self, op, *, arg): return self.if_( arg.is_(NULL), @@ -56,17 +55,14 @@ def visit_ArrayDistinct(self, op, *, arg): ), ) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return self.f.list_extract(arg, index + self.cast(index >= 0, op.index.dtype)) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): func = sge.Lambda(this=arg, expressions=[sg.to_identifier("_")]) return self.f.flatten(self.f.list_apply(self.f.range(times), func)) # TODO(kszucs): this could be moved to the base SQLGlotCompiler - @visit_node.register(ops.Sample) def visit_Sample( self, op, *, parent, fraction: float, method: str, seed: int | None, **_ ): @@ -78,7 +74,6 @@ def visit_Sample( ) return sg.select(STAR).from_(sample) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): arg_length = self.f.len(arg) @@ -94,31 +89,26 @@ def visit_ArraySlice(self, op, *, arg, start, stop): return self.f.list_slice(arg, start + 1, stop) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, body, param): lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)]) return self.f.list_apply(arg, lamduh) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, body, param): lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)]) return self.f.list_filter(arg, lamduh) - @visit_node.register(ops.ArrayIntersect) def visit_ArrayIntersect(self, op, *, left, right): param = sg.to_identifier("x") body = self.f.list_contains(right, param) lamduh = sge.Lambda(this=body, expressions=[param]) return self.f.list_filter(left, lamduh) - @visit_node.register(ops.ArrayRemove) def visit_ArrayRemove(self, op, *, arg, other): param = sg.to_identifier("x") body = param.neq(other) lamduh = sge.Lambda(this=body, expressions=[param]) return self.f.list_filter(arg, lamduh) - @visit_node.register(ops.ArrayUnion) def visit_ArrayUnion(self, op, *, left, right): arg = self.f.list_concat(left, right) return self.if_( @@ -132,7 +122,6 @@ def visit_ArrayUnion(self, op, *, left, right): ), ) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op, *, arg): i = sg.to_identifier("i") body = sge.Struct.from_arg_list( @@ -151,27 +140,24 @@ def visit_ArrayZip(self, op, *, arg): func, ) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): return self.f.ifnull( self.f.list_extract(self.f.element_at(arg, key), 1), default ) - @visit_node.register(ops.MapContains) def visit_MapContains(self, op, *, arg, key): return self.f.len(self.f.element_at(arg, key)).neq(0) - @visit_node.register(ops.ToJSONMap) - @visit_node.register(ops.ToJSONArray) def visit_ToJSONMap(self, op, *, arg): return sge.TryCast(this=arg, to=self.type_mapper.from_ibis(op.dtype)) - @visit_node.register(ops.ArrayConcat) + def visit_ToJSONArray(self, op, *, arg): + return self.visit_ToJSONMap(op, arg=arg) + def visit_ArrayConcat(self, op, *, arg): # TODO(cpcloud): map ArrayConcat to this in sqlglot instead of here return reduce(self.f.list_concat, arg) - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): if unit.short == "ns": raise com.UnsupportedOperationError( @@ -182,11 +168,9 @@ def visit_IntervalFromInteger(self, op, *, arg, unit): return self.f.to_days(arg * 7) return self.f[f"to_{unit.plural}"](arg) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.list_indexof(self.f.array(*values), needle) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, where, arg): # use a tuple because duckdb doesn't accept COUNT(DISTINCT a, b, c, ...) # @@ -198,17 +182,14 @@ def visit_CountDistinctStar(self, op, *, where, arg): ) return self.agg.count(sge.Distinct(expressions=[row]), where=where) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): return self.f.mod(self.f.extract("ms", arg), 1_000) # DuckDB extracts subminute microseconds and milliseconds # so we have to finesse it a little bit - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.mod(self.f.extract("us", arg), 1_000_000) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): unit = unit.short if unit == "ms": @@ -218,7 +199,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): else: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds, **_ ): @@ -231,7 +211,6 @@ def visit_TimestampFromYMDHMS( return self.f[func](*args) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): if to.is_interval(): func = self.f[f"to_{_INTERVAL_SUFFIXES[to.unit.short]}"] @@ -289,7 +268,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): return self.f.concat( self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) @@ -308,7 +286,6 @@ def _neg_idx_to_pos(self, array, idx): arg_length + self.f.greatest(idx, -arg_length), ) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -324,42 +301,37 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.GeoConvert) def visit_GeoConvert(self, op, *, arg, source, target): # 4th argument is to specify that the result is always_xy so that it # matches the behavior of the equivalent geopandas functionality return self.f.st_transform(arg, source, target, True) - @visit_node.register(ops.TimestampNow) def visit_TimestampNow(self, op): """DuckDB current timestamp defaults to timestamp + tz.""" return self.cast(super().visit_TimestampNow(op), dt.timestamp) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return self.f.regexp_replace( arg, pattern, replacement, "g", dialect=self.dialect ) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) def visit_Quantile(self, op, *, arg, quantile, where): suffix = "cont" if op.arg.dtype.is_numeric() else "disc" funcname = f"percentile_{suffix}" return self.agg[funcname](arg, quantile, where=where) - @visit_node.register(ops.HexDigest) + def visit_MultiQuantile(self, op, *, arg, quantile, where): + return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where) + def visit_HexDigest(self, op, *, arg, how): if how in ("md5", "sha256"): return getattr(self.f, how)(arg) else: raise NotImplementedError(f"No available hashing function for {how}") - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg) @@ -418,13 +390,11 @@ def visit_StringConcat(self, op, *, arg): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @DuckDBCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @DuckDBCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 20fedd4dd193..c62580d79e04 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge @@ -34,6 +32,71 @@ class ExasolCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.AnalyticVectorizedUDF, + ops.ApproxMedian, + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.ArrayDistinct, + ops.ArrayFilter, + ops.ArrayFlatten, + ops.ArrayIntersect, + ops.ArrayMap, + ops.ArraySort, + ops.ArrayStringJoin, + ops.ArrayUnion, + ops.ArrayZip, + ops.BitwiseNot, + ops.Covariance, + ops.CumeDist, + ops.DateAdd, + ops.DateDelta, + ops.DateSub, + ops.DateFromYMD, + ops.DayOfWeekIndex, + ops.DayOfWeekName, + ops.ElementWiseVectorizedUDF, + ops.ExtractDayOfYear, + ops.ExtractEpochSeconds, + ops.ExtractQuarter, + ops.ExtractWeekOfYear, + ops.First, + ops.IntervalFromInteger, + ops.IsInf, + ops.IsNan, + ops.Last, + ops.Levenshtein, + ops.Median, + ops.MultiQuantile, + ops.Quantile, + ops.ReductionVectorizedUDF, + ops.RegexExtract, + ops.RegexReplace, + ops.RegexSearch, + ops.RegexSplit, + ops.RowID, + ops.StandardDev, + ops.Strftime, + ops.StringJoin, + ops.StringSplit, + ops.StringToTimestamp, + ops.TimeDelta, + ops.TimestampAdd, + ops.TimestampBucket, + ops.TimestampDelta, + ops.TimestampDiff, + ops.TimestampNow, + ops.TimestampSub, + ops.TimestampTruncate, + ops.TypeOf, + ops.Unnest, + ops.Variance, + ) + ) + @staticmethod def _minimize_spec(start, end, spec): if ( @@ -56,10 +119,6 @@ def _gen_valid_name(name: str) -> str: """Exasol does not allow dots in quoted column names.""" return name.replace(".", "_") - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_date(): return self.cast(value.isoformat(), dtype) @@ -74,105 +133,38 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(str(value)) return super().visit_NonNullLiteral(op, value=value, dtype=dtype) - @visit_node.register(ops.Date) def visit_Date(self, op, *, arg): return self.cast(arg, dt.date) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.length(start)).eq(start) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.right(arg, self.f.length(end)).eq(end) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.locate(substr, arg, (start if start is not None else 0) + 1) - @visit_node.register(ops.StringSQLILike) def visit_StringSQLILike(self, op, *, arg, pattern, escape): return self.f.upper(arg).like(self.f.upper(pattern)) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.locate(needle, haystack) > 0 - @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): return self.f.floor(self.cast(self.f.extract(self.v.second, arg), op.dtype)) - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): any_args_null = (a.is_(NULL) for a in arg) return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) - @visit_node.register(ops.AnalyticVectorizedUDF) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.ArrayDistinct) - @visit_node.register(ops.ArrayFilter) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayIntersect) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayStringJoin) - @visit_node.register(ops.ArrayUnion) - @visit_node.register(ops.ArrayZip) - @visit_node.register(ops.BitwiseNot) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.CumeDist) - @visit_node.register(ops.DateAdd) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.DateSub) - @visit_node.register(ops.DateFromYMD) - @visit_node.register(ops.DayOfWeekIndex) - @visit_node.register(ops.DayOfWeekName) - @visit_node.register(ops.ElementWiseVectorizedUDF) - @visit_node.register(ops.ExtractDayOfYear) - @visit_node.register(ops.ExtractEpochSeconds) - @visit_node.register(ops.ExtractQuarter) - @visit_node.register(ops.ExtractWeekOfYear) - @visit_node.register(ops.First) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.IsInf) - @visit_node.register(ops.IsNan) - @visit_node.register(ops.Last) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.Median) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.ReductionVectorizedUDF) - @visit_node.register(ops.RegexExtract) - @visit_node.register(ops.RegexReplace) - @visit_node.register(ops.RegexSearch) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.StandardDev) - @visit_node.register(ops.Strftime) - @visit_node.register(ops.StringJoin) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.TimestampAdd) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.TimestampDiff) - @visit_node.register(ops.TimestampNow) - @visit_node.register(ops.TimestampSub) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.TypeOf) - @visit_node.register(ops.Unnest) - @visit_node.register(ops.Variance) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - - @visit_node.register(ops.CountDistinctStar) - def visit_Unsupported(self, op, **_): - raise com.UnsupportedOperationError(type(op).__name__) + def visit_CountDistinctStar(self, op, *, arg, where): + raise com.UnsupportedOperationError( + "COUNT(DISTINCT *) is not supported in Exasol" + ) + + def visit_DateTruncate(self, op, *, arg, unit): + return super().visit_TimestampTruncate(op, arg=arg, unit=unit) _SIMPLE_OPS = { @@ -186,13 +178,11 @@ def visit_Unsupported(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @ExasolCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @ExasolCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index af89583f0de9..46731e14175f 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -2,8 +2,6 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge @@ -37,6 +35,41 @@ class FlinkCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.AnalyticVectorizedUDF, + ops.ApproxMedian, + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.ArrayFlatten, + ops.ArraySort, + ops.ArrayStringJoin, + ops.Correlation, + ops.CountDistinctStar, + ops.Covariance, + ops.DateDiff, + ops.ExtractURLField, + ops.FindInSet, + ops.IsInf, + ops.IsNan, + ops.Levenshtein, + ops.MapMerge, + ops.Median, + ops.MultiQuantile, + ops.NthValue, + ops.Quantile, + ops.ReductionVectorizedUDF, + ops.RegexSplit, + ops.RowID, + ops.ScalarUDF, + ops.StringSplit, + ops.Translate, + ops.Unnest, + ) + ) + @property def NAN(self): raise NotImplementedError("Flink does not support NaN") @@ -47,6 +80,10 @@ def POS_INF(self): NEG_INF = POS_INF + @staticmethod + def _generate_groups(groups): + return groups + def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] if where is not None: @@ -75,10 +112,6 @@ def _aggregate(self, funcname: str, *args, where): args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args) return func(*args) - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - @staticmethod def _minimize_spec(start, end, spec): if ( @@ -90,7 +123,6 @@ def _minimize_spec(start, end, spec): return None return spec - @visit_node.register(ops.TumbleWindowingTVF) def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset): args = [ self.v[f"TABLE {table.this.sql(self.dialect)}"], @@ -114,7 +146,6 @@ def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset): ) ) - @visit_node.register(ops.HopWindowingTVF) def visit_HopWindowingTVF( self, op, *, table, time_col, window_size, window_slide, offset ): @@ -135,7 +166,6 @@ def visit_HopWindowingTVF( ) ) - @visit_node.register(ops.CumulateWindowingTVF) def visit_CumulateWindowingTVF( self, op, *, table, time_col, window_size, window_step, offset ): @@ -156,7 +186,6 @@ def visit_CumulateWindowingTVF( ) ) - @visit_node.register(ops.InMemoryTable) def visit_InMemoryTable(self, op, *, name, schema, data): # the performance of this is rather terrible tuples = data.to_frame().itertuples(index=False) @@ -230,15 +259,12 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return self.cast(value.isoformat(timespec="microseconds"), dtype) return None - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return sge.Bracket(this=arg, expressions=[index + 1]) - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): if value is None: assert dtype.nullable, "dtype is not nullable but value is None" @@ -247,13 +273,11 @@ def visit_Literal(self, op, *, value, dtype): return sge.NULL return super().visit_Literal(op, value=value, dtype=dtype) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): if default is sge.NULL: default = self.cast(default, op.dtype) return self.f.coalesce(arg[key], default) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): args = [arg, self.if_(start >= 0, start + 1, start)] @@ -264,15 +288,12 @@ def visit_ArraySlice(self, op, *, arg, start, stop): return self.f.array_slice(*args) - @visit_node.register(ops.Not) def visit_Not(self, op, *, arg): return sg.not_(self.cast(arg, dt.boolean)) - @visit_node.register(ops.Date) def visit_Date(self, op, *, arg): return self.cast(arg, dt.date) - @visit_node.register(ops.TryCast) def visit_TryCast(self, op, *, arg, to): type_mapper = self.type_mapper if op.arg.dtype.is_temporal() and to.is_numeric(): @@ -281,11 +302,9 @@ def visit_TryCast(self, op, *, arg, to): ) return sge.TryCast(this=arg, to=type_mapper.from_ibis(to)) - @visit_node.register(ops.FloorDivide) def visit_FloorDivide(self, op, *, left, right): return self.f.floor(left / right) - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): assert isinstance(op.index, ops.Literal) idx = op.index @@ -299,7 +318,6 @@ def visit_JSONGetItem(self, op, *, arg, index): key_hack = f"{sge.convert(query_path).sql(self.dialect)} WITH CONDITIONAL ARRAY WRAPPER" return self.f.json_query(arg, self.v[key_hack]) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): from ibis.common.temporal import TimestampUnit @@ -312,11 +330,9 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): return self.cast(self.f.to_timestamp_ltz(arg, precision), dt.timestamp) - @visit_node.register(ops.Time) def visit_Time(self, op, *, arg): return self.cast(arg, op.dtype) - @visit_node.register(ops.TimeFromHMS) def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): padded_hour = self.f.lpad(self.cast(hours, dt.string), 2, "0") padded_minute = self.f.lpad(self.cast(minutes, dt.string), 2, "0") @@ -325,7 +341,6 @@ def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): self.f.concat(padded_hour, ":", padded_minute, ":", padded_second), op.dtype ) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): padded_year = self.f.lpad(self.cast(year, dt.string), 4, "0") padded_month = self.f.lpad(self.cast(month, dt.string), 2, "0") @@ -334,7 +349,6 @@ def visit_DateFromYMD(self, op, *, year, month, day): self.f.concat(padded_year, "-", padded_month, "-", padded_day), op.dtype ) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): @@ -361,11 +375,9 @@ def visit_TimestampFromYMDHMS( op.dtype, ) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.unix_timestamp(self.cast(arg, dt.string)) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if to.is_timestamp(): @@ -384,7 +396,6 @@ def visit_Cast(self, op, *, arg, to): else: return self.cast(arg, to) - @visit_node.register(ops.IfElse) def visit_IfElse(self, op, *, bool_expr, true_expr, false_null_expr): return self.if_( bool_expr, @@ -396,23 +407,18 @@ def visit_IfElse(self, op, *, bool_expr, true_expr, false_null_expr): ), ) - @visit_node.register(ops.Log10) def visit_Log10(self, op, *, arg): return self.f.anon.log(10, arg) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): return self.f.extract(self.v.millisecond, arg) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.extract(self.v.microsecond, arg) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.dayofweek(arg) + 5) % 7 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): index = self.cast(self.f.dayofweek(self.cast(arg, dt.date)), op.dtype) lookup_table = self.f.str_to_map( @@ -420,11 +426,9 @@ def visit_DayOfWeekName(self, op, *, arg): ) return lookup_table[index] - @visit_node.register(ops.TimestampNow) def visit_TimestampNow(self, op): return self.v.current_timestamp - @visit_node.register(ops.TimestampBucket) def visit_TimestampBucket(self, op, *, arg, interval, offset): unit = op.interval.dtype.unit.name unit_var = self.v[unit] @@ -446,9 +450,6 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): self.v[f"FLOOR({arg.sql(self.dialect)} TO {unit_var.sql(self.dialect)})"], ) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.DateDelta) def visit_TemporalDelta(self, op, *, part, left, right): right = self.visit_TemporalTruncate(None, arg=right, unit=part) left = self.visit_TemporalTruncate(None, arg=left, unit=part) @@ -458,20 +459,21 @@ def visit_TemporalDelta(self, op, *, part, left, right): self.cast(left, dt.timestamp), ) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimeTruncate) + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta = visit_TemporalDelta + def visit_TemporalTruncate(self, op, *, arg, unit): unit_var = self.v[unit.name] arg_sql = arg.sql(self.dialect) unit_sql = unit_var.sql(self.dialect) return self.f.floor(self.v[f"{arg_sql} TO {unit_sql}"]) - @visit_node.register(ops.StringContains) + visit_TimestampTruncate = ( + visit_DateTruncate + ) = visit_TimeTruncate = visit_TemporalTruncate + def visit_StringContains(self, op, *, haystack, needle): return self.f.instr(haystack, needle) > 0 - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( @@ -485,80 +487,39 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.instr(arg, substr) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return self.f.left(arg, self.f.char_length(start)).eq(start) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.right(arg, self.f.char_length(end)).eq(end) - @visit_node.register(ops.ExtractProtocol) - @visit_node.register(ops.ExtractAuthority) - @visit_node.register(ops.ExtractUserInfo) - @visit_node.register(ops.ExtractHost) - @visit_node.register(ops.ExtractFile) - @visit_node.register(ops.ExtractPath) def visit_ExtractUrlField(self, op, *, arg): return self.f.parse_url(arg, type(op).__name__[len("Extract") :].upper()) - @visit_node.register(ops.ExtractQuery) + visit_ExtractAuthority = ( + visit_ExtractHost + ) = ( + visit_ExtractUserInfo + ) = ( + visit_ExtractProtocol + ) = visit_ExtractFile = visit_ExtractPath = visit_ExtractUrlField + def visit_ExtractQuery(self, op, *, arg, key): return self.f.parse_url(*filter(None, (arg, "QUERY", key))) - @visit_node.register(ops.ExtractFragment) def visit_ExtractFragment(self, op, *, arg): return self.f.parse_url(arg, "REF") - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is None: return self.f.count(STAR) return self.f.sum(self.cast(where, dt.int64)) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, self.f.array(arg)[2]) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.AnalyticVectorizedUDF) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayStringJoin) - @visit_node.register(ops.Correlation) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.DateDiff) - @visit_node.register(ops.ExtractURLField) - @visit_node.register(ops.FindInSet) - @visit_node.register(ops.IsInf) - @visit_node.register(ops.IsNan) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.MapMerge) - @visit_node.register(ops.Median) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.NthValue) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.ReductionVectorizedUDF) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.ScalarUDF) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.Translate) - @visit_node.register(ops.Unnest) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - - @staticmethod - def _generate_groups(groups): - return groups - _SIMPLE_OPS = { ops.All: "min", @@ -588,13 +549,11 @@ def _generate_groups(groups): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @FlinkCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @FlinkCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index f20f48095a5c..21950d32b21a 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge @@ -33,6 +31,37 @@ class ImpalaCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.ArrayPosition, + ops.Array, + ops.Covariance, + ops.DateDelta, + ops.ExtractDayOfYear, + ops.First, + ops.Last, + ops.Levenshtein, + ops.Map, + ops.Median, + ops.MultiQuantile, + ops.NthValue, + ops.Quantile, + ops.RegexSplit, + ops.RowID, + ops.StringSplit, + ops.StructColumn, + ops.Time, + ops.TimeDelta, + ops.TimestampBucket, + ops.TimestampDelta, + ops.Unnest, + ) + ) + def _aggregate(self, funcname: str, *args, where): if where is not None: args = tuple(self.if_(where, arg, NULL) for arg in args) @@ -64,23 +93,16 @@ def _minimize_spec(start, end, spec): return None return spec - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): if value is None and dtype.is_binary(): return NULL return super().visit_Literal(op, value=value, dtype=dtype) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.sum(self.cast(where, op.dtype)) return self.f.count(STAR) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): expressions = ( sg.column(name, table=arg.alias_or_name, quoted=self.quoted) @@ -90,53 +112,42 @@ def visit_CountDistinctStar(self, op, *, arg, where): expressions = (self.if_(where, expr, NULL) for expr in expressions) return self.f.count(sge.Distinct(expressions=list(expressions))) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): return sg.and_(sg.or_(left, right), sg.not_(sg.and_(left, right))) - @visit_node.register(ops.RandomScalar) def visit_RandomScalar(self, op): return self.f.rand(self.f.utc_to_unix_micros(self.f.utc_timestamp())) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.f.pmod(self.f.dayofweek(arg) - 2, 7) - @visit_node.register(ops.ExtractMillisecond) - def viist_ExtractMillisecond(self, op, *, arg): + def visit_ExtractMillisecond(self, op, *, arg): return self.f.extract(self.v.millisecond, arg) % 1_000 - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.extract(self.v.microsecond, arg) % 1_000_000 - @visit_node.register(ops.Degrees) def visit_Degrees(self, op, *, arg): return 180.0 * arg / self.f.pi() - @visit_node.register(ops.Radians) def visit_Radians(self, op, *, arg): return self.f.pi() * arg / 180.0 - @visit_node.register(ops.HashBytes) def visit_HashBytes(self, op, *, arg, how): if how not in ("md5", "sha1", "sha256", "sha512"): raise com.UnsupportedOperationError(how) return self.f[how](arg) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): if base is None: return self.f.ln(arg) return self.f.log(base, arg, dialect=self.dialect) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): return self.cast( self.f.concat( @@ -186,7 +197,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(str(value)) return None - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if from_.is_integer() and to.is_interval(): @@ -195,50 +205,43 @@ def visit_Cast(self, op, *, arg, to): return 1_000_000 * self.f.unix_timestamp(arg) return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return arg.like(self.f.concat(start, "%")) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return arg.like(self.f.concat("%", end)) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.find_in_set(needle, self.f.concat_ws(",", *values)) - @visit_node.register(ops.ExtractProtocol) - @visit_node.register(ops.ExtractAuthority) - @visit_node.register(ops.ExtractUserInfo) - @visit_node.register(ops.ExtractHost) - @visit_node.register(ops.ExtractFile) - @visit_node.register(ops.ExtractPath) def visit_ExtractUrlField(self, op, *, arg): return self.f.parse_url(arg, type(op).__name__[len("Extract") :].upper()) - @visit_node.register(ops.ExtractQuery) + visit_ExtractAuthority = ( + visit_ExtractHost + ) = ( + visit_ExtractUserInfo + ) = ( + visit_ExtractProtocol + ) = visit_ExtractFile = visit_ExtractPath = visit_ExtractUrlField + def visit_ExtractQuery(self, op, *, arg, key): return self.f.parse_url(*filter(None, (arg, "QUERY", key))) - @visit_node.register(ops.ExtractFragment) def visit_ExtractFragment(self, op, *, arg): return self.f.parse_url(arg, "REF") - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if start is not None: return self.f.locate(substr, arg, start + 1) return self.f.locate(substr, arg) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.locate(needle, haystack) > 0 - @visit_node.register(ops.TimestampDiff) def visit_TimestampDiff(self, op, *, left, right): return self.f.unix_timestamp(left) - self.f.unix_timestamp(right) - @visit_node.register(ops.Strftime) def visit_Strftime(self, op, *, arg, format_str): if not isinstance(op.format_str, ops.Literal): raise com.UnsupportedOperationError( @@ -252,11 +255,9 @@ def visit_Strftime(self, op, *, arg, format_str): self.f.unix_timestamp(self.cast(arg, dt.string)), format_str ) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.anon.weekofyear(arg) - @visit_node.register(ops.TimestampTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): units = { "Y": "YEAR", @@ -277,25 +278,21 @@ def visit_TimestampTruncate(self, op, *, arg, unit): ) return self.f.date_trunc(impala_unit, arg) - @visit_node.register(ops.DateTruncate) def visit_DateTruncate(self, op, *, arg, unit): if unit.short == "Q": return self.f.trunc(arg, "Q") return self.f.date_trunc(unit.name.upper(), arg) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): arg = self.cast(util.convert_unit(arg, unit.short, "s"), dt.int32) return self.cast(self.f.from_unixtime(arg, "yyyy-MM-dd HH:mm:ss"), dt.timestamp) - @visit_node.register(ops.DateAdd) def visit_DateAdd(self, op, *, left, right): return self.cast( super().visit_DateAdd(op, left=self.cast(left, dt.date), right=right), dt.date, ) - @visit_node.register(ops.TimestampAdd) def visit_TimestampAdd(self, op, *, left, right): if not isinstance(right, sge.Interval): raise com.UnsupportedOperationError( @@ -309,19 +306,15 @@ def visit_TimestampAdd(self, op, *, left, right): dt.timestamp, ) - @visit_node.register(ops.DateDiff) def visit_DateDiff(self, op, *, left, right): return self.f.anon.datediff(left, right) - @visit_node.register(ops.Date) def visit_Date(self, op, *, arg): return self.cast(self.f.to_date(arg), dt.date) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return self.f.regexp_replace(arg, pattern, replacement, dialect=self.dialect) - @visit_node.register(ops.Round) def visit_Round(self, op, *, arg, digits): rounded = self.f.round(*filter(None, (arg, digits))) @@ -330,7 +323,6 @@ def visit_Round(self, op, *, arg, digits): return self.cast(rounded, dtype) return rounded - @visit_node.register(ops.Sign) def visit_Sign(self, op, *, arg): sign = self.f.sign(arg) dtype = op.dtype @@ -338,35 +330,6 @@ def visit_Sign(self, op, *, arg): return self.cast(sign, dtype) return sign - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.ArrayPosition) - @visit_node.register(ops.Array) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.ExtractDayOfYear) - @visit_node.register(ops.First) - @visit_node.register(ops.Last) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.Map) - @visit_node.register(ops.Median) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.NthValue) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.StructColumn) - @visit_node.register(ops.Time) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.Unnest) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.All: "min", @@ -396,13 +359,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @ImpalaCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @ImpalaCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 388bbb5dea32..372836b76ff4 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -1,7 +1,6 @@ from __future__ import annotations import calendar -from functools import singledispatchmethod import sqlglot as sg import sqlglot.expressions as sge @@ -62,6 +61,63 @@ class MSSQLCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Any, + ops.All, + ops.ApproxMedian, + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.Array, + ops.ArrayDistinct, + ops.ArrayFlatten, + ops.ArrayMap, + ops.ArraySort, + ops.ArrayUnion, + ops.BitAnd, + ops.BitOr, + ops.BitXor, + ops.Covariance, + ops.CountDistinctStar, + ops.DateAdd, + ops.DateDiff, + ops.DateSub, + ops.EndsWith, + ops.First, + ops.IntervalAdd, + ops.IntervalFromInteger, + ops.IntervalMultiply, + ops.IntervalSubtract, + ops.IsInf, + ops.IsNan, + ops.Last, + ops.LPad, + ops.Levenshtein, + ops.Map, + ops.Median, + ops.Mode, + ops.MultiQuantile, + ops.NthValue, + ops.Quantile, + ops.RegexExtract, + ops.RegexReplace, + ops.RegexSearch, + ops.RegexSplit, + ops.RowID, + ops.RPad, + ops.StartsWith, + ops.StringSplit, + ops.StringToTimestamp, + ops.StructColumn, + ops.TimestampAdd, + ops.TimestampDiff, + ops.TimestampSub, + ops.Unnest, + ) + ) + @property def NAN(self): return self.f.double("NaN") @@ -80,10 +136,6 @@ def _aggregate(self, funcname: str, *args, where): args = tuple(self.if_(where, arg, NULL) for arg in args) return func(*args) - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - @staticmethod def _generate_groups(groups): return groups @@ -99,7 +151,6 @@ def _minimize_spec(start, end, spec): return None return spec - @visit_node.register(ops.StringLength) def visit_StringLength(self, op, *, arg): """The MSSQL LEN function doesn't count trailing spaces. @@ -116,7 +167,6 @@ def visit_StringLength(self, op, *, arg): """ return paren(self.f.len(self.f.concat("A", arg, "Z")) - 2) - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): length = paren(self.f.len(self.f.concat("A", arg, "Z")) - 2) return self.f.concat( @@ -124,29 +174,24 @@ def visit_Capitalize(self, op, *, arg): self.f.lower(self.f.substring(arg, 2, length - 1)), ) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.group_concat(arg, sep) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.sum(self.if_(where, 1, 0)) return self.f.count(STAR) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.f.datepart(self.v.weekday, arg) - 1 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): days = calendar.day_name return sge.Case( @@ -154,8 +199,6 @@ def visit_DayOfWeekName(self, op, *, arg): ifs=list(map(self.if_, *zip(*enumerate(days)))), ) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimestampTruncate) def visit_DateTimestampTruncate(self, op, *, arg, unit): interval_units = { "us": "microsecond", @@ -174,23 +217,21 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit): return self.f.datetrunc(self.v[unit], arg, dialect=self.dialect) - @visit_node.register(ops.Date) + visit_DateTruncate = visit_TimestampTruncate = visit_DateTimestampTruncate + def visit_Date(self, op, *, arg): return self.cast(arg, dt.date) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.TimestampDelta) def visit_DateTimeDelta(self, op, *, left, right, part): return self.f.datediff( sge.Var(this=part.this.upper()), right, left, dialect=self.dialect ) - @visit_node.register(ops.Xor) + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta = visit_DateTimeDelta + def visit_Xor(self, op, *, left, right): return sg.and_(sg.or_(left, right), sg.not_(sg.and_(left, right))) - @visit_node.register(ops.TimestampBucket) def visit_TimestampBucket(self, op, *, arg, interval, offset): interval_units = { "ms": "millisecond", @@ -223,50 +264,50 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): return self.f.date_bucket(part, op.interval.value, arg, origin) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.cast( self.f.datediff(self.v.s, "1970-01-01 00:00:00", arg, dialect=self.dialect), dt.int64, ) - @visit_node.register(ops.ExtractYear) - @visit_node.register(ops.ExtractMonth) - @visit_node.register(ops.ExtractDay) - @visit_node.register(ops.ExtractDayOfYear) - @visit_node.register(ops.ExtractHour) - @visit_node.register(ops.ExtractMinute) - @visit_node.register(ops.ExtractSecond) - @visit_node.register(ops.ExtractMillisecond) - @visit_node.register(ops.ExtractMicrosecond) - def visit_Extract(self, op, *, arg): + def visit_ExtractTemporalComponent(self, op, *, arg): return self.f.datepart(self.v[type(op).__name__[len("Extract") :].lower()], arg) - @visit_node.register(ops.ExtractWeekOfYear) + visit_ExtractYear = ( + visit_ExtractMonth + ) = ( + visit_ExtractDay + ) = ( + visit_ExtractDayOfYear + ) = ( + visit_ExtractHour + ) = ( + visit_ExtractMinute + ) = ( + visit_ExtractSecond + ) = ( + visit_ExtractMillisecond + ) = visit_ExtractMicrosecond = visit_ExtractTemporalComponent + def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.datepart(self.v.iso_week, arg) - @visit_node.register(ops.TimeFromHMS) def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): return self.f.timefromparts(hours, minutes, seconds, 0, 0) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): return self.f.datetimefromparts(year, month, day, hours, minutes, seconds, 0) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if start is not None: return self.f.charindex(substr, arg, start) return self.f.charindex(substr, arg) - @visit_node.register(ops.Round) def visit_Round(self, op, *, arg, digits): return self.f.round(arg, digits if digits is not None else 0) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): unit = unit.short if unit == "s": @@ -321,17 +362,14 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return None - @visit_node.register(ops.Log2) def visit_Log2(self, op, *, arg): return self.f.log(arg, 2, dialect=self.dialect) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): if base is None: return self.f.log(arg, dialect=self.dialect) return self.f.log(arg, base, dialect=self.dialect) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype @@ -342,25 +380,21 @@ def visit_Cast(self, op, *, arg, to): return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00") return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.Sum) def visit_Sum(self, op, *, arg, where): if op.arg.dtype.is_boolean(): arg = self.if_(arg, 1, 0) return self.agg.sum(arg, where=where) - @visit_node.register(ops.Mean) def visit_Mean(self, op, *, arg, where): if op.arg.dtype.is_boolean(): arg = self.if_(arg, 1, 0) return self.agg.avg(arg, where=where) - @visit_node.register(ops.Not) def visit_Not(self, op, *, arg): if isinstance(arg, sge.Boolean): return FALSE if arg == TRUE else TRUE return self.if_(arg, 1, 0).eq(0) - @visit_node.register(ops.HashBytes) def visit_HashBytes(self, op, *, arg, how): if how in ("md5", "sha1"): return self.f.hashbytes(how, arg) @@ -371,7 +405,6 @@ def visit_HashBytes(self, op, *, arg, how): else: raise NotImplementedError(how) - @visit_node.register(ops.HexDigest) def visit_HexDigest(self, op, *, arg, how): if how in ("md5", "sha1"): hashbinary = self.f.hashbytes(how, arg) @@ -390,66 +423,10 @@ def visit_HexDigest(self, op, *, arg, how): ) ) - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): any_args_null = (a.is_(NULL) for a in arg) return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) - @visit_node.register(ops.Any) - @visit_node.register(ops.All) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.Array) - @visit_node.register(ops.ArrayDistinct) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayUnion) - @visit_node.register(ops.BitAnd) - @visit_node.register(ops.BitOr) - @visit_node.register(ops.BitXor) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.DateAdd) - @visit_node.register(ops.DateDiff) - @visit_node.register(ops.DateSub) - @visit_node.register(ops.EndsWith) - @visit_node.register(ops.First) - @visit_node.register(ops.IntervalAdd) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.IntervalMultiply) - @visit_node.register(ops.IntervalSubtract) - @visit_node.register(ops.IsInf) - @visit_node.register(ops.IsNan) - @visit_node.register(ops.Last) - @visit_node.register(ops.LPad) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.Map) - @visit_node.register(ops.Median) - @visit_node.register(ops.Mode) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.NthValue) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.RegexExtract) - @visit_node.register(ops.RegexReplace) - @visit_node.register(ops.RegexSearch) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.RPad) - @visit_node.register(ops.StartsWith) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.StructColumn) - @visit_node.register(ops.TimestampAdd) - @visit_node.register(ops.TimestampDiff) - @visit_node.register(ops.TimestampSub) - @visit_node.register(ops.Unnest) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.Atan2: "atn2", @@ -472,13 +449,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @MSSQLCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @MSSQLCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 641a49a34b1e..32e61425cf4a 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import string -from functools import partial, reduce, singledispatchmethod +from functools import partial, reduce import sqlglot as sg import sqlglot.expressions as sge @@ -71,6 +71,35 @@ def POS_INF(self): raise NotImplementedError("MySQL does not support Infinity") NEG_INF = POS_INF + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.ApproxMedian, + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.Array, + ops.ArrayFlatten, + ops.ArrayMap, + ops.Covariance, + ops.First, + ops.Last, + ops.Levenshtein, + ops.Median, + ops.Mode, + ops.MultiQuantile, + ops.Quantile, + ops.RegexReplace, + ops.RegexSplit, + ops.RowID, + ops.StringSplit, + ops.StructColumn, + ops.TimestampBucket, + ops.TimestampDelta, + ops.Translate, + ops.Unnest, + ) + ) def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] @@ -78,10 +107,6 @@ def _aggregate(self, funcname: str, *args, where): args = tuple(self.if_(where, arg, NULL) for arg in args) return func(*args) - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - @staticmethod def _minimize_spec(start, end, spec): if ( @@ -93,7 +118,6 @@ def _minimize_spec(start, end, spec): return None return spec - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if (from_.is_json() or from_.is_string()) and to.is_json(): @@ -108,37 +132,31 @@ def visit_Cast(self, op, *, arg, to): return self.f.from_unixtime(arg) return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.TimestampDiff) def visit_TimestampDiff(self, op, *, left, right): return self.f.timestampdiff( sge.Var(this="SECOND"), right, left, dialect=self.dialect ) - @visit_node.register(ops.DateDiff) def visit_DateDiff(self, op, *, left, right): return self.f.timestampdiff( sge.Var(this="DAY"), right, left, dialect=self.dialect ) - @visit_node.register(ops.ApproxCountDistinct) def visit_ApproxCountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.sum(self.cast(where, op.dtype)) return self.f.count(STAR) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): if where is not None: raise com.UnsupportedOperationError( @@ -149,7 +167,6 @@ def visit_CountDistinctStar(self, op, *, arg, where): sge.Distinct(expressions=list(map(func, op.arg.schema.keys()))) ) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): if not isinstance(op.sep, ops.Literal): raise com.UnsupportedOperationError( @@ -159,11 +176,9 @@ def visit_GroupConcat(self, op, *, arg, sep, where): arg = self.if_(where, arg) return self.f.group_concat(arg, sep) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.dayofweek(arg) + 5) % 7 - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): # avoid casting NULL: the set of types allowed by MySQL and # MariaDB when casting is a strict subset of allowed types in other @@ -195,7 +210,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(value.replace("\\", "\\\\")) return None - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): if op.index.dtype.is_integer(): path = self.f.concat("$[", self.cast(index, dt.string), "]") @@ -203,7 +217,6 @@ def visit_JSONGetItem(self, op, *, arg, index): path = self.f.concat("$.", index) return self.f.json_extract(arg, path) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): return self.f.str_to_date( self.f.concat( @@ -214,25 +227,20 @@ def visit_DateFromYMD(self, op, *, year, month, day): "%Y%m%d", ) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.find_in_set(needle, self.f.concat_ws(",", values)) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): to = sge.DataType(this=sge.DataType.Type.BINARY) return self.f.right(arg, self.f.char_length(end)).eq(sge.Cast(this=end, to=to)) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): to = sge.DataType(this=sge.DataType.Type.BINARY) return self.f.left(arg, self.f.length(start)).eq(sge.Cast(this=start, to=to)) - @visit_node.register(ops.RegexSearch) def visit_RegexSearch(self, op, *, arg, pattern): return arg.rlike(pattern) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): extracted = self.f.regexp_substr(arg, pattern) return self.if_( @@ -247,7 +255,6 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): NULL, ) - @visit_node.register(ops.Equals) def visit_Equals(self, op, *, left, right): if op.left.dtype.is_string(): assert op.right.dtype.is_string(), op.right.dtype @@ -255,11 +262,9 @@ def visit_Equals(self, op, *, left, right): return sge.Cast(this=left, to=to).eq(right) return super().visit_Equals(op, left=left, right=right) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.instr(haystack, needle) > 0 - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise NotImplementedError( @@ -271,7 +276,6 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.locate(substr, arg, start + 1) return self.f.locate(substr, arg) - @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): return self.f.concat( self.f.upper(self.f.left(arg, 1)), self.f.lower(self.f.substr(arg, 2)) @@ -287,8 +291,6 @@ def visit_LRStrip(self, op, *, arg, position): arg, ) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimestampTruncate) def visit_DateTimestampTruncate(self, op, *, arg, unit): truncate_formats = { "s": "%Y-%m-%d %H:%i:%s", @@ -303,38 +305,33 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit): raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit}") return self.f.date_format(arg, format) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) + visit_DateTruncate = visit_TimestampTruncate = visit_DateTimestampTruncate + def visit_DateTimeDelta(self, op, *, left, right, part): return self.f.timestampdiff( sge.Var(this=part.this), right, left, dialect=self.dialect ) - @visit_node.register(ops.ExtractMillisecond) + visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta + def visit_ExtractMillisecond(self, op, *, arg): return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg)) - @visit_node.register(ops.Strip) def visit_Strip(self, op, *, arg): return self.visit_LRStrip(op, arg=arg, position="BOTH") - @visit_node.register(ops.LStrip) def visit_LStrip(self, op, *, arg): return self.visit_LRStrip(op, arg=arg, position="LEADING") - @visit_node.register(ops.RStrip) def visit_RStrip(self, op, *, arg): return self.visit_LRStrip(op, arg=arg, position="TRAILING") - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): return sge.Interval(this=arg, unit=sge.convert(op.resolution.upper())) - @visit_node.register(ops.TimestampAdd) def visit_TimestampAdd(self, op, *, left, right): if op.right.dtype.unit.short == "ms": right = sge.Interval( @@ -342,34 +339,6 @@ def visit_TimestampAdd(self, op, *, left, right): ) return self.f.date_add(left, right, dialect=self.dialect) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.Array) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.First) - @visit_node.register(ops.Last) - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.Median) - @visit_node.register(ops.Mode) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.RegexReplace) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.RowID) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.StructColumn) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.Translate) - @visit_node.register(ops.Unnest) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.BitAnd: "bit_and", @@ -391,13 +360,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @MySQLCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @MySQLCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index edab44f726a7..428cdd6ee45f 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge import toolz @@ -13,7 +11,6 @@ from ibis.backends.base.sqlglot.datatypes import OracleType from ibis.backends.base.sqlglot.dialects import Oracle from ibis.backends.base.sqlglot.rewrites import ( - Window, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, replace_log2, @@ -52,6 +49,42 @@ class OracleCompiler(SQLGlotCompiler): NEG_INF = sge.Literal.number("-binary_double_infinity") """Backend's negative infinity literal.""" + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Arbitrary, + ops.ArgMax, + ops.ArgMin, + ops.ArrayCollect, + ops.Array, + ops.ArrayFlatten, + ops.ArrayMap, + ops.ArrayStringJoin, + ops.First, + ops.Last, + ops.Mode, + ops.MultiQuantile, + ops.RegexSplit, + ops.StringSplit, + ops.TimeTruncate, + ops.Bucket, + ops.TimestampBucket, + ops.TimeDelta, + ops.DateDelta, + ops.TimestampDelta, + ops.TimestampNow, + ops.TimestampFromYMDHMS, + ops.TimeFromHMS, + ops.IntervalFromInteger, + ops.DayOfWeekIndex, + ops.DayOfWeekName, + ops.DateDiff, + ops.ExtractEpochSeconds, + ops.ExtractWeekOfYear, + ops.ExtractDayOfYear, + ops.RowID, + ) + ) + def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] if where is not None: @@ -62,11 +95,6 @@ def _aggregate(self, funcname: str, *args, where): def _generate_groups(groups): return groups - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - - @visit_node.register(ops.Equals) def visit_Equals(self, op, *, left, right): # Oracle didn't have proper boolean types until recently and we handle them # as integers so we end up with things like "t0"."bool_col" = 1 (for True) @@ -81,7 +109,6 @@ def visit_Equals(self, op, *, left, right): return sg.not_(left) return super().visit_Equals(op, left=left, right=right) - @visit_node.register(ops.IsNull) def visit_IsNull(self, op, *, arg): # TODO(gil): find a better way to handle this # but CASE WHEN (bool_col = 1) IS NULL isn't valid and we can simply check if @@ -90,7 +117,6 @@ def visit_IsNull(self, op, *, arg): return arg.this.is_(NULL) return arg.is_(NULL) - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): # avoid casting NULL -- oracle handling for these casts is... complicated if value is None: @@ -122,7 +148,6 @@ def visit_Literal(self, op, *, value, dtype): return super().visit_Literal(op, value=value, dtype=dtype) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): if to.is_interval(): # CASTing to an INTERVAL in Oracle requires specifying digits of @@ -138,7 +163,6 @@ def visit_Cast(self, op, *, arg, to): ) return self.cast(arg, to) - @visit_node.register(ops.Limit) def visit_Limit(self, op, *, parent, n, offset): # push limit/offset into subqueries if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: @@ -173,61 +197,47 @@ def visit_Limit(self, op, *, parent, n, offset): return result.subquery(alias) return result - @visit_node.register(ops.Date) def visit_Date(self, op, *, arg): return sg.cast(arg, to="date") - @visit_node.register(ops.IsNan) def visit_IsNan(self, op, *, arg): return arg.eq(self.NAN) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): return self.f.log(base, arg, dialect=self.dialect) - @visit_node.register(ops.IsInf) def visit_IsInf(self, op, *, arg): return arg.isin(self.POS_INF, self.NEG_INF) - @visit_node.register(ops.RandomScalar) def visit_RandomScalar(self, op): # Not using FuncGen here because of dotted function call return sg.func("dbms_random.value") - @visit_node.register(ops.Pi) def visit_Pi(self, op): return self.f.acos(-1) - @visit_node.register(ops.Cot) def visit_Cot(self, op, *, arg): return 1 / self.f.tan(arg) - @visit_node.register(ops.Degrees) def visit_Degrees(self, op, *, arg): return 180 * arg / self.visit_node(ops.Pi()) - @visit_node.register(ops.Radians) def visit_Radians(self, op, *, arg): return self.visit_node(ops.Pi()) * arg / 180 - @visit_node.register(ops.Modulus) def visit_Modulus(self, op, *, left, right): return self.f.mod(left, right) - @visit_node.register(ops.Levenshtein) def visit_Levenshtein(self, op, *, left, right): # Not using FuncGen here because of dotted function call return sg.func("utl_match.edit_distance", left, right) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return self.f.substr(arg, 0, self.f.length(start)).eq(start) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.substr(arg, -1 * self.f.length(end), self.f.length(end)).eq(end) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise NotImplementedError("`end` is not implemented") @@ -242,11 +252,9 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.instr(arg, sub_string) - @visit_node.register(ops.StrRight) def visit_StrRight(self, op, *, arg, nchars): return self.f.substr(arg, -nchars) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): return self.if_( index.eq(0), @@ -254,21 +262,17 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): self.f.regexp_substr(arg, pattern, 1, 1, "cn", index), ) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return sge.RegexpReplace(this=arg, expression=pattern, replacement=replacement) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.instr(haystack, needle) > 0 - @visit_node.register(ops.StringJoin) def visit_StringJoin(self, op, *, arg, sep): return self.f.concat(*toolz.interpose(sep, arg)) ## Aggregate stuff - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, where, how): if how == "sample": raise ValueError( @@ -276,17 +280,14 @@ def visit_Correlation(self, op, *, left, right, where, how): ) return self.agg.corr(left, right, where=where) - @visit_node.register(ops.Covariance) def visit_Covariance(self, op, *, left, right, where, how): if how == "sample": return self.agg.covar_samp(left, right, where=where) return self.agg.covar_pop(left, right, where=where) - @visit_node.register(ops.ApproxMedian) def visit_ApproxMedian(self, op, *, arg, where): return self.visit_Quantile(op, arg=arg, quantile=0.5, where=where) - @visit_node.register(ops.Quantile) def visit_Quantile(self, op, *, arg, quantile, where): suffix = "cont" if op.arg.dtype.is_numeric() else "disc" funcname = f"percentile_{suffix}" @@ -300,20 +301,17 @@ def visit_Quantile(self, op, *, arg, quantile, where): ) return expr - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg) return sge.Count(this=sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.count(self.if_(where, 1, NULL)) return self.f.count(STAR) - @visit_node.register(ops.IdenticalTo) def visit_IdenticalTo(self, op, *, left, right): # sqlglot NullSafeEQ uses "is not distinct from" which isn't supported in oracle return ( @@ -323,12 +321,9 @@ def visit_IdenticalTo(self, op, *, left, right): .eq(0) ) - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): return (left.or_(right)).and_(sg.not_(left.and_(right))) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.DateTruncate) def visit_DateTruncate(self, op, *, arg, unit): trunc_unit_mapping = { "Y": "year", @@ -359,7 +354,8 @@ def visit_DateTruncate(self, op, *, arg, unit): return self.f.trunc(arg, unyt) - @visit_node.register(Window) + visit_TimestampTruncate = visit_DateTruncate + def visit_Window(self, op, *, how, func, start, end, group_by, order_by): # Oracle has two (more?) types of analytic functions you can use inside OVER. # @@ -441,45 +437,10 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by): return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): any_args_null = (a.is_(NULL) for a in arg) return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg)) - @visit_node.register(ops.Arbitrary) - @visit_node.register(ops.ArgMax) - @visit_node.register(ops.ArgMin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.Array) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArrayStringJoin) - @visit_node.register(ops.First) - @visit_node.register(ops.Last) - @visit_node.register(ops.Mode) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.TimeTruncate) - @visit_node.register(ops.Bucket) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.TimestampNow) - @visit_node.register(ops.TimestampFromYMDHMS) - @visit_node.register(ops.TimeFromHMS) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.DayOfWeekIndex) - @visit_node.register(ops.DayOfWeekName) - @visit_node.register(ops.DateDiff) - @visit_node.register(ops.ExtractEpochSeconds) - @visit_node.register(ops.ExtractWeekOfYear) - @visit_node.register(ops.ExtractDayOfYear) - @visit_node.register(ops.RowID) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.ApproxCountDistinct: "approx_count_distinct", @@ -499,13 +460,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @OracleCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @OracleCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 819fb995dc7f..b03331fe80d5 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -379,19 +379,19 @@ def get_schema(self, table_name, database=None): @classmethod @lru_cache def _get_operations(cls): - return frozenset(op for op in translate.registry if issubclass(op, ops.Value)) + return tuple(op for op in translate.registry if issubclass(op, ops.Value)) @classmethod def has_operation(cls, operation: type[ops.Value]) -> bool: # Polars doesn't support geospatial ops, but the dispatcher implements # a common base class that makes it appear that it does. Explicitly # exclude these operations. - if issubclass(operation, (ops.GeoSpatialUnOp, ops.GeoSpatialBinOp)): + if issubclass( + operation, (ops.GeoSpatialUnOp, ops.GeoSpatialBinOp, ops.GeoUnaryUnion) + ): return False op_classes = cls._get_operations() - return operation in op_classes or any( - issubclass(operation, op_impl) for op_impl in op_classes - ) + return operation in op_classes or issubclass(operation, op_classes) def compile( self, expr: ir.Expr, params: Mapping[ir.Expr, object] | None = None, **_: Any diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 27fa2cd34122..06980cac77d2 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import string -from functools import partial, reduce, singledispatchmethod +from functools import partial, reduce import sqlglot as sg import sqlglot.expressions as sge @@ -32,6 +32,13 @@ class PostgresCompiler(SQLGlotCompiler): NAN = sge.Literal.number("'NaN'::double precision") POS_INF = sge.Literal.number("'Inf'::double precision") NEG_INF = sge.Literal.number("'-Inf'::double precision") + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.RowID, + ops.TimeDelta, + ops.ArrayFlatten, + ) + ) def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) @@ -39,11 +46,6 @@ def _aggregate(self, funcname: str, *args, where): return sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - - @visit_node.register(ops.Mode) def visit_Mode(self, op, *, arg, where): expr = self.f.mode() expr = sge.WithinGroup( @@ -66,15 +68,12 @@ def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool): ) return paren(agg)[0] - @visit_node.register(ops.ArgMin) def visit_ArgMin(self, op, *, arg, key, where): return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False) - @visit_node.register(ops.ArgMax) def visit_ArgMax(self, op, *, arg, key, where): return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=True) - @visit_node.register(ops.Sum) def visit_Sum(self, op, *, arg, where): arg = ( self.cast(self.cast(arg, dt.int32), op.dtype) @@ -83,15 +82,12 @@ def visit_Sum(self, op, *, arg, where): ) return self.agg.sum(arg, where=where) - @visit_node.register(ops.IsNan) def visit_IsNan(self, op, *, arg): return arg.eq(self.cast(sge.convert("NaN"), op.arg.dtype)) - @visit_node.register(ops.IsInf) def visit_IsInf(self, op, *, arg): return arg.isin(self.POS_INF, self.NEG_INF) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, where, arg): # use a tuple because postgres doesn't accept COUNT(DISTINCT a, b, c, ...) # @@ -103,7 +99,6 @@ def visit_CountDistinctStar(self, op, *, where, arg): ) return self.agg.count(sge.Distinct(expressions=[row]), where=where) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -119,20 +114,15 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.ApproxMedian) def visit_ApproxMedian(self, op, *, arg, where): return self.visit_Median(op, arg=arg, where=where) - @visit_node.register(ops.Median) def visit_Median(self, op, *, arg, where): return self.visit_Quantile(op, arg=arg, quantile=sge.convert(0.5), where=where) - @visit_node.register(ops.ApproxCountDistinct) def visit_ApproxCountDistinct(self, op, *, arg, where): return self.agg.count(sge.Distinct(expressions=[arg]), where=where) - @visit_node.register(ops.IntegerRange) - @visit_node.register(ops.TimestampRange) def visit_Range(self, op, *, start, stop, step): def zero_value(dtype): if dtype.is_interval(): @@ -170,15 +160,14 @@ def _sign(value, dtype): self.cast(self.f.array(), op.dtype), ) - @visit_node.register(ops.StringConcat) + visit_IntegerRange = visit_TimestampRange = visit_Range + def visit_StringConcat(self, op, *, arg): return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg) - @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): return reduce(self.f.array_cat, map(partial(self.cast, to=op.dtype), arg)) - @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): arg_dtype = op.arg.dtype return sge.ArrayContains( @@ -186,7 +175,6 @@ def visit_ArrayContains(self, op, *, arg, other): expression=self.f.array(self.cast(other, arg_dtype.value_type)), ) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, body, param): return self.f.array( sg.select(sg.column(param, quoted=self.quoted)) @@ -194,13 +182,11 @@ def visit_ArrayFilter(self, op, *, arg, body, param): .where(body) ) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, body, param): return self.f.array( sg.select(body).from_(sge.Unnest(expressions=[arg], alias=param)) ) - @visit_node.register(ops.ArrayPosition) def visit_ArrayPosition(self, op, *, arg, other): t = sge.Unnest(expressions=[arg], alias="value", offset=True) idx = sg.column("ordinality") @@ -209,13 +195,11 @@ def visit_ArrayPosition(self, op, *, arg, other): sg.select(idx).from_(t).where(value.eq(other)).limit(1).subquery(), 0 ) - @visit_node.register(ops.ArraySort) def visit_ArraySort(self, op, *, arg): return self.f.array( sg.select("x").from_(sge.Unnest(expressions=[arg], alias="x")).order_by("x") ) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): i = sg.to_identifier("i") length = self.f.cardinality(arg) @@ -225,19 +209,16 @@ def visit_ArrayRepeat(self, op, *, arg, times): ) ) - @visit_node.register(ops.ArrayDistinct) def visit_ArrayDistinct(self, op, *, arg): return self.if_( arg.is_(NULL), NULL, self.f.array(sg.select(self.f.explode(arg)).distinct()) ) - @visit_node.register(ops.ArrayUnion) def visit_ArrayUnion(self, op, *, left, right): return self.f.anon.array( sg.union(sg.select(self.f.explode(left)), sg.select(self.f.explode(right))) ) - @visit_node.register(ops.ArrayIntersect) def visit_ArrayIntersect(self, op, *, left, right): return self.f.anon.array( sg.intersect( @@ -245,7 +226,6 @@ def visit_ArrayIntersect(self, op, *, left, right): ) ) - @visit_node.register(ops.Log2) def visit_Log2(self, op, *, arg): return self.cast( self.f.log( @@ -255,7 +235,6 @@ def visit_Log2(self, op, *, arg): op.dtype, ) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): if base is not None: if not op.base.dtype.is_decimal(): @@ -267,7 +246,6 @@ def visit_Log(self, op, *, arg, base): arg = self.cast(arg, dt.decimal) return self.cast(self.f.log(base, arg), op.dtype) - @visit_node.register(ops.StructField) def visit_StructField(self, op, *, arg, field): idx = op.arg.dtype.names.index(field) + 1 # postgres doesn't have anonymous structs :( @@ -282,11 +260,9 @@ def visit_StructField(self, op, *, arg, field): op.dtype, ) - @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, names, values): return self.f.row(*map(self.cast, values, op.dtype.types)) - @visit_node.register(ops.ToJSONArray) def visit_ToJSONArray(self, op, *, arg): return self.if_( self.f.json_typeof(arg).eq(sge.convert("array")), @@ -294,23 +270,18 @@ def visit_ToJSONArray(self, op, *, arg): NULL, ) - @visit_node.register(ops.Map) def visit_Map(self, op, *, keys, values): return self.f.map(self.f.array(*keys), self.f.array(*values)) - @visit_node.register(ops.MapLength) def visit_MapLength(self, op, *, arg): return self.f.cardinality(self.f.akeys(arg)) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): return self.if_(self.f.exist(arg, key), self.f.json_extract(arg, key), default) - @visit_node.register(ops.MapMerge) def visit_MapMerge(self, op, *, left, right): return sge.DPipe(this=left, expression=right) - @visit_node.register(ops.TypeOf) def visit_TypeOf(self, op, *, arg): typ = self.cast(self.f.pg_typeof(arg), dt.string) return self.if_( @@ -319,7 +290,6 @@ def visit_TypeOf(self, op, *, arg): typ, ) - @visit_node.register(ops.Round) def visit_Round(self, op, *, arg, digits): if digits is None: return self.f.round(arg) @@ -329,7 +299,6 @@ def visit_Round(self, op, *, arg, digits): return result return self.cast(result, dt.float64) - @visit_node.register(ops.Modulus) def visit_Modulus(self, op, *, left, right): # postgres doesn't allow modulus of double precision values, so upcast and # then downcast later if necessary @@ -343,24 +312,20 @@ def visit_Modulus(self, op, *, left, right): else: return result - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): pattern = self.f.concat("(", pattern, ")") matches = self.f.regexp_match(arg, pattern) return self.if_(arg.rlike(pattern), paren(matches)[index], NULL) - @visit_node.register(ops.FindInSet) def visit_FindInSet(self, op, *, needle, values): return self.f.coalesce( self.f.array_position(self.f.array(*values), needle), 0, ) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.strpos(haystack, needle) > 0 - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.right(arg, self.f.length(end)).eq(end) @@ -380,7 +345,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return self.cast(value, dt.json) return None - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): @@ -394,12 +358,10 @@ def visit_TimestampFromYMDHMS( self.cast(seconds, dt.float64), ) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): to_int32 = partial(self.cast, to=dt.int32) return self.f.datefromparts(to_int32(year), to_int32(month), to_int32(day)) - @visit_node.register(ops.TimestampBucket) def visit_TimestampBucket(self, op, *, arg, interval, offset): origin = self.f.make_timestamp( *map(partial(self.cast, to=dt.int32), (1970, 1, 1, 0, 0, 0)) @@ -410,46 +372,36 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): return self.f.date_bin(interval, arg, origin) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.cast(self.f.extract("dow", arg) + 6, dt.int16) % 7 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return self.f.trim(self.f.to_char(arg, "Day"), string.whitespace) - @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): return self.cast(self.f.floor(self.f.extract("second", arg)), op.dtype) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): return self.cast( self.f.floor(self.f.extract("millisecond", arg)) % 1_000, op.dtype ) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.extract("microsecond", arg) % 1_000_000 - @visit_node.register(ops.ExtractDayOfYear) def visit_ExtractDayOfYear(self, op, *, arg): return self.f.extract("doy", arg) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): return self.f.extract("week", arg) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.extract("epoch", arg) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): index = self.if_(index < 0, self.f.cardinality(arg) + index, index) return paren(arg)[index + 1] - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): neg_to_pos_index = lambda n, index: self.if_(index < 0, n + index, index) @@ -468,7 +420,6 @@ def visit_ArraySlice(self, op, *, arg, start, stop): slice_expr = sge.Slice(this=start + 1, expression=stop) return paren(arg)[slice_expr] - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): plural = unit.plural if plural == "minutes": @@ -493,8 +444,6 @@ def visit_IntervalFromInteger(self, op, *, arg, unit): return self.f.make_interval(sge.Kwarg(this=key, expression=arg)) - @visit_node.register(ops.Cast) - @visit_node.register(ops.TryCast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype @@ -519,11 +468,7 @@ def visit_Cast(self, op, *, arg, to): return self.cast(arg, op.to) - @visit_node.register(ops.RowID) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.ArrayFlatten) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) + visit_TryCast = visit_Cast _SIMPLE_OPS = { @@ -589,13 +534,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @PostgresCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @PostgresCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/postgres/tests/test_functions.py b/ibis/backends/postgres/tests/test_functions.py index d4929de48b6d..64e3f3cfc25a 100644 --- a/ibis/backends/postgres/tests/test_functions.py +++ b/ibis/backends/postgres/tests/test_functions.py @@ -672,7 +672,8 @@ def test_interactive_repr_shows_error(alltypes): with config.option_context("interactive", True): result = repr(expr) - assert "OperationNotDefinedError('BaseConvert')" in result + assert "OperationNotDefinedError" in result + assert "BaseConvert" in result def test_subquery(alltypes, df): diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 9bf9bd1ffe58..29e84b18ba7c 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -3,7 +3,6 @@ import calendar import itertools import re -from functools import singledispatchmethod import sqlglot as sg import sqlglot.expressions as sge @@ -16,7 +15,7 @@ from ibis.backends.base.sqlglot.compiler import FALSE, NULL, STAR, TRUE, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import PySparkType from ibis.backends.base.sqlglot.dialects import PySpark -from ibis.backends.base.sqlglot.rewrites import Window, p +from ibis.backends.base.sqlglot.rewrites import p from ibis.common.patterns import replace from ibis.config import options from ibis.util import gen_name @@ -53,16 +52,19 @@ class PySparkCompiler(SQLGlotCompiler): type_mapper = PySparkType rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.RowID, + ops.TimestampBucket, + ) + ) + def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] if where is not None: args = tuple(self.if_(where, arg, NULL) for arg in args) return func(*args) - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) - def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_floating(): result = super().visit_NonNullLiteral(op, value=value, dtype=dtype) @@ -85,7 +87,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - @visit_node.register(ops.Field) def visit_Field(self, op, *, rel, name): result = super().visit_Field(op, rel=rel, name=name) if op.dtype.is_floating() and options.pyspark.treat_nan_as_null: @@ -93,7 +94,6 @@ def visit_Field(self, op, *, rel, name): else: return result - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): if to.is_json(): if op.arg.dtype.is_string(): @@ -103,7 +103,6 @@ def visit_Cast(self, op, *, arg, to): else: return self.cast(arg, to) - @visit_node.register(ops.IsNull) def visit_IsNull(self, op, *, arg): is_null = arg.is_(NULL) is_nan = self.f.isnan(arg) @@ -112,7 +111,6 @@ def visit_IsNull(self, op, *, arg): else: return is_null - @visit_node.register(ops.NotNull) def visit_NotNull(self, op, *, arg): is_not_null = arg.is_(sg.not_(NULL)) is_not_nan = sg.not_(self.f.isnan(arg)) @@ -121,56 +119,45 @@ def visit_NotNull(self, op, *, arg): else: return is_not_null - @visit_node.register(ops.IsInf) def visit_IsInf(self, op, *, arg): if op.arg.dtype.is_floating(): return sg.or_(arg == self.POS_INF, arg == self.NEG_INF) return FALSE - @visit_node.register(ops.Xor) def visit_Xor(self, op, left, right): return (left | right) & ~(left & right) - @visit_node.register(ops.Time) def visit_Time(self, op, *, arg): return arg - self.f.anon.date_trunc("day", arg) - @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): arg = self.f.concat(arg, sge.convert(f" {unit.plural}")) typ = sge.DataType(this=sge.DataType.Type.INTERVAL) return sg.cast(sge.convert(arg), to=typ) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return (self.f.dayofweek(arg) + 5) % 7 - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return sge.Case( this=(self.f.dayofweek(arg) + 5) % 7, ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))), ) - @visit_node.register(ops.ExtractDayOfYear) def visit_ExtractDayOfYear(self, op, *, arg): return self.cast(self.f.dayofyear(arg), op.dtype) - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): return self.cast(self.f.date_format(arg, "SSS"), op.dtype) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): raise com.UnsupportedOperationError( "PySpark backend does not support extracting microseconds." ) - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.unix_timestamp(self.cast(arg, dt.timestamp)) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): if not op.unit: return self.f.to_timestamp(self.f.from_unixtime(arg)) @@ -183,9 +170,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): f"unit {op.unit.short}. Supported unit is s." ) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimeTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): if unit.short == "ns": raise com.UnsupportedOperationError( @@ -193,19 +177,18 @@ def visit_TimestampTruncate(self, op, *, arg, unit): ) return self.f.anon.date_trunc(unit.singular, arg) - @visit_node.register(ops.CountStar) + visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate + def visit_CountStar(self, op, *, arg, where): if where is not None: return self.f.sum(self.cast(where, op.dtype)) return self.f.count(STAR) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): if where is None: return self.f.count(sge.Distinct(expressions=[STAR])) @@ -220,19 +203,16 @@ def visit_CountDistinctStar(self, op, *, arg, where): ] return self.f.count(sge.Distinct(expressions=cols)) - @visit_node.register(ops.First) def visit_First(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.first(arg, TRUE) - @visit_node.register(ops.Last) def visit_Last(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.last(arg, TRUE) - @visit_node.register(ops.Arbitrary) def visit_Arbitrary(self, op, *, arg, how, where): if where is not None: arg = self.if_(where, arg, NULL) @@ -246,11 +226,9 @@ def visit_Arbitrary(self, op, *, arg, how, where): "Supported values are `first` and `last`." ) - @visit_node.register(ops.Median) def visit_Median(self, op, *, arg, where): return self.agg.percentile(arg, 0.5, where=where) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, sep, where): if where is not None: arg = self.if_(where, arg, NULL) @@ -258,7 +236,6 @@ def visit_GroupConcat(self, op, *, arg, sep, where): collected = self.if_(self.f.size(collected).eq(0), NULL, collected) return self.f.array_join(collected, sep) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if (left_type := op.left.dtype).is_boolean(): left = self.cast(left, dt.Int32(nullable=left_type.nullable)) @@ -278,18 +255,15 @@ def _build_sequence(self, start, stop, step, zero): self.f.array(), ) - @visit_node.register(ops.IntegerRange) def visit_IntegerRange(self, op, *, start, stop, step): zero = sge.convert(0) return self._build_sequence(start, stop, step, zero) - @visit_node.register(ops.TimestampRange) def visit_TimestampRange(self, op, *, start, stop, step): unit = op.step.dtype.resolution zero = sge.Interval(this=sge.convert(0), unit=unit) return self._build_sequence(start, stop, step, zero) - @visit_node.register(ops.Sample) def visit_Sample( self, op, *, parent, fraction: float, method: str, seed: int | None, **_ ): @@ -303,7 +277,6 @@ def visit_Sample( ) return sg.select(STAR).from_(sample) - @visit_node.register(ops.WindowBoundary) def visit_WindowBoundary(self, op, *, value, preceding): if isinstance(op.value, ops.Literal) and op.value.value == 0: value = "CURRENT ROW" @@ -333,29 +306,25 @@ def __sql_name__(self, op) -> str: return f"ibis_udf_{name}" - @visit_node.register(ops.ElementWiseVectorizedUDF) - @visit_node.register(ops.ReductionVectorizedUDF) def visit_VectorizedUDF(self, op, *, func, func_args, input_type, return_type): return self.f[self.__sql_name__(op)](*func_args) - @visit_node.register(ops.MapGet) + visit_ElementWiseVectorizedUDF = visit_ReductionVectorizedUDF = visit_VectorizedUDF + def visit_MapGet(self, op, *, arg, key, default): if default is None: return arg[key] else: return self.if_(self.f.map_contains_key(arg, key), arg[key], default) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op, *, arg): return self.f.arrays_zip(*arg) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, body, param): param = sge.Identifier(this=param) func = sge.Lambda(this=body, expressions=[param]) return self.f.transform(arg, func) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, body, param): param = sge.Identifier(this=param) func = sge.Lambda(this=self.if_(body, param, NULL), expressions=[param]) @@ -363,19 +332,15 @@ def visit_ArrayFilter(self, op, *, arg, body, param): func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=[param]) return self.f.filter(transform, func) - @visit_node.register(ops.ArrayIndex) def visit_ArrayIndex(self, op, *, arg, index): return self.f.element_at(arg, index + 1) - @visit_node.register(ops.ArrayPosition) def visit_ArrayPosition(self, op, *, arg, other): return self.f.array_position(arg, other) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): return self.f.flatten(self.f.array_repeat(arg, times)) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): size = self.f.array_size(arg) start = self.if_(start < 0, self.if_(start < -size, 0, size + start), start) @@ -387,7 +352,6 @@ def visit_ArraySlice(self, op, *, arg, start, stop): length = self.if_(stop < start, 0, stop - start) return self.f.slice(arg, start + 1, length) - @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): return self.if_( arg.is_(NULL), @@ -395,11 +359,9 @@ def visit_ArrayContains(self, op, *, arg, other): self.f.coalesce(self.f.array_contains(arg, other), FALSE), ) - @visit_node.register(ops.ArrayStringJoin) def visit_ArrayStringJoin(self, op, *, arg, sep): return self.f.concat_ws(sep, arg) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if end is not None: raise com.UnsupportedOperationError( @@ -413,11 +375,9 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.instr(arg, substr) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return self.f.regexp_replace(arg, pattern, replacement) - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): if op.index.dtype.is_integer(): fmt = "$[%s]" @@ -426,7 +386,6 @@ def visit_JSONGetItem(self, op, *, arg, index): path = self.f.format_string(fmt, index) return self.f.get_json_object(arg, path) - @visit_node.register(Window) def visit_Window(self, op, *, func, group_by, order_by, **kwargs): if isinstance(op.func, ops.Analytic): # spark disallows specifying boundaries for lead/lag @@ -437,11 +396,10 @@ def visit_Window(self, op, *, func, group_by, order_by, **kwargs): order = sge.Order(expressions=[NULL]) return sge.Window(this=func, partition_by=group_by, order=order) else: - return super().visit_node( + return super().visit_Window( op, func=func, group_by=group_by, order_by=order_by, **kwargs ) - @visit_node.register(ops.JoinLink) def visit_JoinLink(self, op, **kwargs): if op.how == "asof": raise com.UnsupportedOperationError( @@ -452,12 +410,6 @@ def visit_JoinLink(self, op, **kwargs): ) return super().visit_JoinLink(op, **kwargs) - @visit_node.register(ops.RowID) - @visit_node.register(ops.TimestampBucket) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - - @visit_node.register(ops.HexDigest) def visit_HexDigest(self, op, *, arg, how): if how == "md5": return self.f.md5(arg) @@ -492,7 +444,6 @@ def visit_HexDigest(self, op, *, arg, how): for _op, _name in _SIMPLE_OPS.items(): assert isinstance(type(_op), type), type(_op) - @PySparkCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py index 6aeb21ec9ddb..f0bd3e2fa40e 100644 --- a/ibis/backends/risingwave/compiler.py +++ b/ibis/backends/risingwave/compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot.expressions as sge from public import public @@ -9,6 +7,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.backends.base.sqlglot.compiler import ALL_OPERATIONS from ibis.backends.base.sqlglot.datatypes import RisingWaveType from ibis.backends.base.sqlglot.dialects import RisingWave from ibis.backends.postgres.compiler import PostgresCompiler @@ -21,11 +20,20 @@ class RisingwaveCompiler(PostgresCompiler): dialect = RisingWave type_mapper = RisingWaveType - @singledispatchmethod - def visit_node(self, op, **kwargs): - return super().visit_node(op, **kwargs) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.DateFromYMD, + ops.Mode, + *( + op + for op in ALL_OPERATIONS + if issubclass( + op, (ops.GeoSpatialUnOp, ops.GeoSpatialBinOp, ops.GeoUnaryUnion) + ) + ), + ) + ) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -35,9 +43,6 @@ def visit_Correlation(self, op, *, left, right, how, where): op, left=left, right=right, how=how, where=where ) - @visit_node.register(ops.TimestampTruncate) - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimeTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", @@ -57,7 +62,8 @@ def visit_TimestampTruncate(self, op, *, arg, unit): return self.f.date_trunc(unit, arg) - @visit_node.register(ops.IntervalFromInteger) + visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate + def visit_IntervalFromInteger(self, op, *, arg, unit): if op.arg.shape == ds.scalar: return sge.Interval(this=arg, unit=self.v[unit.name]) @@ -75,11 +81,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(str(value)) return None - @visit_node.register(ops.DateFromYMD) - @visit_node.register(ops.Mode) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.First: "first_value", @@ -90,13 +91,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @RisingwaveCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @RisingwaveCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index 1604f6f919a0..15875529fe43 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from functools import partial, singledispatchmethod +from functools import partial import sqlglot as sg import sqlglot.expressions as sge @@ -23,13 +23,6 @@ rewrite_first_to_first_value, rewrite_last_to_last_value, ) -from ibis.common.patterns import replace -from ibis.expr.analysis import p - - -@replace(p.ToJSONMap | p.ToJSONArray) -def replace_to_json(_): - return ops.Cast(_.arg, to=_.dtype) class SnowflakeFuncGen(FuncGen): @@ -44,7 +37,6 @@ class SnowflakeCompiler(SQLGlotCompiler): type_mapper = SnowflakeType no_limit_value = NULL rewrites = ( - replace_to_json, exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, rewrite_first_to_first_value, @@ -55,6 +47,18 @@ class SnowflakeCompiler(SQLGlotCompiler): *SQLGlotCompiler.rewrites, ) + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.ArrayMap, + ops.ArrayFilter, + ops.RowID, + ops.MultiQuantile, + ops.IntervalFromInteger, + ops.IntervalAdd, + ops.TimestampDiff, + ) + ) + def __init__(self): super().__init__() self.f = SnowflakeFuncGen() @@ -66,10 +70,6 @@ def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] return func(*args) - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - @staticmethod def _minimize_spec(start, end, spec): if ( @@ -81,7 +81,6 @@ def _minimize_spec(start, end, spec): return None return spec - @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype): if value is None: return super().visit_Literal(op, value=value, dtype=dtype) @@ -140,9 +139,8 @@ def visit_Literal(self, op, *, value, dtype): return sge.convert(str(value)) elif dtype.is_binary(): return sge.HexString(this=value.hex()) - return super().visit_node(op, value=value, dtype=dtype) + return super().visit_Literal(op, value=value, dtype=dtype) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): if to.is_struct() or to.is_map(): return self.if_(self.f.is_object(arg), arg, NULL) @@ -150,30 +148,30 @@ def visit_Cast(self, op, *, arg, to): return self.if_(self.f.is_array(arg), arg, NULL) return self.cast(arg, to) - @visit_node.register(ops.IsNan) + def visit_ToJSONMap(self, op, *, arg): + return self.if_(self.f.is_object(arg), arg, NULL) + + def visit_ToJSONArray(self, op, *, arg): + return self.if_(self.f.is_array(arg), arg, NULL) + def visit_IsNan(self, op, *, arg): return arg.eq(self.NAN) - @visit_node.register(ops.IsInf) def visit_IsInf(self, op, *, arg): return arg.isin(self.POS_INF, self.NEG_INF) - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): return self.f.get(arg, index) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): args = [substr, arg] if start is not None: args.append(start + 1) return self.f.position(*args) - @visit_node.register(ops.RegexSplit) def visit_RegexSplit(self, op, *, arg, pattern): return self.f.udf.regexp_split(arg, pattern) - @visit_node.register(ops.Map) def visit_Map(self, op, *, keys, values): return self.if_( sg.and_(self.f.is_array(keys), self.f.is_array(values)), @@ -181,15 +179,12 @@ def visit_Map(self, op, *, keys, values): NULL, ) - @visit_node.register(ops.MapKeys) def visit_MapKeys(self, op, *, arg): return self.if_(self.f.is_object(arg), self.f.object_keys(arg), NULL) - @visit_node.register(ops.MapValues) def visit_MapValues(self, op, *, arg): return self.if_(self.f.is_object(arg), self.f.udf.object_values(arg), NULL) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): dtype = op.dtype expr = self.f.coalesce(self.f.get(arg, key), self.f.to_variant(default)) @@ -197,14 +192,12 @@ def visit_MapGet(self, op, *, arg, key, default): return expr return self.cast(expr, dtype) - @visit_node.register(ops.MapContains) def visit_MapContains(self, op, *, arg, key): return self.f.array_contains( self.if_(self.f.is_object(arg), self.f.object_keys(arg), NULL), self.f.to_variant(key), ) - @visit_node.register(ops.MapMerge) def visit_MapMerge(self, op, *, left, right): return self.if_( sg.and_(self.f.is_object(left), self.f.is_object(right)), @@ -212,40 +205,31 @@ def visit_MapMerge(self, op, *, left, right): NULL, ) - @visit_node.register(ops.MapLength) def visit_MapLength(self, op, *, arg): return self.if_( self.f.is_object(arg), self.f.array_size(self.f.object_keys(arg)), NULL ) - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): return self.f.log(base, arg, dialect=self.dialect) - @visit_node.register(ops.RandomScalar) def visit_RandomScalar(self, op): return self.f.uniform( self.f.to_double(0.0), self.f.to_double(1.0), self.f.random() ) - @visit_node.register(ops.ApproxMedian) def visit_ApproxMedian(self, op, *, arg, where): return self.agg.approx_percentile(arg, 0.5, where=where) - @visit_node.register(ops.TimeDelta) def visit_TimeDelta(self, op, *, part, left, right): return self.f.timediff(part, right, left, dialect=self.dialect) - @visit_node.register(ops.DateDelta) def visit_DateDelta(self, op, *, part, left, right): return self.f.datediff(part, right, left, dialect=self.dialect) - @visit_node.register(ops.TimestampDelta) def visit_TimestampDelta(self, op, *, part, left, right): return self.f.timestampdiff(part, right, left, dialect=self.dialect) - @visit_node.register(ops.TimestampAdd) - @visit_node.register(ops.DateAdd) def visit_TimestampDateAdd(self, op, *, left, right): if not isinstance(op.right, ops.Literal): raise com.OperationNotDefinedError( @@ -253,23 +237,21 @@ def visit_TimestampDateAdd(self, op, *, left, right): ) return sg.exp.Add(this=left, expression=right) - @visit_node.register(ops.IntegerRange) + visit_DateAdd = visit_TimestampAdd = visit_TimestampDateAdd + def visit_IntegerRange(self, op, *, start, stop, step): return self.if_( step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array() ) - @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, names, values): return self.f.object_construct_keep_null( *itertools.chain.from_iterable(zip(names, values)) ) - @visit_node.register(ops.StructField) def visit_StructField(self, op, *, arg, field): return self.cast(self.f.get(arg, field), op.dtype) - @visit_node.register(ops.RegexSearch) def visit_RegexSearch(self, op, *, arg, pattern): return sge.RegexpLike( this=arg, @@ -277,38 +259,30 @@ def visit_RegexSearch(self, op, *, arg, pattern): flag=sge.convert("cs"), ) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return sge.RegexpReplace(this=arg, expression=pattern, replacement=replacement) - @visit_node.register(ops.TypeOf) def visit_TypeOf(self, op, *, arg): return self.f.typeof(self.f.to_variant(arg)) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): return self.f.udf.array_repeat(arg, times) - @visit_node.register(ops.ArrayUnion) def visit_ArrayUnion(self, op, *, left, right): return self.f.array_distinct(self.f.array_cat(left, right)) - @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): return self.f.array_contains(arg, self.f.to_variant(other)) - @visit_node.register(ops.ArrayCollect) def visit_ArrayCollect(self, op, *, arg, where): return self.agg.array_agg( self.f.ifnull(arg, self.f.parse_json("null")), where=where ) - @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): # array_cat only accepts two arguments return self.f.array_flatten(self.f.array(*arg)) - @visit_node.register(ops.ArrayPosition) def visit_ArrayPosition(self, op, *, arg, other): # snowflake is zero-based here, so we don't need to subtract 1 from the # result @@ -316,7 +290,6 @@ def visit_ArrayPosition(self, op, *, arg, other): self.f.array_position(self.f.to_variant(other), arg) + 1, 0 ) - @visit_node.register(ops.RegexExtract) def visit_RegexExtract(self, op, *, arg, pattern, index): # https://docs.snowflake.com/en/sql-reference/functions/regexp_substr return sge.RegexpExtract( @@ -327,11 +300,9 @@ def visit_RegexExtract(self, op, *, arg, pattern, index): parameters=sge.convert("ce"), ) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op, *, arg): return self.f.udf.array_zip(self.f.array(*arg)) - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return sge.Case( this=self.f.dayname(arg), @@ -347,21 +318,17 @@ def visit_DayOfWeekName(self, op, *, arg): default=NULL, ) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) - @visit_node.register(ops.First) def visit_First(self, op, *, arg, where): return self.f.get(self.agg.array_agg(arg, where=where), 0) - @visit_node.register(ops.Last) def visit_Last(self, op, *, arg, where): expr = self.agg.array_agg(arg, where=where) return self.f.get(expr, self.f.array_size(expr) - 1) - @visit_node.register(ops.GroupConcat) def visit_GroupConcat(self, op, *, arg, where, sep): if where is None: return self.f.listagg(arg, sep) @@ -372,7 +339,6 @@ def visit_GroupConcat(self, op, *, arg, where, sep): NULL, ) - @visit_node.register(ops.TimestampBucket) def visit_TimestampBucket(self, op, *, arg, interval, offset): if offset is not None: raise com.UnsupportedOperationError( @@ -387,7 +353,6 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): return self.f.time_slice(arg, interval.value, interval.dtype.unit.name) - @visit_node.register(ops.Arbitrary) def visit_Arbitrary(self, op, *, arg, how, where): if how == "first": return self.f.get(self.agg.array_agg(arg, where=where), 0) @@ -397,7 +362,6 @@ def visit_Arbitrary(self, op, *, arg, how, where): else: raise com.UnsupportedOperationError("how must be 'first' or 'last'") - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): if start is None: start = 0 @@ -406,19 +370,15 @@ def visit_ArraySlice(self, op, *, arg, start, stop): stop = self.f.array_size(arg) return self.f.array_slice(arg, start, stop) - @visit_node.register(ops.ExtractEpochSeconds) - def visit_ExtractExtractEpochSeconds(self, op, *, arg): + def visit_ExtractEpochSeconds(self, op, *, arg): return self.f.extract("epoch", arg) - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): return self.f.extract("epoch_microsecond", arg) % 1_000_000 - @visit_node.register(ops.ExtractMillisecond) def visit_ExtractMillisecond(self, op, *, arg): return self.f.extract("epoch_millisecond", arg) % 1_000 - @visit_node.register(ops.ExtractQuery) def visit_ExtractQuery(self, op, *, arg, key): parsed_url = self.f.parse_url(arg, 1) if key is not None: @@ -427,13 +387,11 @@ def visit_ExtractQuery(self, op, *, arg, key): r = self.f.get(parsed_url, "query") return self.f.nullif(self.f.as_varchar(r), "") - @visit_node.register(ops.ExtractProtocol) def visit_ExtractProtocol(self, op, *, arg): return self.f.nullif( self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "scheme")), "" ) - @visit_node.register(ops.ExtractAuthority) def visit_ExtractAuthority(self, op, *, arg): return self.f.concat_ws( ":", @@ -441,7 +399,6 @@ def visit_ExtractAuthority(self, op, *, arg): self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "port")), ) - @visit_node.register(ops.ExtractFile) def visit_ExtractFile(self, op, *, arg): return self.f.concat_ws( "?", @@ -449,19 +406,16 @@ def visit_ExtractFile(self, op, *, arg): self.visit_ExtractQuery(op, arg=arg, key=None), ) - @visit_node.register(ops.ExtractPath) def visit_ExtractPath(self, op, *, arg): return self.f.concat( "/", self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "path")) ) - @visit_node.register(ops.ExtractFragment) def visit_ExtractFragment(self, op, *, arg): return self.f.nullif( self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "fragment")), "" ) - @visit_node.register(ops.Unnest) def visit_Unnest(self, op, *, arg): sep = sge.convert(util.guid()) split = self.f.split( @@ -470,7 +424,6 @@ def visit_Unnest(self, op, *, arg): expr = self.f.nullif(self.f.explode(split), "") return self.cast(expr, op.dtype) - @visit_node.register(ops.Quantile) def visit_Quantile(self, op, *, arg, quantile, where): # can't use `self.agg` here because `quantile` must be a constant and # the agg method filters using `where` for every argument which turns @@ -479,19 +432,16 @@ def visit_Quantile(self, op, *, arg, quantile, where): arg = self.if_(where, arg, NULL) return self.f.percentile_cont(arg, quantile) - @visit_node.register(ops.CountStar) def visit_CountStar(self, op, *, arg, where): if where is None: - return super().visit_node(op, arg=arg, where=where) + return super().visit_CountStar(op, arg=arg, where=where) return self.f.count_if(where) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, NULL) return self.f.count(sge.Distinct(expressions=[arg])) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): columns = op.arg.schema.names quoted = self.quoted @@ -503,12 +453,10 @@ def visit_CountDistinctStar(self, op, *, arg, where): expressions = [self.if_(where, col(name), NULL) for name in columns] return self.f.count(sge.Distinct(expressions=expressions)) - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): # boolxor accepts numerics ... and returns a boolean? wtf? return self.f.boolxor(self.cast(left, dt.int8), self.cast(right, dt.int8)) - @visit_node.register(ops.WindowBoundary) def visit_WindowBoundary(self, op, *, value, preceding): if not isinstance(op.value, ops.Literal): raise com.OperationNotDefinedError( @@ -516,7 +464,6 @@ def visit_WindowBoundary(self, op, *, value, preceding): ) return super().visit_WindowBoundary(op, value=value, preceding=preceding) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -532,7 +479,6 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.TimestampRange) def visit_TimestampRange(self, op, *, start, stop, step): raw_step = op.step @@ -587,17 +533,6 @@ def visit_TimestampRange(self, op, *, start, stop, step): .subquery() ) - @visit_node.register(ops.ArrayMap) - @visit_node.register(ops.ArrayFilter) - @visit_node.register(ops.RowID) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.IntervalAdd) - @visit_node.register(ops.TimestampDiff) - @visit_node.register(ops.TryCast) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - _SIMPLE_OPS = { ops.Any: "max", @@ -629,13 +564,11 @@ def visit_Undefined(self, op, **_): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @SnowflakeCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @SnowflakeCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 52b76a3b9ca4..0dcd701906f0 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import singledispatchmethod - import sqlglot as sg import sqlglot.expressions as sge from public import public @@ -37,56 +35,53 @@ class SQLiteCompiler(SQLGlotCompiler): POS_INF = sge.Literal.number("1e999") NEG_INF = sge.Literal.number("-1e999") + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Levenshtein, + ops.RegexSplit, + ops.StringSplit, + ops.IsNan, + ops.IsInf, + ops.Covariance, + ops.Correlation, + ops.Quantile, + ops.MultiQuantile, + ops.Median, + ops.ApproxMedian, + ops.Array, + ops.ArrayConcat, + ops.ArrayStringJoin, + ops.ArrayCollect, + ops.ArrayContains, + ops.ArrayFlatten, + ops.ArrayLength, + ops.ArraySort, + ops.ArrayStringJoin, + ops.CountDistinctStar, + ops.IntervalBinary, + ops.IntervalAdd, + ops.IntervalSubtract, + ops.IntervalMultiply, + ops.IntervalFloorDivide, + ops.IntervalFromInteger, + ops.TimestampBucket, + ops.TimestampAdd, + ops.TimestampSub, + ops.TimestampDiff, + ops.StringToTimestamp, + ops.TimeDelta, + ops.DateDelta, + ops.TimestampDelta, + ops.TryCast, + ) + ) + def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) if where is not None: return sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.Levenshtein) - @visit_node.register(ops.RegexSplit) - @visit_node.register(ops.StringSplit) - @visit_node.register(ops.IsNan) - @visit_node.register(ops.IsInf) - @visit_node.register(ops.Covariance) - @visit_node.register(ops.Correlation) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Median) - @visit_node.register(ops.ApproxMedian) - @visit_node.register(ops.Array) - @visit_node.register(ops.ArrayConcat) - @visit_node.register(ops.ArrayStringJoin) - @visit_node.register(ops.ArrayCollect) - @visit_node.register(ops.ArrayContains) - @visit_node.register(ops.ArrayFlatten) - @visit_node.register(ops.ArrayLength) - @visit_node.register(ops.ArraySort) - @visit_node.register(ops.ArrayStringJoin) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.IntervalBinary) - @visit_node.register(ops.IntervalAdd) - @visit_node.register(ops.IntervalSubtract) - @visit_node.register(ops.IntervalMultiply) - @visit_node.register(ops.IntervalFloorDivide) - @visit_node.register(ops.IntervalFromInteger) - @visit_node.register(ops.TimestampBucket) - @visit_node.register(ops.TimestampAdd) - @visit_node.register(ops.TimestampSub) - @visit_node.register(ops.TimestampDiff) - @visit_node.register(ops.StringToTimestamp) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.TimestampDelta) - @visit_node.register(ops.TryCast) - def visit_Undefined(self, op, **kwargs): - return super().visit_Undefined(op, **kwargs) - - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to) -> sge.Cast: if to.is_timestamp(): if to.timezone not in (None, "UTC"): @@ -103,7 +98,6 @@ def visit_Cast(self, op, *, arg, to) -> sge.Cast: return self.f.time(arg) return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.Limit) def visit_Limit(self, op, *, parent, n, offset): # SQLite doesn't support compiling an OFFSET without a LIMIT, but # treats LIMIT -1 as no limit @@ -111,7 +105,6 @@ def visit_Limit(self, op, *, parent, n, offset): op, parent=parent, n=(-1 if n is None else n), offset=offset ) - @visit_node.register(ops.WindowBoundary) def visit_WindowBoundary(self, op, *, value, preceding): if op.value.dtype.is_interval(): raise com.OperationNotDefinedError( @@ -119,7 +112,6 @@ def visit_WindowBoundary(self, op, *, value, preceding): ) return super().visit_WindowBoundary(op, value=value, preceding=preceding) - @visit_node.register(ops.JoinLink) def visit_JoinLink(self, op, **kwargs): if op.how == "asof": raise com.UnsupportedOperationError( @@ -127,19 +119,15 @@ def visit_JoinLink(self, op, **kwargs): ) return super().visit_JoinLink(op, **kwargs) - @visit_node.register(ops.StartsWith) def visit_StartsWith(self, op, *, arg, start): return arg.like(self.f.concat(start, "%")) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return arg.like(self.f.concat("%", end)) - @visit_node.register(ops.StrRight) def visit_StrRight(self, op, *, arg, nchars): return self.f.substr(arg, -nchars, nchars) - @visit_node.register(ops.StringFind) def visit_StringFind(self, op, *, arg, substr, start, end): if op.end is not None: raise NotImplementedError("`end` not yet implemented") @@ -151,36 +139,29 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.instr(arg, substr) - @visit_node.register(ops.StringJoin) def visit_StringJoin(self, op, *, arg, sep): args = [arg[0]] for item in arg[1:]: args.extend([sep, item]) return self.f.concat(*args) - @visit_node.register(ops.StringContains) - def visit_Contains(self, op, *, haystack, needle): + def visit_StringContains(self, op, *, haystack, needle): return self.f.instr(haystack, needle) >= 1 - @visit_node.register(ops.ExtractQuery) def visit_ExtractQuery(self, op, *, arg, key): if op.key is None: return self.f._ibis_extract_full_query(arg) return self.f._ibis_extract_query(arg, key) - @visit_node.register(ops.Greatest) def visit_Greatest(self, op, *, arg): return self.f.max(*arg) - @visit_node.register(ops.Least) def visit_Least(self, op, *, arg): return self.f.min(*arg) - @visit_node.register(ops.IdenticalTo) def visit_IdenticalTo(self, op, *, left, right): return sge.Is(this=left, expression=right) - @visit_node.register(ops.Clip) def visit_Clip(self, op, *, arg, lower, upper): if upper is not None: arg = self.if_(arg.is_(NULL), arg, self.f.min(upper, arg)) @@ -190,15 +171,12 @@ def visit_Clip(self, op, *, arg, lower, upper): return arg - @visit_node.register(ops.RandomScalar) def visit_RandomScalar(self, op): return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64)) - @visit_node.register(ops.Cot) def visit_Cot(self, op, *, arg): return 1 / self.f.tan(arg) - @visit_node.register(ops.Arbitrary) def visit_Arbitrary(self, op, *, arg, how, where): if op.how == "heavy": raise com.OperationNotDefinedError( @@ -207,11 +185,9 @@ def visit_Arbitrary(self, op, *, arg, how, where): return self._aggregate(f"_ibis_arbitrary_{how}", arg, where=where) - @visit_node.register(ops.ArgMin) def visit_ArgMin(self, *args, **kwargs): return self._visit_arg_reduction("min", *args, **kwargs) - @visit_node.register(ops.ArgMax) def visit_ArgMax(self, *args, **kwargs): return self._visit_arg_reduction("max", *args, **kwargs) @@ -224,36 +200,28 @@ def _visit_arg_reduction(self, func, op, *, arg, key, where): agg = self._aggregate(func, key, where=cond) return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]") - @visit_node.register(ops.Variance) def visit_Variance(self, op, *, arg, how, where): return self._aggregate(f"_ibis_var_{op.how}", arg, where=where) - @visit_node.register(ops.StandardDev) def visit_StandardDev(self, op, *, arg, how, where): var = self._aggregate(f"_ibis_var_{op.how}", arg, where=where) return self.f.sqrt(var) - @visit_node.register(ops.ApproxCountDistinct) def visit_ApproxCountDistinct(self, op, *, arg, where): return self.agg.count(sge.Distinct(expressions=[arg]), where=where) - @visit_node.register(ops.CountDistinct) def visit_CountDistinct(self, op, *, arg, where): return self.agg.count(sge.Distinct(expressions=[arg]), where=where) - @visit_node.register(ops.Strftime) def visit_Strftime(self, op, *, arg, format_str): return self.f.strftime(format_str, arg) - @visit_node.register(ops.DateFromYMD) def visit_DateFromYMD(self, op, *, year, month, day): return self.f.date(self.f.printf("%04d-%02d-%02d", year, month, day)) - @visit_node.register(ops.TimeFromHMS) def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): return self.f.time(self.f.printf("%02d:%02d:%02d", hours, minutes, seconds)) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): @@ -286,16 +254,12 @@ def _temporal_truncate(self, func, arg, unit): raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") return func(arg, *params) - @visit_node.register(ops.DateTruncate) def visit_DateTruncate(self, op, *, arg, unit): return self._temporal_truncate(self.f.date, arg, unit) - @visit_node.register(ops.TimestampTruncate) def visit_TimestampTruncate(self, op, *, arg, unit): return self._temporal_truncate(self.f.datetime, arg, unit) - @visit_node.register(ops.DateAdd) - @visit_node.register(ops.DateSub) def visit_DateArithmetic(self, op, *, left, right): unit = op.right.dtype.unit sign = "+" if isinstance(op, ops.DateAdd) else "-" @@ -308,53 +272,43 @@ def visit_DateArithmetic(self, op, *, left, right): else: return self.f.date(left, self.f.concat(sign, right, f" {unit.plural}")) - @visit_node.register(ops.DateDiff) + visit_DateAdd = visit_DateSub = visit_DateArithmetic + def visit_DateDiff(self, op, *, left, right): return self.f.julianday(left) - self.f.julianday(right) - @visit_node.register(ops.ExtractYear) def visit_ExtractYear(self, op, *, arg): return self.cast(self.f.strftime("%Y", arg), dt.int64) - @visit_node.register(ops.ExtractQuarter) def visit_ExtractQuarter(self, op, *, arg): return (self.f.strftime("%m", arg) + 2) / 3 - @visit_node.register(ops.ExtractMonth) def visit_ExtractMonth(self, op, *, arg): return self.cast(self.f.strftime("%m", arg), dt.int64) - @visit_node.register(ops.ExtractDay) def visit_ExtractDay(self, op, *, arg): return self.cast(self.f.strftime("%d", arg), dt.int64) - @visit_node.register(ops.ExtractDayOfYear) def visit_ExtractDayOfYear(self, op, *, arg): return self.cast(self.f.strftime("%j", arg), dt.int64) - @visit_node.register(ops.ExtractHour) def visit_ExtractHour(self, op, *, arg): return self.cast(self.f.strftime("%H", arg), dt.int64) - @visit_node.register(ops.ExtractMinute) def visit_ExtractMinute(self, op, *, arg): return self.cast(self.f.strftime("%M", arg), dt.int64) - @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): return self.cast(self.f.strftime("%S", arg), dt.int64) - @visit_node.register(ops.ExtractMillisecond) - def visit_Millisecond(self, op, *, arg): + def visit_ExtractMillisecond(self, op, *, arg): return self.cast(self.f.mod(self.f.strftime("%f", arg) * 1000, 1000), dt.int64) - @visit_node.register(ops.ExtractMicrosecond) - def visit_Microsecond(self, op, *, arg): + def visit_ExtractMicrosecond(self, op, *, arg): return self.cast( self.f.mod(self.cast(self.f.strftime("%f", arg), dt.int64), 1000), dt.int64 ) - @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): """ISO week of year. @@ -400,17 +354,14 @@ def visit_ExtractWeekOfYear(self, op, *, arg): date = self.f.date(arg, "-3 days", "weekday 4") return (self.f.strftime("%j", date) - 1) / 7 + 1 - @visit_node.register(ops.ExtractEpochSeconds) def visit_ExtractEpochSeconds(self, op, *, arg): return self.cast((self.f.julianday(arg) - 2440587.5) * 86400.0, dt.int64) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.cast( self.f.mod(self.cast(self.f.strftime("%w", arg) + 6, dt.int64), 7), dt.int64 ) - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return sge.Case( this=self.f.strftime("%w", arg), @@ -425,7 +376,6 @@ def visit_DayOfWeekName(self, op, *, arg): ], ) - @visit_node.register(ops.Xor) def visit_Xor(self, op, *, left, right): return (left.or_(right)).and_(sg.not_(left.and_(right))) @@ -502,13 +452,11 @@ def visit_NonNullLiteral(self, op, *, value, dtype): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @SQLiteCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @SQLiteCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values()) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ca876e4a3040..2c6d6d5838c8 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1321,15 +1321,7 @@ def hash_256(col): backend.assert_series_equal(h1, h2) -@pytest.mark.notimpl( - [ - "pandas", - "dask", - "oracle", - "snowflake", - "sqlite", - ] -) +@pytest.mark.notimpl(["pandas", "dask", "oracle", "sqlite"]) @pytest.mark.parametrize( ("from_val", "to_type", "expected"), [ @@ -1344,6 +1336,7 @@ def hash_256(col): marks=[ pytest.mark.notyet(["duckdb", "impala"], reason="casts to NULL"), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), + pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.notyet(["trino"], raises=TrinoUserError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), pytest.mark.broken( @@ -1379,7 +1372,6 @@ def test_try_cast(con, from_val, to_type, expected): "pandas", "postgres", "risingwave", - "snowflake", "sqlite", ] ) @@ -1395,6 +1387,7 @@ def test_try_cast(con, from_val, to_type, expected): ["clickhouse", "pyspark", "flink"], reason="casts to 1672531200" ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), + pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.notyet(["trino"], raises=TrinoUserError), pytest.mark.notyet(["mssql"], raises=PyODBCDataError), pytest.mark.broken(["polars"], reason="casts to 1672531200000000000"), @@ -1443,7 +1436,6 @@ def test_try_cast_table(backend, con): "oracle", "postgres", "risingwave", - "snowflake", "sqlite", "exasol", ] @@ -1463,6 +1455,7 @@ def test_try_cast_table(backend, con): reason="casts this to to a number", ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), + pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.notyet(["trino"], raises=TrinoUserError), pytest.mark.notyet(["mssql"], raises=PyODBCDataError), ], diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index a62d59734dba..30151b01b668 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -409,7 +409,6 @@ def uses_java_re(t): id="find_start", marks=[ pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError), - pytest.mark.notyet(["bigquery"], raises=NotImplementedError), ], ), param( diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 80b93e348b04..1fb6207ad007 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from functools import partial, reduce, singledispatchmethod +from functools import partial, reduce import sqlglot as sg import sqlglot.expressions as sge @@ -43,6 +43,16 @@ class TrinoCompiler(SQLGlotCompiler): POS_INF = sg.func("infinity") NEG_INF = -POS_INF + UNSUPPORTED_OPERATIONS = frozenset( + ( + ops.Quantile, + ops.MultiQuantile, + ops.Median, + ops.RowID, + ops.TimestampBucket, + ) + ) + def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) if where is not None: @@ -60,11 +70,6 @@ def _minimize_spec(start, end, spec): return None return spec - @singledispatchmethod - def visit_node(self, op, **kw): - return super().visit_node(op, **kw) - - @visit_node.register(ops.Sample) def visit_Sample( self, op, *, parent, fraction: float, method: str, seed: int | None, **_ ): @@ -80,7 +85,6 @@ def visit_Sample( ) return sg.select(STAR).from_(sample) - @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( @@ -94,7 +98,6 @@ def visit_Correlation(self, op, *, left, right, how, where): return self.agg.corr(left, right, where=where) - @visit_node.register(ops.Arbitrary) def visit_Arbitrary(self, op, *, arg, how, where): if how != "first": raise com.UnsupportedOperationError( @@ -102,7 +105,6 @@ def visit_Arbitrary(self, op, *, arg, how, where): ) return self.agg.arbitrary(arg, where=where) - @visit_node.register(ops.BitXor) def visit_BitXor(self, op, *, arg, where): a, b = map(sg.to_identifier, "ab") input_fn = combine_fn = sge.Lambda( @@ -110,11 +112,9 @@ def visit_BitXor(self, op, *, arg, where): ) return self.agg.reduce_agg(arg, 0, input_fn, combine_fn, where=where) - @visit_node.register(ops.ArrayRepeat) def visit_ArrayRepeat(self, op, *, arg, times): return self.f.flatten(self.f.repeat(arg, times)) - @visit_node.register(ops.ArraySlice) def visit_ArraySlice(self, op, *, arg, start, stop): def _neg_idx_to_pos(n, idx): return self.if_(idx < 0, n + self.f.greatest(idx, -n), idx) @@ -133,15 +133,12 @@ def _neg_idx_to_pos(n, idx): return self.f.slice(arg, start + 1, stop - start) - @visit_node.register(ops.ArrayMap) def visit_ArrayMap(self, op, *, arg, param, body): return self.f.transform(arg, sge.Lambda(this=body, expressions=[param])) - @visit_node.register(ops.ArrayFilter) def visit_ArrayFilter(self, op, *, arg, param, body): return self.f.filter(arg, sge.Lambda(this=body, expressions=[param])) - @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): return self.if_( arg.is_(sg.not_(NULL)), @@ -149,33 +146,25 @@ def visit_ArrayContains(self, op, *, arg, other): NULL, ) - @visit_node.register(ops.JSONGetItem) def visit_JSONGetItem(self, op, *, arg, index): fmt = "%d" if op.index.dtype.is_integer() else '"%s"' return self.f.json_extract(arg, self.f.format(f"$[{fmt}]", index)) - @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): return self.cast(paren(self.f.day_of_week(arg) + 6) % 7, op.dtype) - @visit_node.register(ops.DayOfWeekName) def visit_DayOfWeekName(self, op, *, arg): return self.f.date_format(arg, "%W") - @visit_node.register(ops.StrRight) def visit_StrRight(self, op, *, arg, nchars): return self.f.substr(arg, -nchars) - @visit_node.register(ops.EndsWith) def visit_EndsWith(self, op, *, arg, end): return self.f.substr(arg, -self.f.length(end)).eq(end) - @visit_node.register(ops.Repeat) def visit_Repeat(self, op, *, arg, times): return self.f.array_join(self.f.repeat(arg, times), "") - @visit_node.register(ops.DateTruncate) - @visit_node.register(ops.TimestampTruncate) def visit_DateTimestampTruncate(self, op, *, arg, unit): _truncate_precisions = { # ms unit is not yet officially documented but it works @@ -196,19 +185,18 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit): ) return self.f.date_trunc(precision, arg) - @visit_node.register(ops.DateFromYMD) + visit_DateTruncate = visit_TimestampTruncate = visit_DateTimestampTruncate + def visit_DateFromYMD(self, op, *, year, month, day): return self.f.from_iso8601_date( self.f.format("%04d-%02d-%02d", year, month, day) ) - @visit_node.register(ops.TimeFromHMS) def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): return self.cast( self.f.format("%02d:%02d:%02d", hours, minutes, seconds), dt.time ) - @visit_node.register(ops.TimestampFromYMDHMS) def visit_TimestampFromYMDHMS( self, op, *, year, month, day, hours, minutes, seconds ): @@ -227,7 +215,6 @@ def visit_TimestampFromYMDHMS( dt.timestamp, ) - @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): short = unit.short if short == "ms": @@ -242,7 +229,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): raise com.UnsupportedOperationError(f"{unit!r} unit is not supported") return self.cast(res, op.dtype) - @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, names, values): return self.cast(sge.Struct(expressions=list(values)), op.dtype) @@ -272,37 +258,30 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - @visit_node.register(ops.Log) def visit_Log(self, op, *, arg, base): return self.f.log(base, arg, dialect=self.dialect) - @visit_node.register(ops.MapGet) def visit_MapGet(self, op, *, arg, key, default): return self.f.coalesce(self.f.element_at(arg, key), default) - @visit_node.register(ops.MapContains) def visit_MapContains(self, op, *, arg, key): return self.f.contains(self.f.map_keys(arg), key) - @visit_node.register(ops.ExtractFile) - def visit_ExtractProtocol(self, op, *, arg): + def visit_ExtractFile(self, op, *, arg): return self.f.concat_ws( "?", self.f.nullif(self.f.url_extract_path(arg), ""), self.f.nullif(self.f.url_extract_query(arg), ""), ) - @visit_node.register(ops.ExtractQuery) def visit_ExtractQuery(self, op, *, arg, key): if key is None: return self.f.url_extract_query(arg) return self.f.url_extract_parameter(arg, key) - @visit_node.register(ops.Cot) def visit_Cot(self, op, *, arg): return 1.0 / self.f.tan(arg) - @visit_node.register(ops.StringAscii) def visit_StringAscii(self, op, *, arg): return self.f.codepoint( sge.Cast( @@ -314,19 +293,15 @@ def visit_StringAscii(self, op, *, arg): ) ) - @visit_node.register(ops.ArrayStringJoin) def visit_ArrayStringJoin(self, op, *, sep, arg): return self.f.array_join(arg, sep) - @visit_node.register(ops.First) def visit_First(self, op, *, arg, where): return self.f.element_at(self.agg.array_agg(arg, where=where), 1) - @visit_node.register(ops.Last) def visit_Last(self, op, *, arg, where): return self.f.element_at(self.agg.array_agg(arg, where=where), -1) - @visit_node.register(ops.ArrayZip) def visit_ArrayZip(self, op, *, arg): max_zip_arguments = 5 chunks = ( @@ -354,15 +329,11 @@ def combine_zipped(left, right): assert all_n == len(op.dtype.value_type) return chunk - @visit_node.register(ops.ExtractMicrosecond) def visit_ExtractMicrosecond(self, op, *, arg): # trino only seems to store milliseconds, but the result of formatting # always pads the right with 000 return self.cast(self.f.date_format(arg, "%f"), dt.int32) - @visit_node.register(ops.TimeDelta) - @visit_node.register(ops.DateDelta) - @visit_node.register(ops.TimestampDelta) def visit_TemporalDelta(self, op, *, part, left, right): # trino truncates _after_ the delta, whereas many other backends # truncate each operand @@ -374,7 +345,8 @@ def visit_TemporalDelta(self, op, *, part, left, right): dialect=dialect, ) - @visit_node.register(ops.IntervalFromInteger) + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta = visit_TemporalDelta + def visit_IntervalFromInteger(self, op, *, arg, unit): unit = op.unit.short if unit in ("Y", "Q", "M", "W"): @@ -385,8 +357,6 @@ def visit_IntervalFromInteger(self, op, *, arg, unit): ) ) - @visit_node.register(ops.TimestampRange) - @visit_node.register(ops.IntegerRange) def visit_Range(self, op, *, start, stop, step): def zero_value(dtype): if dtype.is_interval(): @@ -420,11 +390,11 @@ def _sign(value, dtype): self.f.array(), ) - @visit_node.register(ops.ArrayIndex) + visit_IntegerRange = visit_TimestampRange = visit_Range + def visit_ArrayIndex(self, op, *, arg, index): return self.f.element_at(arg, index + 1) - @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): from_ = op.arg.dtype if from_.is_integer() and to.is_interval(): @@ -437,34 +407,22 @@ def visit_Cast(self, op, *, arg, to): return self.f.from_unixtime(arg, to.timezone or "UTC") return super().visit_Cast(op, arg=arg, to=to) - @visit_node.register(ops.CountDistinctStar) def visit_CountDistinctStar(self, op, *, arg, where): make_col = partial(sg.column, table=arg.alias_or_name, quoted=self.quoted) row = self.f.row(*map(make_col, op.arg.schema.names)) return self.agg.count(sge.Distinct(expressions=[row]), where=where) - @visit_node.register(ops.ArrayConcat) def visit_ArrayConcat(self, op, *, arg): return self.f.concat(*arg) - @visit_node.register(ops.StringContains) def visit_StringContains(self, op, *, haystack, needle): return self.f.strpos(haystack, needle) > 0 - @visit_node.register(ops.RegexExtract) def visit_RegexpExtract(self, op, *, arg, pattern, index): # sqlglot doesn't support the third `group` argument for trino so work # around that limitation using an anonymous function return self.f.anon.regexp_extract(arg, pattern, index) - @visit_node.register(ops.Quantile) - @visit_node.register(ops.MultiQuantile) - @visit_node.register(ops.Median) - @visit_node.register(ops.RowID) - @visit_node.register(ops.TimestampBucket) - def visit_Undefined(self, op, **kw): - return super().visit_Undefined(op, **kw) - _SIMPLE_OPS = { ops.Pi: "pi", @@ -508,13 +466,11 @@ def visit_Undefined(self, op, **kw): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - @TrinoCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, where, **kw): return self.agg[_name](*kw.values(), where=where) else: - @TrinoCompiler.visit_node.register(_op) def _fmt(self, op, *, _name: str = _name, **kw): return self.f[_name](*kw.values())