Skip to content

Commit

Permalink
refactor(clickhouse): replace lit with builtin sqlglot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 11, 2023
1 parent a47be79 commit 221b630
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 46 deletions.
18 changes: 5 additions & 13 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ def __getitem__(self, key: str) -> partial:
return getattr(self, key)


def _to_sqlglot(arg):
return arg if isinstance(arg, sg.exp.Expression) else lit(arg)


def _func(name: str, *args: Any, **kwargs: Any):
return sg.func(name, *map(_to_sqlglot, args), **kwargs)
return sg.func(name, *map(sg.exp.convert, args), **kwargs)


class FuncGen:
Expand All @@ -41,16 +37,16 @@ def __getitem__(self, key: str) -> partial:
return getattr(self, key)

def array(self, *args):
return sg.exp.Array.from_arg_list(list(map(_to_sqlglot, args)))
return sg.exp.Array.from_arg_list(list(map(sg.exp.convert, args)))

def tuple(self, *args):
return sg.func("tuple", *map(_to_sqlglot, args))
return sg.func("tuple", *map(sg.exp.convert, args))

def exists(self, query):
return sg.exp.Exists(this=query)

def concat(self, *args):
return sg.exp.Concat.from_arg_list(list(map(_to_sqlglot, args)))
return sg.exp.Concat.from_arg_list(list(map(sg.exp.convert, args)))

def map(self, keys, values):
return sg.exp.Map(keys=keys, values=values)
Expand All @@ -66,12 +62,8 @@ def __getitem__(self, key: str) -> sg.exp.Column:
return sg.column(key)


def lit(val):
return sg.exp.Literal(this=str(val), is_string=isinstance(val, str))


def interval(value, *, unit):
return sg.exp.Interval(this=_to_sqlglot(value), unit=sg.exp.var(unit))
return sg.exp.Interval(this=sg.exp.convert(value), unit=sg.exp.var(unit))


F = FuncGen()
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.base.sqlglot import STAR, C, F, lit
from ibis.backends.base.sqlglot import STAR, C, F
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import ClickhouseType

Expand Down Expand Up @@ -201,7 +201,7 @@ def list_tables(
if database is None:
database = F.currentDatabase()
else:
database = lit(database)
database = sg.exp.convert(database)

query = query.where(C.database.eq(database).or_(C.is_temporary))

Expand Down
44 changes: 15 additions & 29 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,20 @@

import calendar
import functools
import math
import operator
from functools import partial
from typing import Any

import sqlglot as sg
from sqlglot.dialects.dialect import rename_func

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.base.sqlglot import (
NULL,
STAR,
AggGen,
C,
F,
interval,
lit,
make_cast,
)
from ibis.backends.base.sqlglot import NULL, STAR, AggGen, C, F, interval, make_cast
from ibis.backends.clickhouse.datatypes import ClickhouseType

# TODO: This is a hack to get around the fact that sqlglot 17.8.6 is broken for
# ClickHouse's isNaN
sg.dialects.clickhouse.ClickHouse.Generator.TRANSFORMS.update(
{
sg.exp.IsNan: rename_func("isNaN"),
sg.exp.StartsWith: rename_func("startsWith"),
}
)


def _aggregate(funcname, *args, where):
has_filter = where is not None
Expand Down Expand Up @@ -325,14 +307,14 @@ def _literal(op, *, value, dtype, **kw):
return NULL
return cast(NULL, dtype)
elif dtype.is_boolean():
return lit(bool(value))
return sg.exp.convert(bool(value))
elif dtype.is_inet():
v = str(value)
return F.toIPv6(v) if ":" in v else F.toIPv4(v)
elif dtype.is_string():
return lit(str(value).replace(r"\0", r"\\0"))
return sg.exp.convert(str(value).replace(r"\0", r"\\0"))
elif dtype.is_macaddr():
return lit(str(value))
return sg.exp.convert(str(value))
elif dtype.is_decimal():
precision = dtype.precision
if precision is None or not 1 <= precision <= 76:
Expand All @@ -350,10 +332,14 @@ def _literal(op, *, value, dtype, **kw):
type_name = F.toDecimal256
return type_name(value, dtype.scale)
elif dtype.is_numeric():
return lit(value)
if math.isnan(value):
return sg.exp.Literal(this="NaN", is_string=False)
elif math.isinf(value):
inf = sg.exp.Literal(this="inf", is_string=False)
return -inf if value < 0 else inf
return sg.exp.convert(value)
elif dtype.is_interval():
dtype = op.dtype
if dtype.unit.short in {"ms", "us", "ns"}:
if dtype.unit.short in ("ms", "us", "ns"):
raise com.UnsupportedOperationError(
"Clickhouse doesn't support subsecond interval resolutions"
)
Expand Down Expand Up @@ -393,7 +379,7 @@ def _literal(op, *, value, dtype, **kw):
values = []

for k, v in value.items():
keys.append(lit(k))
keys.append(sg.exp.convert(k))
values.append(
_literal(
ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw
Expand Down Expand Up @@ -578,7 +564,7 @@ def _clip(op, *, arg, lower, upper, **_):
def _struct_field(op, *, arg, field: str, **_):
arg_dtype = op.arg.dtype
idx = arg_dtype.names.index(field)
return cast(sg.exp.Dot(this=arg, expression=lit(idx + 1)), op.dtype)
return cast(sg.exp.Dot(this=arg, expression=sg.exp.convert(idx + 1)), op.dtype)


@translate_val.register(ops.NthValue)
Expand Down Expand Up @@ -638,7 +624,7 @@ def day_of_week_name(op, *, arg, **_):
sg.exp.Case(
this=base,
ifs=[if_(day, calendar.day_name[day]) for day in weekdays],
default=lit(""),
default=sg.exp.convert(""),
),
"",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
False
FALSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
True
TRUE

0 comments on commit 221b630

Please sign in to comment.