Skip to content

Commit

Permalink
chore: pin newer sqlglot version and update the implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 23, 2023
1 parent d1f9d2d commit f74e8fd
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 353 deletions.
33 changes: 17 additions & 16 deletions ibis/backends/base/sql/glot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
"""Convert a sqlglot type to an ibis type."""
typecode = typ.this

if method := getattr(cls, f"_to_ibis_{typecode.name}", None):
if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None):
dtype = method(*typ.expressions)
else:
dtype = _from_sqlglot_types[typecode](nullable=cls.default_nullable)
Expand Down Expand Up @@ -180,19 +180,21 @@ def to_string(cls, dtype: dt.DataType) -> str:
return cls.from_ibis(dtype).sql(dialect=cls.dialect)

@classmethod
def _to_ibis_ARRAY(cls, value_type: sge.DataType) -> dt.Array:
def _from_sqlglot_ARRAY(cls, value_type: sge.DataType) -> dt.Array:
return dt.Array(cls.to_ibis(value_type), nullable=cls.default_nullable)

@classmethod
def _to_ibis_MAP(cls, key_type: sge.DataType, value_type: sge.DataType) -> dt.Map:
def _from_sqlglot_MAP(
cls, key_type: sge.DataType, value_type: sge.DataType
) -> dt.Map:
return dt.Map(
cls.to_ibis(key_type),
cls.to_ibis(value_type),
nullable=cls.default_nullable,
)

@classmethod
def _to_ibis_STRUCT(cls, *fields: sge.ColumnDef) -> dt.Struct:
def _from_sqlglot_STRUCT(cls, *fields: sge.ColumnDef) -> dt.Struct:
types = {}
for i, field in enumerate(fields):
if isinstance(field, sge.ColumnDef):
Expand All @@ -202,33 +204,33 @@ def _to_ibis_STRUCT(cls, *fields: sge.ColumnDef) -> dt.Struct:
return dt.Struct(types, nullable=cls.default_nullable)

