Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sql): move dialects to always-importable location #8279

Merged
merged 4 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ diamonds.json
*.ndjson
reference/
objects.json
*support_matrix.csv

# generated notebooks and files
*.ipynb
Expand Down
54 changes: 43 additions & 11 deletions docs/support_matrix.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,49 @@ hide:

```{python}
#| echo: false
!python ../gen_matrix.py
```
from pathlib import Path

```{python}
#| echo: false
import pandas as pd

support_matrix = pd.read_csv("./backends/raw_support_matrix.csv")
support_matrix = support_matrix.assign(
Category=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[0].rsplit(".", 1)[-1]),
Operation=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[-1]),
).set_index(["Category", "Operation"])
import ibis
import ibis.expr.operations as ops


def get_backends(exclude=()):
entry_points = sorted(ep.name for ep in ibis.util.backend_entry_points())
return [
(backend, getattr(ibis, backend))
for backend in entry_points
if backend not in exclude
]


def get_leaf_classes(op):
for child_class in op.__subclasses__():
if not child_class.__subclasses__():
yield child_class
else:
yield from get_leaf_classes(child_class)


public_ops = frozenset(get_leaf_classes(ops.Value))
support = {"Operation": [f"{op.__module__}.{op.__name__}" for op in public_ops]}
support.update(
(name, list(map(backend.has_operation, public_ops)))
for name, backend in get_backends()
)

support_matrix = (
pd.DataFrame(support)
.assign(splits=lambda df: df.Operation.str.findall("[a-zA-Z_][a-zA-Z_0-9]*"))
.assign(
Category=lambda df: df.splits.str[-2],
Operation=lambda df: df.splits.str[-1],
)
.drop(["splits"], axis=1)
.set_index(["Category", "Operation"])
.sort_index()
)
all_visible_ops_count = len(support_matrix)
coverage = pd.Index(
support_matrix.sum()
Expand Down Expand Up @@ -70,15 +101,16 @@ dict(
#| content: valuebox
#| title: "Number of SQL backends"
import importlib
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sqlglot import SQLGlotBackend

sql_backends = sum(
issubclass(
importlib.import_module(f"ibis.backends.{entry_point.name}").Backend,
BaseSQLBackend
SQLGlotBackend
)
for entry_point in ibis.util.backend_entry_points()
)
assert sql_backends > 0
dict(value=sql_backends, color="green", icon="database")
```

Expand Down
45 changes: 0 additions & 45 deletions gen_matrix.py

This file was deleted.

3 changes: 0 additions & 3 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __getattr__(name: str) -> BaseBackend:
# - add_operation
# - _from_url
# - _to_sql
# - _sqlglot_dialect (if defined)
#
# We also copy over the docstring from `do_connect` to the proxy `connect`
# method, since that's where all the backend-specific kwargs are currently
Expand All @@ -120,8 +119,6 @@ def connect(*args, **kwargs):
proxy.name = name
proxy._from_url = backend._from_url
proxy._to_sql = backend._to_sql
if (dialect := getattr(backend, "_sqlglot_dialect", None)) is not None:
proxy._sqlglot_dialect = dialect
# Add any additional methods that should be exposed at the top level
for name in getattr(backend, "_top_level_methods", ()):
setattr(proxy, name, getattr(backend, name))
Expand Down
30 changes: 11 additions & 19 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,11 @@

import pandas as pd
import pyarrow as pa
import sqlglot as sg
import torch

__all__ = ("BaseBackend", "Database", "connect")

# TODO(cpcloud): move these to a place that doesn't require importing
# backend-specific dependencies
_IBIS_TO_SQLGLOT_DIALECT = {
"mssql": "tsql",
"impala": "hive",
"pyspark": "spark",
"polars": "postgres",
"datafusion": "postgres",
# closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901
"exasol": "oracle",
"risingwave": "postgres",
}


class Database:
"""Generic Database class."""
Expand Down Expand Up @@ -805,6 +793,14 @@ def __init__(self, *args, **kwargs):
key=lambda expr: expr.op(),
)

