From a14d7ba19a193e6506ad9780e96fd976283dadc0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 5 Dec 2023 08:26:45 -0500 Subject: [PATCH] fixup! chore: fix safe_ordinal call --- ibis/backends/base/sqlglot/__init__.py | 7 ++--- ibis/backends/base/sqlglot/compiler.py | 27 ++++++++++++---- ibis/backends/bigquery/compiler.py | 43 ++++++++++++++++---------- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 3c105f58445a..11749d182cc1 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -80,15 +80,14 @@ def _to_sqlglot( sql = sg.select(STAR).from_(sql) assert not isinstance(sql, sge.Subquery) - return sql + return [sql] def compile( self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any ): """Compile an Ibis expression to a ClickHouse SQL string.""" - queries = ibis.util.promote_list( - self._to_sqlglot(expr, limit=limit, params=params, **kwargs) - ) + queries = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) + return ";\n\n".join( query.sql(dialect=self.name, pretty=True) for query in queries ) diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index 264cab1e22f2..b7126fe82aa7 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -9,7 +9,7 @@ import string from collections.abc import Mapping from functools import partial, singledispatchmethod -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable import sqlglot as sg import sqlglot.expressions as sge @@ -267,7 +267,9 @@ def visit_node(self, op: ops.Node, **_): @visit_node.register(ops.Field) def visit_Field(self, op, *, rel, name): - return sg.column(name, table=rel.alias_or_name, quoted=self.quoted) + return sg.column( + self._gen_valid_name(name), table=rel.alias_or_name, quoted=self.quoted + ) @visit_node.register(ops.ScalarSubquery) def visit_ScalarSubquery(self, op, *, rel): @@ -275,7 +277,7 @@ def visit_ScalarSubquery(self, op, *, rel): @visit_node.register(ops.Alias) def visit_Alias(self, op, *, arg, name): - return arg.as_(name, quoted=self.quoted) + return arg.as_(self._gen_valid_name(name), quoted=self.quoted) @visit_node.register(ops.Literal) def visit_Literal(self, op, *, value, dtype, **kw): @@ -923,12 +925,19 @@ def visit_JoinLink(self, op, *, how, table, predicates): this=table, side=sides[how], kind=kinds[how], on=sg.and_(*predicates) ) + @staticmethod + def _gen_valid_name(name: str) -> str: + return name + @visit_node.register(ops.Project) def visit_Project(self, op, *, parent, values): # needs_alias should never be true here in explicitly, but it may get # passed via a (recursive) call to translate_val return sg.select( - *(value.as_(key, quoted=self.quoted) for key, value in values.items()) + *( + value.as_(self._gen_valid_name(key), quoted=self.quoted) + for key, value in values.items() + ) ).from_(parent) @staticmethod @@ -938,8 +947,14 @@ def _generate_groups(groups): @visit_node.register(ops.Aggregate) def visit_Aggregate(self, op, *, parent, groups, metrics): sel = sg.select( - *(value.as_(key, quoted=self.quoted) for key, value in groups.items()), - *(value.as_(key, quoted=self.quoted) for key, value in metrics.items()), + *( + value.as_(self._gen_valid_name(key), quoted=self.quoted) + for key, value in groups.items() + ), + *( + value.as_(self._gen_valid_name(key), quoted=self.quoted) + for key, value in metrics.items() + ), ).from_(parent) if groups: diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index b5f88eef5f04..9f7ef125c740 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -2,6 +2,7 @@ import contextlib import math +import re from functools import singledispatchmethod import sqlglot as sg @@ -104,6 +105,13 @@ def _aggregate(self, funcname: str, *args, where): def visit_node(self, op, **kw): return super().visit_node(op, **kw) + @staticmethod + def _gen_valid_name(name: str) -> str: + found = re.findall(r'[^!"$()*,./;?@[\\\]^`{}~\n]+', name) + result = "_".join(found) + assert result, f"name {name} produced an empty result" + return result + @visit_node.register(ops.IntegerRange) def visit_IntegerRange(self, op, *, start, stop, step): n = self.f.floor((stop - start) / self.f.nullif(step, 0)) @@ -141,21 +149,6 @@ def visit_TimestampDelta(self, op, *, part, left, right): def visit_Pi(self, op): return self.f.acos(-1) - @visit_node.register(ops.FindInSet) - @visit_node.register(ops.CountDistinctStar) - @visit_node.register(ops.DateDiff) - @visit_node.register(ops.TimestampDiff) - @visit_node.register(ops.ExtractAuthority) - @visit_node.register(ops.ExtractFile) - @visit_node.register(ops.ExtractFragment) - @visit_node.register(ops.ExtractHost) - @visit_node.register(ops.ExtractPath) - @visit_node.register(ops.ExtractProtocol) - @visit_node.register(ops.ExtractQuery) - @visit_node.register(ops.ExtractUserInfo) - def visit_Undefined(self, op, **_): - raise com.OperationNotDefinedError(type(op).__name__) - @visit_node.register(ops.WindowBoundary) def visit_WindowBoundary(self, op, *, value, preceding): if not isinstance(op.value, ops.Literal): @@ -642,7 +635,25 @@ def visit_Cast(self, op, *, arg, to): elif from_.is_floating() and to.is_integer(): return self.cast(self.f.trunc(arg), dt.int64) else: - return super().visit_node(op, arg=arg, to=to) + return self.cast(arg, to) + + @visit_node.register(ops.FindInSet) + @visit_node.register(ops.CountDistinctStar) + @visit_node.register(ops.DateDiff) + @visit_node.register(ops.TimestampDiff) + @visit_node.register(ops.ExtractAuthority) + @visit_node.register(ops.ExtractFile) + @visit_node.register(ops.ExtractFragment) + @visit_node.register(ops.ExtractHost) + @visit_node.register(ops.ExtractPath) + @visit_node.register(ops.ExtractProtocol) + @visit_node.register(ops.ExtractQuery) + @visit_node.register(ops.ExtractUserInfo) + @visit_node.register(ops.Quantile) + @visit_node.register(ops.MultiQuantile) + @visit_node.register(ops.Median) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) _SIMPLE_OPS = {