Skip to content

Commit

Permalink
refactor(sql): move dialects to always-importable location (#8279)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Feb 8, 2024
1 parent 977df11 commit e75b229
Show file tree
Hide file tree
Showing 63 changed files with 500 additions and 541 deletions.
1 change: 0 additions & 1 deletion docs/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ diamonds.json
*.ndjson
reference/
objects.json
*support_matrix.csv

# generated notebooks and files
*.ipynb
Expand Down
54 changes: 43 additions & 11 deletions docs/support_matrix.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,49 @@ hide:

```{python}
#| echo: false
!python ../gen_matrix.py
```
from pathlib import Path
```{python}
#| echo: false
import pandas as pd
support_matrix = pd.read_csv("./backends/raw_support_matrix.csv")
support_matrix = support_matrix.assign(
Category=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[0].rsplit(".", 1)[-1]),
Operation=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[-1]),
).set_index(["Category", "Operation"])
import ibis
import ibis.expr.operations as ops
def get_backends(exclude=()):
entry_points = sorted(ep.name for ep in ibis.util.backend_entry_points())
return [
(backend, getattr(ibis, backend))
for backend in entry_points
if backend not in exclude
]
def get_leaf_classes(op):
for child_class in op.__subclasses__():
if not child_class.__subclasses__():
yield child_class
else:
yield from get_leaf_classes(child_class)
public_ops = frozenset(get_leaf_classes(ops.Value))
support = {"Operation": [f"{op.__module__}.{op.__name__}" for op in public_ops]}
support.update(
(name, list(map(backend.has_operation, public_ops)))
for name, backend in get_backends()
)
support_matrix = (
pd.DataFrame(support)
.assign(splits=lambda df: df.Operation.str.findall("[a-zA-Z_][a-zA-Z_0-9]*"))
.assign(
Category=lambda df: df.splits.str[-2],
Operation=lambda df: df.splits.str[-1],
)
.drop(["splits"], axis=1)
.set_index(["Category", "Operation"])
.sort_index()
)
all_visible_ops_count = len(support_matrix)
coverage = pd.Index(
support_matrix.sum()
Expand Down Expand Up @@ -70,15 +101,16 @@ dict(
#| content: valuebox
#| title: "Number of SQL backends"
import importlib
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sqlglot import SQLGlotBackend
sql_backends = sum(
issubclass(
importlib.import_module(f"ibis.backends.{entry_point.name}").Backend,
BaseSQLBackend
SQLGlotBackend
)
for entry_point in ibis.util.backend_entry_points()
)
assert sql_backends > 0
dict(value=sql_backends, color="green", icon="database")
```

Expand Down
45 changes: 0 additions & 45 deletions gen_matrix.py

This file was deleted.

3 changes: 0 additions & 3 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __getattr__(name: str) -> BaseBackend:
# - add_operation
# - _from_url
# - _to_sql
# - _sqlglot_dialect (if defined)
#
# We also copy over the docstring from `do_connect` to the proxy `connect`
# method, since that's where all the backend-specific kwargs are currently
Expand All @@ -120,8 +119,6 @@ def connect(*args, **kwargs):
proxy.name = name
proxy._from_url = backend._from_url
proxy._to_sql = backend._to_sql
if (dialect := getattr(backend, "_sqlglot_dialect", None)) is not None:
proxy._sqlglot_dialect = dialect
# Add any additional methods that should be exposed at the top level
for name in getattr(backend, "_top_level_methods", ()):
setattr(proxy, name, getattr(backend, name))
Expand Down
30 changes: 11 additions & 19 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,11 @@

import pandas as pd
import pyarrow as pa
import sqlglot as sg
import torch

__all__ = ("BaseBackend", "Database", "connect")

# TODO(cpcloud): move these to a place that doesn't require importing
# backend-specific dependencies
_IBIS_TO_SQLGLOT_DIALECT = {
"mssql": "tsql",
"impala": "hive",
"pyspark": "spark",
"polars": "postgres",
"datafusion": "postgres",
# closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901
"exasol": "oracle",
"risingwave": "postgres",
}


class Database:
"""Generic Database class."""
Expand Down Expand Up @@ -805,6 +793,14 @@ def __init__(self, *args, **kwargs):
key=lambda expr: expr.op(),
)