@property
@abc.abstractmethod
def dialect(self) -> sg.Dialect | None:
"""The sqlglot dialect for this backend, where applicable.

Returns None if the backend is not a SQL backend.
"""

def __getstate__(self):
return dict(_con_args=self._con_args, _con_kwargs=self._con_kwargs)

Expand Down Expand Up @@ -1272,15 +1268,11 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:

# only transpile if the backend dialect doesn't match the input dialect
name = self.name
if (output_dialect := getattr(self, "_sqlglot_dialect", name)) is None:
if (output_dialect := self.dialect) is None:
raise NotImplementedError(f"No known sqlglot dialect for backend {name}")

if dialect != output_dialect:
(query,) = sg.transpile(
query,
read=_IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect),
write=output_dialect,
)
(query,) = sg.transpile(query, read=dialect, write=output_dialect)
return query


Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ class BaseSQLBackend(BaseBackend):

compiler = Compiler

@property
def _sqlglot_dialect(self) -> str:
return self.name

def _from_url(self, url: str, **kwargs):
"""Connect to a backend using a URL `url`.

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SQLGlotBackend(BaseBackend):
name: ClassVar[str]

@property
def _sqlglot_dialect(self) -> str:
def dialect(self) -> sg.Dialect:
return self.compiler.dialect

@classmethod
Expand Down Expand Up @@ -115,7 +115,7 @@ def compile(
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
sql = query.sql(dialect=self.compiler.dialect, pretty=True)
sql = query.sql(dialect=self.dialect, pretty=True)
self._log(sql)
return sql

Expand Down Expand Up @@ -380,6 +380,6 @@ def truncate_table(
"""
ident = sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
).sql(self.compiler.dialect)
).sql(self.dialect)
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass
95 changes: 94 additions & 1 deletion ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from functools import partial
from typing import NoReturn
from typing import Literal, NoReturn

import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.common.collections import FrozenDict
Expand Down Expand Up @@ -942,3 +943,95 @@ def _from_sqlglot_MAP(cls) -> sge.DataType:
@classmethod
def _from_sqlglot_STRUCT(cls) -> sge.DataType:
raise com.UnsupportedBackendType("SQL Server does not support structs")


# TODO(kszucs): add a bool converter method to support different clickhouse bool types
def _bool_type() -> Literal["Bool", "UInt8", "Int8"]:
return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool")


class ClickHouseType(SqlglotType):
dialect = "clickhouse"
default_decimal_precision = None
default_decimal_scale = None
default_nullable = False

unknown_type_strings = FrozenDict(
{
"ipv4": dt.INET(nullable=default_nullable),
"ipv6": dt.INET(nullable=default_nullable),
"object('json')": dt.JSON(nullable=default_nullable),
"array(null)": dt.Array(dt.null, nullable=default_nullable),
"array(nothing)": dt.Array(dt.null, nullable=default_nullable),
}
)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not dtype.is_map():
# map cannot be nullable in clickhouse
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ

@classmethod
def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type, nullable=True)

@classmethod
def _from_sqlglot_DATETIME(
cls, timezone: sge.DataTypeParam | None = None
) -> dt.Timestamp:
return dt.Timestamp(
scale=0,
timezone=None if timezone is None else timezone.this.this,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_DATETIME64(
cls,
scale: sge.DataTypeSize | None = None,
timezone: sge.Literal | None = None,
) -> dt.Timestamp:
return dt.Timestamp(
timezone=None if timezone is None else timezone.this.this,
scale=int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
return cls.to_ibis(inner_type)

@classmethod
def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
fields = {
field.name: dt.Array(
cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable
)
for field in fields
}
return dt.Struct(fields, nullable=cls.default_nullable)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
timezone = None
else:
timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone))

if dtype.scale is None:
return sge.DataType(this=typecode.DATETIME, expressions=[timezone])
else:
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone])

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
Loading
Loading