@classmethod
def _to_ibis_TIMESTAMP(cls, scale=None) -> dt.Timestamp:
def _from_sqlglot_TIMESTAMP(cls, scale=None) -> dt.Timestamp:
return dt.Timestamp(
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _to_ibis_TIMESTAMPTZ(cls, scale=None) -> dt.Timestamp:
def _from_sqlglot_TIMESTAMPTZ(cls, scale=None) -> dt.Timestamp:
return dt.Timestamp(
timezone="UTC",
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _to_ibis_INTERVAL(
cls, precision: sge.DataTypeSize | None = None
def _from_sqlglot_INTERVAL(
cls, precision: sge.DataTypeParam | None = None
) -> dt.Interval:
if precision is None:
precision = cls.default_interval_precision
return dt.Interval(str(precision), nullable=cls.default_nullable)

@classmethod
def _to_ibis_DECIMAL(
def _from_sqlglot_DECIMAL(
cls,
precision: sge.DataTypeSize | None = None,
scale: sge.DataTypeSize | None = None,
precision: sge.DataTypeParam | None = None,
scale: sge.DataTypeParam | None = None,
) -> dt.Decimal:
if precision is None:
precision = cls.default_decimal_precision
Expand Down Expand Up @@ -266,17 +268,16 @@ def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
return sge.DataType(
this=typecode.DECIMAL,
expressions=[
sge.Literal.number(dtype.precision),
sge.Literal.number(dtype.scale),
sge.DataTypeParam(this=sge.Literal.number(dtype.precision)),
sge.DataTypeParam(this=sge.Literal.number(dtype.scale)),
],
)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
code = typecode.TIMESTAMP if dtype.timezone is None else typecode.TIMESTAMPTZ
if dtype.scale:
scale = sge.Literal.number(dtype.scale)
scale = sge.DataTypeSize(this=typecode.DATATYPESIZE, expressions=[scale])
if dtype.scale is not None:
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
return sge.DataType(this=code, expressions=[scale])
else:
return sge.DataType(this=code)
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/base/sql/glot/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
if sqlglot_expected is not None:
assert sqlglot_result == sqlglot_expected

restored_dtype = SqlglotType.to_ibis(
sqlglot_result
) # , nullable=ibis_type.nullable)
restored_dtype = SqlglotType.to_ibis(sqlglot_result)
assert ibis_type == restored_dtype


Expand All @@ -38,7 +36,6 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
| its.date_dtype(nullable=true)
| its.time_dtype(nullable=true)
| its.timestamp_dtype(timezone=st.none(), nullable=true)
# | its.interval_dtype(nullable=true)
| its.array_dtypes(roundtripable_types, nullable=true)
| its.map_dtypes(roundtripable_types, roundtripable_types, nullable=true)
| its.struct_dtypes(roundtripable_types, nullable=true)
Expand All @@ -48,6 +45,7 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
# not roundtrippable:
# - float16
# - macaddr
# - interval?


@h.given(roundtripable_types)
Expand Down
40 changes: 18 additions & 22 deletions ibis/backends/clickhouse/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Literal

import sqlglot.expressions as sge
from sqlglot.dialects.clickhouse import ClickHouse

import ibis
import ibis.expr.datatypes as dt
Expand All @@ -12,11 +11,6 @@

typecode = sge.DataType.Type

# TODO(kszucs): it may not be the nicest way to do it, should be pushed upstream
# or just use TEXT/VARCHAR which are aliases for `String`, but our test suite
# expects `String` to be used
ClickHouse.Generator.TYPE_MAPPING[typecode.VARCHAR] = "String"


def _bool_type() -> Literal["Bool", "UInt8", "Int8"]:
return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool")
Expand All @@ -26,6 +20,7 @@ class ClickhouseType(SqlglotType):
dialect = "clickhouse"
default_decimal_precision = None
default_decimal_scale = None
default_temporal_scale = 3
default_nullable = False

unknown_type_strings = FrozenDict(
Expand All @@ -42,42 +37,44 @@ class ClickhouseType(SqlglotType):
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
non_nullable_types = (dt.Map,)
if dtype.nullable and not isinstance(dtype, non_nullable_types):
if dtype.nullable and not dtype.is_map():
# map cannot be nullable in clickhouse
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ

@classmethod
def _to_ibis_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type, nullable=True)

@classmethod
def _to_ibis_DATETIME(cls, timezone: sge.Literal | None = None) -> dt.Timestamp:
def _from_sqlglot_DATETIME(
cls, timezone: sge.DataTypeParam | None = None
) -> dt.Timestamp:
return dt.Timestamp(
scale=0,
timezone=None if timezone is None else timezone.this.this,
scale=cls.default_temporal_scale,
nullable=cls.default_nullable,
)

@classmethod
def _to_ibis_DATETIME64(
def _from_sqlglot_DATETIME64(
cls,
scale: sge.DataTypeSize | None = None,
timezone: sge.Literal | None = None,
) -> dt.Timestamp:
return dt.Timestamp(
timezone=None if timezone is None else timezone.this.this,
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
scale=int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _to_ibis_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type)

@classmethod
def _to_ibis_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
fields = {
field.name: dt.Array(
cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable
Expand All @@ -88,16 +85,15 @@ def _to_ibis_NESTED(cls, *fields: sge.DataType) -> dt.Struct:

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.scale is None:
scale = sge.Literal.number(3)
if dtype.timezone is None:
timezone = None
else:
scale = sge.Literal.number(dtype.scale)
timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone))

scale = sge.DataTypeSize(this=typecode.DATATYPESIZE, expressions=[scale])
if dtype.timezone is None:
return sge.DataType(this=typecode.DATETIME64, expressions=[scale])
if dtype.scale is None:
return sge.DataType(this=typecode.DATETIME, expressions=[timezone])
else:
timezone = sge.Literal.string(dtype.timezone)
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone])

@classmethod
Expand Down
54 changes: 47 additions & 7 deletions ibis/backends/clickhouse/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import hypothesis as h
import hypothesis.strategies as st
import pytest
from pytest import param

import ibis.expr.datatypes as dt
import ibis.tests.strategies as its
from ibis.backends.clickhouse.datatypes import ClickhouseType

pytest.importorskip("clickhouse_connect")
Expand Down Expand Up @@ -114,12 +117,12 @@ def test_columns_types_with_additional_argument(con):
),
param(
"Array(DateTime)",
dt.Array(dt.Timestamp(nullable=False), nullable=False),
dt.Array(dt.Timestamp(scale=0, nullable=False), nullable=False),
id="array_datetime",
),
param(
"Array(DateTime64)",
dt.Array(dt.Timestamp(nullable=False), nullable=False),
"Array(DateTime64(9))",
dt.Array(dt.Timestamp(scale=9, nullable=False), nullable=False),
id="array_datetime64",
),
param("Array(Nothing)", dt.Array(dt.null, nullable=False), id="array_nothing"),
Expand Down Expand Up @@ -207,13 +210,12 @@ def test_columns_types_with_additional_argument(con):
),
id="nested",
),
param("DateTime", dt.Timestamp(nullable=False), id="datetime"),
param("DateTime", dt.Timestamp(scale=0, nullable=False), id="datetime"),
param(
"DateTime('Europe/Budapest')",
dt.Timestamp(timezone="Europe/Budapest", nullable=False),
dt.Timestamp(scale=0, timezone="Europe/Budapest", nullable=False),
id="datetime_timezone",
),
param("DateTime64", dt.Timestamp(nullable=False), id="datetime64"),
param(
"DateTime64(0)", dt.Timestamp(scale=0, nullable=False), id="datetime64_zero"
),
Expand All @@ -232,4 +234,42 @@ def test_columns_types_with_additional_argument(con):
],
)
def test_parse_type(ch_type, ibis_type):
assert ClickhouseType.from_string(ch_type) == ibis_type
parsed_ibis_type = ClickhouseType.from_string(ch_type)
assert parsed_ibis_type == ibis_type


false = st.just(False)

map_key_types = (
its.string_dtype(nullable=false)
| its.integer_dtypes(nullable=false)
| its.date_dtype(nullable=false)
| its.timestamp_dtype(scale=st.integers(0, 9), nullable=false)
)

roundtrippable_types = st.deferred(
lambda: (
its.null_dtype
| its.boolean_dtype()
| its.integer_dtypes()
| st.just(dt.Float32())
| st.just(dt.Float64())
| its.decimal_dtypes()
| its.string_dtype()
| its.json_dtype()
| its.inet_dtype()
| its.uuid_dtype()
| its.date_dtype()
| its.time_dtype()
| its.timestamp_dtype(scale=st.integers(0, 9))
| its.array_dtypes(roundtrippable_types)
| its.map_dtypes(map_key_types, roundtrippable_types, nullable=false)
)
)


@h.given(roundtrippable_types)
def test_type_roundtrip(ibis_type):
type_string = ClickhouseType.to_string(ibis_type)
parsed_ibis_type = ClickhouseType.from_string(type_string)
assert parsed_ibis_type == ibis_type
Loading

0 comments on commit f74e8fd

Please sign in to comment.