Skip to content

Commit

Permalink
refactor(sql): move dialects to always-importable location
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 8, 2024
1 parent 977df11 commit 8bf20f3
Show file tree
Hide file tree
Showing 43 changed files with 438 additions and 470 deletions.
3 changes: 0 additions & 3 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __getattr__(name: str) -> BaseBackend:
# - add_operation
# - _from_url
# - _to_sql
# - _sqlglot_dialect (if defined)
#
# We also copy over the docstring from `do_connect` to the proxy `connect`
# method, since that's where all the backend-specific kwargs are currently
Expand All @@ -120,8 +119,6 @@ def connect(*args, **kwargs):
proxy.name = name
proxy._from_url = backend._from_url
proxy._to_sql = backend._to_sql
if (dialect := getattr(backend, "_sqlglot_dialect", None)) is not None:
proxy._sqlglot_dialect = dialect
# Add any additional methods that should be exposed at the top level
for name in getattr(backend, "_top_level_methods", ()):
setattr(proxy, name, getattr(backend, name))
Expand Down
30 changes: 11 additions & 19 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,11 @@

import pandas as pd
import pyarrow as pa
import sqlglot as sg
import torch

__all__ = ("BaseBackend", "Database", "connect")

# TODO(cpcloud): move these to a place that doesn't require importing
# backend-specific dependencies
_IBIS_TO_SQLGLOT_DIALECT = {
"mssql": "tsql",
"impala": "hive",
"pyspark": "spark",
"polars": "postgres",
"datafusion": "postgres",
# closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901
"exasol": "oracle",
"risingwave": "postgres",
}


class Database:
"""Generic Database class."""
Expand Down Expand Up @@ -805,6 +793,14 @@ def __init__(self, *args, **kwargs):
key=lambda expr: expr.op(),
)

@property
@abc.abstractmethod
def dialect(self) -> sg.Dialect | None:
"""The sqlglot dialect for this backend, where applicable.
Returns None if the backend is not a SQL backend.
"""

def __getstate__(self):
return dict(_con_args=self._con_args, _con_kwargs=self._con_kwargs)

Expand Down Expand Up @@ -1272,15 +1268,11 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:

# only transpile if the backend dialect doesn't match the input dialect
name = self.name
if (output_dialect := getattr(self, "_sqlglot_dialect", name)) is None:
if (output_dialect := self.dialect) is None:
raise NotImplementedError(f"No known sqlglot dialect for backend {name}")

if dialect != output_dialect:
(query,) = sg.transpile(
query,
read=_IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect),
write=output_dialect,
)
(query,) = sg.transpile(query, read=dialect, write=output_dialect)
return query


Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ class BaseSQLBackend(BaseBackend):

compiler = Compiler

@property
def _sqlglot_dialect(self) -> str:
return self.name

def _from_url(self, url: str, **kwargs):
"""Connect to a backend using a URL `url`.
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SQLGlotBackend(BaseBackend):
name: ClassVar[str]

@property
def _sqlglot_dialect(self) -> str:
def dialect(self) -> sg.Dialect:
return self.compiler.dialect

@classmethod
Expand Down Expand Up @@ -115,7 +115,7 @@ def compile(
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
sql = query.sql(dialect=self.compiler.dialect, pretty=True)
sql = query.sql(dialect=self.dialect, pretty=True)
self._log(sql)
return sql

Expand Down Expand Up @@ -380,6 +380,6 @@ def truncate_table(
"""
ident = sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
).sql(self.compiler.dialect)
).sql(self.dialect)
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass
95 changes: 94 additions & 1 deletion ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from functools import partial
from typing import NoReturn
from typing import Literal, NoReturn

import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.common.collections import FrozenDict
Expand Down Expand Up @@ -942,3 +943,95 @@ def _from_sqlglot_MAP(cls) -> sge.DataType:
@classmethod
def _from_sqlglot_STRUCT(cls) -> sge.DataType:
raise com.UnsupportedBackendType("SQL Server does not support structs")


# TODO(kszucs): add a bool converter method to support different clickhouse bool types
def _bool_type() -> Literal["Bool", "UInt8", "Int8"]:
return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool")


class ClickHouseType(SqlglotType):
dialect = "clickhouse"
default_decimal_precision = None
default_decimal_scale = None
default_nullable = False

unknown_type_strings = FrozenDict(
{
"ipv4": dt.INET(nullable=default_nullable),
"ipv6": dt.INET(nullable=default_nullable),
"object('json')": dt.JSON(nullable=default_nullable),
"array(null)": dt.Array(dt.null, nullable=default_nullable),
"array(nothing)": dt.Array(dt.null, nullable=default_nullable),
}
)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not dtype.is_map():
# map cannot be nullable in clickhouse
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ

@classmethod
def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type, nullable=True)

@classmethod
def _from_sqlglot_DATETIME(
cls, timezone: sge.DataTypeParam | None = None
) -> dt.Timestamp:
return dt.Timestamp(
scale=0,
timezone=None if timezone is None else timezone.this.this,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_DATETIME64(
cls,
scale: sge.DataTypeSize | None = None,
timezone: sge.Literal | None = None,
) -> dt.Timestamp:
return dt.Timestamp(
timezone=None if timezone is None else timezone.this.this,
scale=int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type)

@classmethod
def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
fields = {
field.name: dt.Array(
cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable
)
for field in fields
}
return dt.Struct(fields, nullable=cls.default_nullable)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
timezone = None
else:
timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone))

if dtype.scale is None:
return sge.DataType(this=typecode.DATETIME, expressions=[timezone])
else:
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone])

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
Loading

0 comments on commit 8bf20f3

Please sign in to comment.