@property
@abc.abstractmethod
def dialect(self) -> sg.Dialect | None:
"""The sqlglot dialect for this backend, where applicable.
Returns None if the backend is not a SQL backend.
"""

def __getstate__(self):
return dict(_con_args=self._con_args, _con_kwargs=self._con_kwargs)

Expand Down Expand Up @@ -1272,15 +1268,11 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:

# only transpile if the backend dialect doesn't match the input dialect
name = self.name
if (output_dialect := getattr(self, "_sqlglot_dialect", name)) is None:
if (output_dialect := self.dialect) is None:
raise NotImplementedError(f"No known sqlglot dialect for backend {name}")

if dialect != output_dialect:
(query,) = sg.transpile(
query,
read=_IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect),
write=output_dialect,
)
(query,) = sg.transpile(query, read=dialect, write=output_dialect)
return query


Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ class BaseSQLBackend(BaseBackend):

compiler = Compiler

@property
def _sqlglot_dialect(self) -> str:
return self.name

def _from_url(self, url: str, **kwargs):
"""Connect to a backend using a URL `url`.
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SQLGlotBackend(BaseBackend):
name: ClassVar[str]

@property
def _sqlglot_dialect(self) -> str:
def dialect(self) -> sg.Dialect:
return self.compiler.dialect

@classmethod
Expand Down Expand Up @@ -115,7 +115,7 @@ def compile(
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
sql = query.sql(dialect=self.compiler.dialect, pretty=True)
sql = query.sql(dialect=self.dialect, pretty=True)
self._log(sql)
return sql

Expand Down Expand Up @@ -380,6 +380,6 @@ def truncate_table(
"""
ident = sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
).sql(self.compiler.dialect)
).sql(self.dialect)
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass
87 changes: 87 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,90 @@ def _from_sqlglot_MAP(cls) -> sge.DataType:
@classmethod
def _from_sqlglot_STRUCT(cls) -> sge.DataType:
raise com.UnsupportedBackendType("SQL Server does not support structs")


class ClickHouseType(SqlglotType):
dialect = "clickhouse"
default_decimal_precision = None
default_decimal_scale = None
default_nullable = False

unknown_type_strings = FrozenDict(
{
"ipv4": dt.INET(nullable=default_nullable),
"ipv6": dt.INET(nullable=default_nullable),
"object('json')": dt.JSON(nullable=default_nullable),
"array(null)": dt.Array(dt.null, nullable=default_nullable),
"array(nothing)": dt.Array(dt.null, nullable=default_nullable),
}
)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not dtype.is_map():
# map cannot be nullable in clickhouse
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ

@classmethod
def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type, nullable=True)

@classmethod
def _from_sqlglot_DATETIME(
cls, timezone: sge.DataTypeParam | None = None
) -> dt.Timestamp:
return dt.Timestamp(
scale=0,
timezone=None if timezone is None else timezone.this.this,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_DATETIME64(
cls,
scale: sge.DataTypeSize | None = None,
timezone: sge.Literal | None = None,
) -> dt.Timestamp:
return dt.Timestamp(
timezone=None if timezone is None else timezone.this.this,
scale=int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type)

@classmethod
def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
fields = {
field.name: dt.Array(
cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable
)
for field in fields
}
return dt.Struct(fields, nullable=cls.default_nullable)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
timezone = None
else:
timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone))

if dtype.scale is None:
return sge.DataType(this=typecode.DATETIME, expressions=[timezone])
else:
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone])

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
Loading

0 comments on commit e75b229

Please sign in to comment.