From 221b630b9fced0c658fb8cf687aba679f6386399 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 11 Oct 2023 06:19:45 -0400 Subject: [PATCH] refactor(clickhouse): replace `lit` with builtin sqlglot functions --- ibis/backends/base/sqlglot/__init__.py | 18 +++----- ibis/backends/clickhouse/__init__.py | 4 +- ibis/backends/clickhouse/compiler/values.py | 44 +++++++------------ .../false/out.sql | 2 +- .../true/out.sql | 2 +- 5 files changed, 24 insertions(+), 46 deletions(-) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index d16500d05662..078f16b95b38 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -23,12 +23,8 @@ def __getitem__(self, key: str) -> partial: return getattr(self, key) -def _to_sqlglot(arg): - return arg if isinstance(arg, sg.exp.Expression) else lit(arg) - - def _func(name: str, *args: Any, **kwargs: Any): - return sg.func(name, *map(_to_sqlglot, args), **kwargs) + return sg.func(name, *map(sg.exp.convert, args), **kwargs) class FuncGen: @@ -41,16 +37,16 @@ def __getitem__(self, key: str) -> partial: return getattr(self, key) def array(self, *args): - return sg.exp.Array.from_arg_list(list(map(_to_sqlglot, args))) + return sg.exp.Array.from_arg_list(list(map(sg.exp.convert, args))) def tuple(self, *args): - return sg.func("tuple", *map(_to_sqlglot, args)) + return sg.func("tuple", *map(sg.exp.convert, args)) def exists(self, query): return sg.exp.Exists(this=query) def concat(self, *args): - return sg.exp.Concat.from_arg_list(list(map(_to_sqlglot, args))) + return sg.exp.Concat.from_arg_list(list(map(sg.exp.convert, args))) def map(self, keys, values): return sg.exp.Map(keys=keys, values=values) @@ -66,12 +62,8 @@ def __getitem__(self, key: str) -> sg.exp.Column: return sg.column(key) -def lit(val): - return sg.exp.Literal(this=str(val), is_string=isinstance(val, str)) - - def interval(value, *, unit): - return sg.exp.Interval(this=_to_sqlglot(value), unit=sg.exp.var(unit)) + return sg.exp.Interval(this=sg.exp.convert(value), unit=sg.exp.var(unit)) F = FuncGen() diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index a5a26df44745..6077f58e52c4 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -23,7 +23,7 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base import BaseBackend, CanCreateDatabase -from ibis.backends.base.sqlglot import STAR, C, F, lit +from ibis.backends.base.sqlglot import STAR, C, F from ibis.backends.clickhouse.compiler import translate from ibis.backends.clickhouse.datatypes import ClickhouseType @@ -201,7 +201,7 @@ def list_tables( if database is None: database = F.currentDatabase() else: - database = lit(database) + database = sg.exp.convert(database) query = query.where(C.database.eq(database).or_(C.is_temporary)) diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 85b225399aba..05eac06695b3 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -2,38 +2,20 @@ import calendar import functools +import math import operator from functools import partial from typing import Any import sqlglot as sg -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.base.sqlglot import ( - NULL, - STAR, - AggGen, - C, - F, - interval, - lit, - make_cast, -) +from ibis.backends.base.sqlglot import NULL, STAR, AggGen, C, F, interval, make_cast from ibis.backends.clickhouse.datatypes import ClickhouseType -# TODO: This is a hack to get around the fact that sqlglot 17.8.6 is broken for -# ClickHouse's isNaN -sg.dialects.clickhouse.ClickHouse.Generator.TRANSFORMS.update( - { - sg.exp.IsNan: rename_func("isNaN"), - sg.exp.StartsWith: rename_func("startsWith"), - } -) - def _aggregate(funcname, *args, where): has_filter = where is not None @@ -325,14 +307,14 @@ def _literal(op, *, value, dtype, **kw): return NULL return cast(NULL, dtype) elif dtype.is_boolean(): - return lit(bool(value)) + return sg.exp.convert(bool(value)) elif dtype.is_inet(): v = str(value) return F.toIPv6(v) if ":" in v else F.toIPv4(v) elif dtype.is_string(): - return lit(str(value).replace(r"\0", r"\\0")) + return sg.exp.convert(str(value).replace(r"\0", r"\\0")) elif dtype.is_macaddr(): - return lit(str(value)) + return sg.exp.convert(str(value)) elif dtype.is_decimal(): precision = dtype.precision if precision is None or not 1 <= precision <= 76: @@ -350,10 +332,14 @@ def _literal(op, *, value, dtype, **kw): type_name = F.toDecimal256 return type_name(value, dtype.scale) elif dtype.is_numeric(): - return lit(value) + if math.isnan(value): + return sg.exp.Literal(this="NaN", is_string=False) + elif math.isinf(value): + inf = sg.exp.Literal(this="inf", is_string=False) + return -inf if value < 0 else inf + return sg.exp.convert(value) elif dtype.is_interval(): - dtype = op.dtype - if dtype.unit.short in {"ms", "us", "ns"}: + if dtype.unit.short in ("ms", "us", "ns"): raise com.UnsupportedOperationError( "Clickhouse doesn't support subsecond interval resolutions" ) @@ -393,7 +379,7 @@ def _literal(op, *, value, dtype, **kw): values = [] for k, v in value.items(): - keys.append(lit(k)) + keys.append(sg.exp.convert(k)) values.append( _literal( ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw @@ -578,7 +564,7 @@ def _clip(op, *, arg, lower, upper, **_): def _struct_field(op, *, arg, field: str, **_): arg_dtype = op.arg.dtype idx = arg_dtype.names.index(field) - return cast(sg.exp.Dot(this=arg, expression=lit(idx + 1)), op.dtype) + return cast(sg.exp.Dot(this=arg, expression=sg.exp.convert(idx + 1)), op.dtype) @translate_val.register(ops.NthValue) @@ -638,7 +624,7 @@ def day_of_week_name(op, *, arg, **_): sg.exp.Case( this=base, ifs=[if_(day, calendar.day_name[day]) for day in weekdays], - default=lit(""), + default=sg.exp.convert(""), ), "", ) diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql index 52db1a268bf3..ff725623a2c4 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql @@ -1,2 +1,2 @@ SELECT - False \ No newline at end of file + FALSE \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql index 55e1da583d09..fc55522005b9 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql @@ -1,2 +1,2 @@ SELECT - True \ No newline at end of file + TRUE \ No newline at end of file