Skip to content

Commit

Permalink
fixup! chore: fix safe_ordinal call
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 5, 2023
1 parent 2aa2540 commit a14d7ba
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
7 changes: 3 additions & 4 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ def _to_sqlglot(
sql = sg.select(STAR).from_(sql)

assert not isinstance(sql, sge.Subquery)
return sql
return [sql]

def compile(
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any
):
"""Compile an Ibis expression to a ClickHouse SQL string."""
queries = ibis.util.promote_list(
self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
)
queries = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)

return ";\n\n".join(
query.sql(dialect=self.name, pretty=True) for query in queries
)
Expand Down
27 changes: 21 additions & 6 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import string
from collections.abc import Mapping
from functools import partial, singledispatchmethod
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable

import sqlglot as sg
import sqlglot.expressions as sge
Expand Down Expand Up @@ -267,15 +267,17 @@ def visit_node(self, op: ops.Node, **_):

@visit_node.register(ops.Field)
def visit_Field(self, op, *, rel, name):
return sg.column(name, table=rel.alias_or_name, quoted=self.quoted)
return sg.column(
self._gen_valid_name(name), table=rel.alias_or_name, quoted=self.quoted
)

@visit_node.register(ops.ScalarSubquery)
def visit_ScalarSubquery(self, op, *, rel):
return rel.this.subquery()

@visit_node.register(ops.Alias)
def visit_Alias(self, op, *, arg, name):
return arg.as_(name, quoted=self.quoted)
return arg.as_(self._gen_valid_name(name), quoted=self.quoted)

@visit_node.register(ops.Literal)
def visit_Literal(self, op, *, value, dtype, **kw):
Expand Down Expand Up @@ -923,12 +925,19 @@ def visit_JoinLink(self, op, *, how, table, predicates):
this=table, side=sides[how], kind=kinds[how], on=sg.and_(*predicates)
)

@staticmethod
def _gen_valid_name(name: str) -> str:
return name

@visit_node.register(ops.Project)
def visit_Project(self, op, *, parent, values):
# needs_alias should never be true here in explicitly, but it may get
# passed via a (recursive) call to translate_val
return sg.select(
*(value.as_(key, quoted=self.quoted) for key, value in values.items())
*(
value.as_(self._gen_valid_name(key), quoted=self.quoted)
for key, value in values.items()
)
).from_(parent)

@staticmethod
Expand All @@ -938,8 +947,14 @@ def _generate_groups(groups):
@visit_node.register(ops.Aggregate)
def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sg.select(
*(value.as_(key, quoted=self.quoted) for key, value in groups.items()),
*(value.as_(key, quoted=self.quoted) for key, value in metrics.items()),
*(
value.as_(self._gen_valid_name(key), quoted=self.quoted)
for key, value in groups.items()
),
*(
value.as_(self._gen_valid_name(key), quoted=self.quoted)
for key, value in metrics.items()
),
).from_(parent)

if groups:
Expand Down
43 changes: 27 additions & 16 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import math
import re
from functools import singledispatchmethod

import sqlglot as sg
Expand Down Expand Up @@ -104,6 +105,13 @@ def _aggregate(self, funcname: str, *args, where):
def visit_node(self, op, **kw):
return super().visit_node(op, **kw)

@staticmethod
def _gen_valid_name(name: str) -> str:
found = re.findall(r'[^!"$()*,./;?@[\\\]^`{}~\n]+', name)
result = "_".join(found)
assert result, f"name {name} produced an empty result"
return result

@visit_node.register(ops.IntegerRange)
def visit_IntegerRange(self, op, *, start, stop, step):
n = self.f.floor((stop - start) / self.f.nullif(step, 0))
Expand Down Expand Up @@ -141,21 +149,6 @@ def visit_TimestampDelta(self, op, *, part, left, right):
def visit_Pi(self, op):
return self.f.acos(-1)

@visit_node.register(ops.FindInSet)
@visit_node.register(ops.CountDistinctStar)
@visit_node.register(ops.DateDiff)
@visit_node.register(ops.TimestampDiff)
@visit_node.register(ops.ExtractAuthority)
@visit_node.register(ops.ExtractFile)
@visit_node.register(ops.ExtractFragment)
@visit_node.register(ops.ExtractHost)
@visit_node.register(ops.ExtractPath)
@visit_node.register(ops.ExtractProtocol)
@visit_node.register(ops.ExtractQuery)
@visit_node.register(ops.ExtractUserInfo)
def visit_Undefined(self, op, **_):
raise com.OperationNotDefinedError(type(op).__name__)

@visit_node.register(ops.WindowBoundary)
def visit_WindowBoundary(self, op, *, value, preceding):
if not isinstance(op.value, ops.Literal):
Expand Down Expand Up @@ -642,7 +635,25 @@ def visit_Cast(self, op, *, arg, to):
elif from_.is_floating() and to.is_integer():
return self.cast(self.f.trunc(arg), dt.int64)
else:
return super().visit_node(op, arg=arg, to=to)
return self.cast(arg, to)

@visit_node.register(ops.FindInSet)
@visit_node.register(ops.CountDistinctStar)
@visit_node.register(ops.DateDiff)
@visit_node.register(ops.TimestampDiff)
@visit_node.register(ops.ExtractAuthority)
@visit_node.register(ops.ExtractFile)
@visit_node.register(ops.ExtractFragment)
@visit_node.register(ops.ExtractHost)
@visit_node.register(ops.ExtractPath)
@visit_node.register(ops.ExtractProtocol)
@visit_node.register(ops.ExtractQuery)
@visit_node.register(ops.ExtractUserInfo)
@visit_node.register(ops.Quantile)
@visit_node.register(ops.MultiQuantile)
@visit_node.register(ops.Median)
def visit_Undefined(self, op, **_):
raise com.OperationNotDefinedError(type(op).__name__)


_SIMPLE_OPS = {
Expand Down

0 comments on commit a14d7ba

Please sign in to comment.