diff --git a/docs/.gitignore b/docs/.gitignore index c6821696481fb..3c9a23d5e7807 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -14,7 +14,6 @@ diamonds.json *.ndjson reference/ objects.json -*support_matrix.csv # generated notebooks and files *.ipynb diff --git a/docs/support_matrix.qmd b/docs/support_matrix.qmd index d805ad99274db..48ec4d07bc4e1 100644 --- a/docs/support_matrix.qmd +++ b/docs/support_matrix.qmd @@ -7,18 +7,49 @@ hide: ```{python} #| echo: false -!python ../gen_matrix.py -``` +from pathlib import Path -```{python} -#| echo: false import pandas as pd -support_matrix = pd.read_csv("./backends/raw_support_matrix.csv") -support_matrix = support_matrix.assign( - Category=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[0].rsplit(".", 1)[-1]), - Operation=support_matrix.Operation.map(lambda op: op.rsplit(".", 1)[-1]), -).set_index(["Category", "Operation"]) +import ibis +import ibis.expr.operations as ops + + +def get_backends(exclude=()): + entry_points = sorted(ep.name for ep in ibis.util.backend_entry_points()) + return [ + (backend, getattr(ibis, backend)) + for backend in entry_points + if backend not in exclude + ] + + +def get_leaf_classes(op): + for child_class in op.__subclasses__(): + if not child_class.__subclasses__(): + yield child_class + else: + yield from get_leaf_classes(child_class) + + +public_ops = frozenset(get_leaf_classes(ops.Value)) +support = {"Operation": [f"{op.__module__}.{op.__name__}" for op in public_ops]} +support.update( + (name, list(map(backend.has_operation, public_ops))) + for name, backend in get_backends() +) + +support_matrix = ( + pd.DataFrame(support) + .assign(splits=lambda df: df.Operation.str.findall("[a-zA-Z_][a-zA-Z_0-9]*")) + .assign( + Category=lambda df: df.splits.str[-2], + Operation=lambda df: df.splits.str[-1], + ) + .drop(["splits"], axis=1) + .set_index(["Category", "Operation"]) + .sort_index() +) all_visible_ops_count = len(support_matrix) coverage = pd.Index( support_matrix.sum() @@ -70,15 +101,16 @@ dict( #| content: valuebox #| title: "Number of SQL backends" import importlib -from ibis.backends.base.sql import BaseSQLBackend +from ibis.backends.base.sqlglot import SQLGlotBackend sql_backends = sum( issubclass( importlib.import_module(f"ibis.backends.{entry_point.name}").Backend, - BaseSQLBackend + SQLGlotBackend ) for entry_point in ibis.util.backend_entry_points() ) +assert sql_backends > 0 dict(value=sql_backends, color="green", icon="database") ``` diff --git a/gen_matrix.py b/gen_matrix.py deleted file mode 100644 index 9f9745cb72391..0000000000000 --- a/gen_matrix.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import pandas as pd - -import ibis -import ibis.expr.operations as ops - - -def get_backends(exclude=()): - entry_points = sorted(ep.name for ep in ibis.util.backend_entry_points()) - return [ - (backend, getattr(ibis, backend)) - for backend in entry_points - if backend not in exclude - ] - - -def get_leaf_classes(op): - for child_class in op.__subclasses__(): - if not child_class.__subclasses__(): - yield child_class - else: - yield from get_leaf_classes(child_class) - - -def main(): - public_ops = frozenset(get_leaf_classes(ops.Value)) - support = {"operation": [f"{op.__module__}.{op.__name__}" for op in public_ops]} - support.update( - (name, list(map(backend.has_operation, public_ops))) - for name, backend in get_backends() - ) - - df = pd.DataFrame(support).set_index("operation").sort_index() - - with Path(ibis.__file__).parents[1].joinpath( - "docs", "backends", "raw_support_matrix.csv" - ).open(mode="w") as f: - df.to_csv(f, index_label="Operation") - - -if __name__ == "__main__": - main() diff --git a/ibis/__init__.py b/ibis/__init__.py index 4baf115674654..8688490a66bda 100644 --- a/ibis/__init__.py +++ b/ibis/__init__.py @@ -97,7 +97,6 @@ def __getattr__(name: str) -> BaseBackend: # - add_operation # - _from_url # - _to_sql - # - _sqlglot_dialect (if defined) # # We also copy over the docstring from `do_connect` to the proxy `connect` # method, since that's where all the backend-specific kwargs are currently @@ -120,8 +119,6 @@ def connect(*args, **kwargs): proxy.name = name proxy._from_url = backend._from_url proxy._to_sql = backend._to_sql - if (dialect := getattr(backend, "_sqlglot_dialect", None)) is not None: - proxy._sqlglot_dialect = dialect # Add any additional methods that should be exposed at the top level for name in getattr(backend, "_top_level_methods", ()): setattr(proxy, name, getattr(backend, name)) diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 2e7d336ebf9d9..075fb94bb0f2f 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -30,23 +30,11 @@ import pandas as pd import pyarrow as pa + import sqlglot as sg import torch __all__ = ("BaseBackend", "Database", "connect") -# TODO(cpcloud): move these to a place that doesn't require importing -# backend-specific dependencies -_IBIS_TO_SQLGLOT_DIALECT = { - "mssql": "tsql", - "impala": "hive", - "pyspark": "spark", - "polars": "postgres", - "datafusion": "postgres", - # closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901 - "exasol": "oracle", - "risingwave": "postgres", -} - class Database: """Generic Database class.""" @@ -805,6 +793,14 @@ def __init__(self, *args, **kwargs): key=lambda expr: expr.op(), ) + @property + @abc.abstractmethod + def dialect(self) -> sg.Dialect | None: + """The sqlglot dialect for this backend, where applicable. + + Returns None if the backend is not a SQL backend. + """ + def __getstate__(self): return dict(_con_args=self._con_args, _con_kwargs=self._con_kwargs) @@ -1272,15 +1268,11 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: # only transpile if the backend dialect doesn't match the input dialect name = self.name - if (output_dialect := getattr(self, "_sqlglot_dialect", name)) is None: + if (output_dialect := self.dialect) is None: raise NotImplementedError(f"No known sqlglot dialect for backend {name}") if dialect != output_dialect: - (query,) = sg.transpile( - query, - read=_IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect), - write=output_dialect, - ) + (query,) = sg.transpile(query, read=dialect, write=output_dialect) return query diff --git a/ibis/backends/base/sql/__init__.py b/ibis/backends/base/sql/__init__.py index ddece41f926de..9e08e0c518ba2 100644 --- a/ibis/backends/base/sql/__init__.py +++ b/ibis/backends/base/sql/__init__.py @@ -29,10 +29,6 @@ class BaseSQLBackend(BaseBackend): compiler = Compiler - @property - def _sqlglot_dialect(self) -> str: - return self.name - def _from_url(self, url: str, **kwargs): """Connect to a backend using a URL `url`. diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index f52942ec50e54..47c715129cf2f 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -30,7 +30,7 @@ class SQLGlotBackend(BaseBackend): name: ClassVar[str] @property - def _sqlglot_dialect(self) -> str: + def dialect(self) -> sg.Dialect: return self.compiler.dialect @classmethod @@ -115,7 +115,7 @@ def compile( ): """Compile an Ibis expression to a SQL string.""" query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) - sql = query.sql(dialect=self.compiler.dialect, pretty=True) + sql = query.sql(dialect=self.dialect, pretty=True) self._log(sql) return sql @@ -380,6 +380,6 @@ def truncate_table( """ ident = sg.table( name, db=schema, catalog=database, quoted=self.compiler.quoted - ).sql(self.compiler.dialect) + ).sql(self.dialect) with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index bb9e6447d093d..dd9827d2a5498 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -942,3 +942,90 @@ def _from_sqlglot_MAP(cls) -> sge.DataType: @classmethod def _from_sqlglot_STRUCT(cls) -> sge.DataType: raise com.UnsupportedBackendType("SQL Server does not support structs") + + +class ClickHouseType(SqlglotType): + dialect = "clickhouse" + default_decimal_precision = None + default_decimal_scale = None + default_nullable = False + + unknown_type_strings = FrozenDict( + { + "ipv4": dt.INET(nullable=default_nullable), + "ipv6": dt.INET(nullable=default_nullable), + "object('json')": dt.JSON(nullable=default_nullable), + "array(null)": dt.Array(dt.null, nullable=default_nullable), + "array(nothing)": dt.Array(dt.null, nullable=default_nullable), + } + ) + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: + """Convert a sqlglot type to an ibis type.""" + typ = super().from_ibis(dtype) + if dtype.nullable and not dtype.is_map(): + # map cannot be nullable in clickhouse + return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) + else: + return typ + + @classmethod + def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType: + return cls.to_ibis(inner_type, nullable=True) + + @classmethod + def _from_sqlglot_DATETIME( + cls, timezone: sge.DataTypeParam | None = None + ) -> dt.Timestamp: + return dt.Timestamp( + scale=0, + timezone=None if timezone is None else timezone.this.this, + nullable=cls.default_nullable, + ) + + @classmethod + def _from_sqlglot_DATETIME64( + cls, + scale: sge.DataTypeSize | None = None, + timezone: sge.Literal | None = None, + ) -> dt.Timestamp: + return dt.Timestamp( + timezone=None if timezone is None else timezone.this.this, + scale=int(scale.this.this), + nullable=cls.default_nullable, + ) + + @classmethod + def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType: + return cls.to_ibis(inner_type) + + @classmethod + def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct: + fields = { + field.name: dt.Array( + cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable + ) + for field in fields + } + return dt.Struct(fields, nullable=cls.default_nullable) + + @classmethod + def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: + if dtype.timezone is None: + timezone = None + else: + timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone)) + + if dtype.scale is None: + return sge.DataType(this=typecode.DATETIME, expressions=[timezone]) + else: + scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale)) + return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone]) + + @classmethod + def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: + # key cannot be nullable in clickhouse + key_type = cls.from_ibis(dtype.key_type.copy(nullable=False)) + value_type = cls.from_ibis(dtype.value_type) + return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type]) diff --git a/ibis/backends/base/sqlglot/dialects.py b/ibis/backends/base/sqlglot/dialects.py new file mode 100644 index 0000000000000..74a603b268ba8 --- /dev/null +++ b/ibis/backends/base/sqlglot/dialects.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import contextlib + +import sqlglot.expressions as sge +from sqlglot import transforms +from sqlglot.dialects import ( + TSQL, + ClickHouse, + Hive, + MySQL, + Oracle, + Postgres, + Snowflake, + Spark, + SQLite, + Trino, +) +from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func + +ClickHouse.Generator.TRANSFORMS |= { + sge.ArraySize: rename_func("length"), + sge.ArraySort: rename_func("arraySort"), + sge.LogicalAnd: rename_func("min"), + sge.LogicalOr: rename_func("max"), +} + + +class DataFusion(Postgres): + class Generator(Postgres.Generator): + TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { + sge.Select: transforms.preprocess([transforms.eliminate_qualify]), + sge.Pow: rename_func("pow"), + sge.IsNan: rename_func("isnan"), + sge.CurrentTimestamp: rename_func("now"), + sge.Split: rename_func("string_to_array"), + sge.Array: rename_func("make_array"), + sge.ArrayContains: rename_func("array_has"), + sge.ArraySize: rename_func("array_length"), + } + + +class Druid(Postgres): + class Generator(Postgres.Generator): + TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { + sge.ApproxDistinct: rename_func("approx_count_distinct"), + sge.Pow: rename_func("power"), + } + + +def _interval(self, e, quote_arg=True): + """Work around the inability to handle string literals in INTERVAL syntax.""" + arg = e.args["this"].this + with contextlib.suppress(AttributeError): + arg = arg.sql(self.dialect) + + if quote_arg: + arg = f"'{arg}'" + + return f"INTERVAL {arg} {e.args['unit']}" + + +class Exasol(Postgres): + class Generator(Postgres.Generator): + TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | {sge.Interval: _interval} + TYPE_MAPPING = Postgres.Generator.TYPE_MAPPING.copy() | { + sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP WITH LOCAL TIME ZONE", + } + + +class Flink(Hive): + class Generator(Hive.Generator): + TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | { + sge.Stddev: rename_func("stddev_samp"), + sge.StddevPop: rename_func("stddev_pop"), + sge.StddevSamp: rename_func("stddev_samp"), + sge.Variance: rename_func("var_samp"), + sge.VariancePop: rename_func("var_pop"), + sge.Array: ( + lambda self, + e: f"ARRAY[{', '.join(arg.sql(self.dialect) for arg in e.expressions)}]" + ), + sge.ArrayConcat: rename_func("array_concat"), + sge.Length: rename_func("char_length"), + } + + +class Impala(Hive): + class Generator(Hive.Generator): + TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | { + sge.ApproxDistinct: rename_func("ndv"), + sge.IsNan: rename_func("is_nan"), + sge.IsInf: rename_func("is_inf"), + sge.DayOfWeek: rename_func("dayofweek"), + sge.Interval: lambda self, e: _interval(self, e, quote_arg=False), + } + + +class MSSQL(TSQL): + class Generator(TSQL.Generator): + TRANSFORMS = TSQL.Generator.TRANSFORMS.copy() | { + sge.ApproxDistinct: rename_func("approx_count_distinct"), + sge.Stddev: rename_func("stdevp"), + sge.StddevPop: rename_func("stdevp"), + sge.StddevSamp: rename_func("stdev"), + sge.Variance: rename_func("var"), + sge.VariancePop: rename_func("varp"), + sge.Ceil: rename_func("ceiling"), + sge.Trim: lambda self, e: f"TRIM({e.this.sql(self.dialect)})", + sge.DateFromParts: rename_func("datefromparts"), + } + + +MySQL.Generator.TRANSFORMS |= { + sge.LogicalOr: rename_func("max"), + sge.LogicalAnd: rename_func("min"), + sge.VariancePop: rename_func("var_pop"), + sge.Variance: rename_func("var_samp"), + sge.Stddev: rename_func("stddev_pop"), + sge.StddevPop: rename_func("stddev_pop"), + sge.StddevSamp: rename_func("stddev_samp"), + sge.RegexpLike: ( + lambda _, e: f"({e.this.sql('mysql')} RLIKE {e.expression.sql('mysql')})" + ), +} + + +def _create_sql(self, expression: sge.Create) -> str: + # TODO: should we use CREATE PRIVATE instead? That will set an implicit + # lower bound of Oracle 18c + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, sge.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + kind = expression.args["kind"] + if (obj := kind.upper()) in ("TABLE", "VIEW") and temporary: + if expression.expression: + return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" + else: + # TODO: why does autocommit not work here? need to specify the ON COMMIT part... + return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} ON COMMIT PRESERVE ROWS" + + return create_with_partitions_sql(self, expression) + + +Oracle.Generator.TRANSFORMS |= { + sge.LogicalOr: rename_func("max"), + sge.LogicalAnd: rename_func("min"), + sge.VariancePop: rename_func("var_pop"), + sge.Variance: rename_func("var_samp"), + sge.Stddev: rename_func("stddev_pop"), + sge.ApproxDistinct: rename_func("approx_count_distinct"), + sge.Create: _create_sql, + sge.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), +} + +# TODO: can delete this after bumping sqlglot version > 20.9.0 +Oracle.Generator.TYPE_MAPPING |= { + sge.DataType.Type.TIMETZ: "TIME", + sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", +} +Oracle.Generator.TZ_TO_WITH_TIME_ZONE = True + + +class Polars(Postgres): + """Subclass of Postgres dialect for Polars. + + This is here to allow referring to the Postgres dialect as "polars" + """ + + +Postgres.Generator.TRANSFORMS |= { + sge.Map: rename_func("hstore"), + sge.Split: rename_func("string_to_array"), + sge.RegexpSplit: rename_func("regexp_split_to_array"), + sge.DateFromParts: rename_func("make_date"), + sge.ArraySize: rename_func("cardinality"), + sge.Pow: rename_func("pow"), +} + + +class PySpark(Spark): + """Subclass of Spark dialect for PySpark. + + This is here to allow referring to the Spark dialect as "pyspark" + """ + + +class RisingWave(Postgres): + # Need to disable timestamp precision + # No "or replace" allowed in create statements + # no "not null" clause for column constraints + + class Generator(Postgres.Generator): + SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False + LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False + QUERY_HINTS = False + NVL2_SUPPORTED = False + PARAMETER_TOKEN = "$" + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True + JSON_TYPE_REQUIRED_FOR_EXTRACTION = True + SUPPORTS_UNLOGGED_TABLES = True + + TYPE_MAPPING = Postgres.Generator.TYPE_MAPPING.copy() | { + sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMPTZ" + } + + +Snowflake.Generator.TRANSFORMS |= { + sge.ApproxDistinct: rename_func("approx_count_distinct"), + sge.Levenshtein: rename_func("editdistance"), +} + +SQLite.Generator.TYPE_MAPPING |= {sge.DataType.Type.BOOLEAN: "BOOLEAN"} + + +# TODO(cpcloud): remove this hack once +# https://github.com/tobymao/sqlglot/issues/2735 is resolved +def make_cross_joins_explicit(node): + if not (node.kind or node.side): + node.args["kind"] = "CROSS" + return node + + +Trino.Generator.TRANSFORMS |= { + sge.BitwiseLeftShift: rename_func("bitwise_left_shift"), + sge.BitwiseRightShift: rename_func("bitwise_right_shift"), + sge.Join: transforms.preprocess([make_cross_joins_explicit]), +} diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index cbfb77e10603f..43fe70af5c8a8 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -660,7 +660,7 @@ def raw_sql(self, query: str, params=None): for param, value in (params or {}).items() ] with contextlib.suppress(AttributeError): - query = query.sql(self.compiler.dialect) + query = query.sql(self.dialect) return self._execute(query, query_parameters=query_parameters) @property diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index c70660a7f739c..ada8ddf901c9d 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -7,6 +7,7 @@ import sqlglot as sg import sqlglot.expressions as sge +from sqlglot.dialects import BigQuery import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -28,7 +29,7 @@ class BigQueryCompiler(SQLGlotCompiler): - dialect = "bigquery" + dialect = BigQuery type_mapper = BigQueryType udf_type_mapper = BigQueryUDFType rewrites = ( diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 8142110397e34..9e6cde739f1b4 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -27,7 +27,6 @@ from ibis.backends.base.sqlglot import SQLGlotBackend from ibis.backends.base.sqlglot.compiler import C from ibis.backends.clickhouse.compiler import ClickHouseCompiler -from ibis.backends.clickhouse.datatypes import ClickhouseType if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -221,13 +220,14 @@ def _normalize_external_tables(self, external_tables=None) -> ExternalData | Non """Merge registered external tables with any new external tables.""" external_data = ExternalData() n = 0 + type_mapper = self.compiler.type_mapper for name, obj in (external_tables or {}).items(): n += 1 if not (schema := obj.schema): raise TypeError(f"Schema is empty for external table {name}") structure = [ - f"{name} {ClickhouseType.to_string(typ.copy(nullable=not typ.is_nested()))}" + f"{name} {type_mapper.to_string(typ.copy(nullable=not typ.is_nested()))}" for name, typ in schema.items() ] external_data.add_file( @@ -478,7 +478,9 @@ def get_schema( query = sge.Describe(this=sg.table(table_name, db=database)) with self._safe_raw_sql(query) as results: names, types, *_ = results.result_columns - return sch.Schema(dict(zip(names, map(ClickhouseType.from_string, types)))) + return sch.Schema( + dict(zip(names, map(self.compiler.type_mapper.from_string, types))) + ) def _metadata(self, query: str) -> sch.Schema: name = util.gen_name("clickhouse_metadata") @@ -490,7 +492,7 @@ def _metadata(self, query: str) -> sch.Schema: finally: with closing(self.raw_sql(f"DROP VIEW {name}")): pass - return zip(names, map(ClickhouseType.from_string, types)) + return zip(names, map(self.compiler.type_mapper.from_string, types)) def create_database( self, name: str, *, force: bool = False, engine: str = "Atomic" @@ -637,7 +639,8 @@ def create_table( this=sg.table(name, db=database), expressions=[ sge.ColumnDef( - this=sg.to_identifier(name), kind=ClickhouseType.from_ibis(typ) + this=sg.to_identifier(name), + kind=self.compiler.type_mapper.from_ibis(typ), ) for name, typ in schema.items() ], diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index b561d841f8bda..78dd25b012d2a 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -7,9 +7,6 @@ import sqlglot as sg import sqlglot.expressions as sge -from sqlglot import exp -from sqlglot.dialects import ClickHouse -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -21,22 +18,16 @@ SQLGlotCompiler, parenthesize, ) +from ibis.backends.base.sqlglot.datatypes import ClickHouseType +from ibis.backends.base.sqlglot.dialects import ClickHouse from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter -from ibis.backends.clickhouse.datatypes import ClickhouseType - -ClickHouse.Generator.TRANSFORMS |= { - exp.ArraySize: rename_func("length"), - exp.ArraySort: rename_func("arraySort"), - exp.LogicalAnd: rename_func("min"), - exp.LogicalOr: rename_func("max"), -} class ClickHouseCompiler(SQLGlotCompiler): __slots__ = () - dialect = "clickhouse" - type_mapper = ClickhouseType + dialect = ClickHouse + type_mapper = ClickHouseType rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): diff --git a/ibis/backends/clickhouse/datatypes.py b/ibis/backends/clickhouse/datatypes.py deleted file mode 100644 index 7250bafd4c0fb..0000000000000 --- a/ibis/backends/clickhouse/datatypes.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from typing import Literal - -import sqlglot.expressions as sge - -import ibis -import ibis.expr.datatypes as dt -from ibis.backends.base.sqlglot.datatypes import SqlglotType -from ibis.common.collections import FrozenDict - -typecode = sge.DataType.Type - - -# TODO(kszucs): add a bool converter method to support different clickhouse bool types -def _bool_type() -> Literal["Bool", "UInt8", "Int8"]: - return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool") - - -class ClickhouseType(SqlglotType): - dialect = "clickhouse" - default_decimal_precision = None - default_decimal_scale = None - default_nullable = False - - unknown_type_strings = FrozenDict( - { - "ipv4": dt.INET(nullable=default_nullable), - "ipv6": dt.INET(nullable=default_nullable), - "object('json')": dt.JSON(nullable=default_nullable), - "array(null)": dt.Array(dt.null, nullable=default_nullable), - "array(nothing)": dt.Array(dt.null, nullable=default_nullable), - } - ) - - @classmethod - def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: - """Convert a sqlglot type to an ibis type.""" - typ = super().from_ibis(dtype) - if dtype.nullable and not dtype.is_map(): - # map cannot be nullable in clickhouse - return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) - else: - return typ - - @classmethod - def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType: - return cls.to_ibis(inner_type, nullable=True) - - @classmethod - def _from_sqlglot_DATETIME( - cls, timezone: sge.DataTypeParam | None = None - ) -> dt.Timestamp: - return dt.Timestamp( - scale=0, - timezone=None if timezone is None else timezone.this.this, - nullable=cls.default_nullable, - ) - - @classmethod - def _from_sqlglot_DATETIME64( - cls, - scale: sge.DataTypeSize | None = None, - timezone: sge.Literal | None = None, - ) -> dt.Timestamp: - return dt.Timestamp( - timezone=None if timezone is None else timezone.this.this, - scale=int(scale.this.this), - nullable=cls.default_nullable, - ) - - @classmethod - def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType: - return cls.to_ibis(inner_type) - - @classmethod - def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct: - fields = { - field.name: dt.Array( - cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable - ) - for field in fields - } - return dt.Struct(fields, nullable=cls.default_nullable) - - @classmethod - def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: - if dtype.timezone is None: - timezone = None - else: - timezone = sge.DataTypeParam(this=sge.Literal.string(dtype.timezone)) - - if dtype.scale is None: - return sge.DataType(this=typecode.DATETIME, expressions=[timezone]) - else: - scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale)) - return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone]) - - @classmethod - def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: - # key cannot be nullable in clickhouse - key_type = cls.from_ibis(dtype.key_type.copy(nullable=False)) - value_type = cls.from_ibis(dtype.value_type) - return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type]) diff --git a/ibis/backends/clickhouse/tests/test_datatypes.py b/ibis/backends/clickhouse/tests/test_datatypes.py index b5190a3d7d674..69c10c432798d 100644 --- a/ibis/backends/clickhouse/tests/test_datatypes.py +++ b/ibis/backends/clickhouse/tests/test_datatypes.py @@ -8,7 +8,7 @@ import ibis import ibis.expr.datatypes as dt import ibis.tests.strategies as its -from ibis.backends.clickhouse.datatypes import ClickhouseType +from ibis.backends.base.sqlglot.datatypes import ClickHouseType pytest.importorskip("clickhouse_connect") @@ -250,7 +250,7 @@ def test_array_discovery_clickhouse(con): ], ) def test_parse_type(ch_type, ibis_type): - parsed_ibis_type = ClickhouseType.from_string(ch_type) + parsed_ibis_type = ClickHouseType.from_string(ch_type) assert parsed_ibis_type == ibis_type @@ -286,6 +286,6 @@ def test_parse_type(ch_type, ibis_type): @h.given(roundtrippable_types) def test_type_roundtrip(ibis_type): - type_string = ClickhouseType.to_string(ibis_type) - parsed_ibis_type = ClickhouseType.from_string(type_string) + type_string = ClickHouseType.to_string(ibis_type) + parsed_ibis_type = ClickHouseType.from_string(type_string) assert parsed_ibis_type == ibis_type diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index bad07522d99fb..44946ac3abf51 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -46,8 +46,6 @@ class Backend(SQLGlotBackend, CanCreateDatabase, CanCreateSchema, NoUrl): name = "datafusion" - dialect = "datafusion" - builder = None supports_in_memory_tables = True supports_arrays = True compiler = DataFusionCompiler() diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 84b4ae5dc3e03..173c34dbef0a8 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -7,9 +7,6 @@ import sqlglot as sg import sqlglot.expressions as sge -from sqlglot import exp, transforms -from sqlglot.dialects import Postgres -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -22,32 +19,18 @@ paren, ) from ibis.backends.base.sqlglot.datatypes import DataFusionType +from ibis.backends.base.sqlglot.dialects import DataFusion from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType -class DataFusion(Postgres): - class Generator(Postgres.Generator): - TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { - exp.Select: transforms.preprocess([transforms.eliminate_qualify]), - exp.Pow: rename_func("pow"), - exp.IsNan: rename_func("isnan"), - exp.CurrentTimestamp: rename_func("now"), - exp.Split: rename_func("string_to_array"), - exp.Array: rename_func("make_array"), - exp.ArrayContains: rename_func("array_has"), - exp.ArraySize: rename_func("array_length"), - } - - class DataFusionCompiler(SQLGlotCompiler): __slots__ = () - dialect = "datafusion" + dialect = DataFusion type_mapper = DataFusionType - quoted = True rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): diff --git a/ibis/backends/druid/__init__.py b/ibis/backends/druid/__init__.py index fe68977c8a41a..411c83a4866d4 100644 --- a/ibis/backends/druid/__init__.py +++ b/ibis/backends/druid/__init__.py @@ -91,7 +91,7 @@ def do_connect(self, **kwargs: Any) -> None: @contextlib.contextmanager def _safe_raw_sql(self, query, *args, **kwargs): with contextlib.suppress(AttributeError): - query = query.sql(dialect=self.compiler.dialect) + query = query.sql(dialect=self.dialect) with contextlib.closing(self.con.cursor()) as cur: cur.execute(query, *args, **kwargs) @@ -116,7 +116,7 @@ def get_schema( name_type_pairs = self._metadata( sg.select(STAR) .from_(sg.table(table_name, db=schema, catalog=database)) - .sql(self.compiler.dialect) + .sql(self.dialect) ) return sch.Schema.from_tuples(name_type_pairs) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 31a60f7e85be4..c0661cd4cf715 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -5,35 +5,21 @@ import sqlglot as sg import sqlglot.expressions as sge import toolz -from sqlglot import exp -from sqlglot.dialects import Postgres -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import DruidType +from ibis.backends.base.sqlglot.dialects import Druid from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter -# Is postgres the best dialect to inherit from? -class Druid(Postgres): - """The druid dialect.""" - - class Generator(Postgres.Generator): - TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { - exp.ApproxDistinct: rename_func("approx_count_distinct"), - exp.Pow: rename_func("power"), - } - - class DruidCompiler(SQLGlotCompiler): __slots__ = () - dialect = "druid" + dialect = Druid type_mapper = DruidType - quoted = True rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 221a3b281ba00..f760f0a21b63f 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -27,7 +27,7 @@ from ibis.backends.base.sqlglot import SQLGlotBackend from ibis.backends.base.sqlglot.compiler import STAR, C, F from ibis.backends.duckdb.compiler import DuckDBCompiler -from ibis.backends.duckdb.datatypes import DuckDBPandasData +from ibis.backends.duckdb.converter import DuckDBPandasData from ibis.expr.operations.udf import InputType if TYPE_CHECKING: diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index ccd1b6f2191f3..598018823aab5 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -6,6 +6,7 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public +from sqlglot.dialects import DuckDB import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -29,7 +30,7 @@ class DuckDBCompiler(SQLGlotCompiler): __slots__ = () - dialect = "duckdb" + dialect = DuckDB type_mapper = DuckDBType def _aggregate(self, funcname: str, *args, where): diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/converter.py similarity index 100% rename from ibis/backends/duckdb/datatypes.py rename to ibis/backends/duckdb/converter.py diff --git a/ibis/backends/exasol/__init__.py b/ibis/backends/exasol/__init__.py index 39ad83ab8c188..f72b1d046dc02 100644 --- a/ibis/backends/exasol/__init__.py +++ b/ibis/backends/exasol/__init__.py @@ -131,7 +131,7 @@ def begin(self): @contextlib.contextmanager def _safe_raw_sql(self, query: str, *args, **kwargs): with contextlib.suppress(AttributeError): - query = query.sql(dialect=self.compiler.dialect) + query = query.sql(dialect=self.dialect) with self.begin() as cur: yield cur.execute(query, *args, **kwargs) @@ -165,7 +165,7 @@ def get_schema( table_name, db=schema, catalog=database, quoted=self.compiler.quoted ) ) - .sql(self.compiler.dialect) + .sql(self.dialect) ) return sch.Schema.from_tuples(name_type_pairs) @@ -180,7 +180,7 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: table = sg.table(util.gen_name("exasol_metadata"), quoted=self.compiler.quoted) - dialect = self.compiler.dialect + dialect = self.dialect create_view = sg.exp.Create( kind="VIEW", this=table, @@ -385,7 +385,7 @@ def drop_schema( ) drop_schema = sg.exp.Drop(kind="SCHEMA", this=name, exists=force) with self.begin() as con: - con.execute(drop_schema.sql(dialect=self.compiler.dialect)) + con.execute(drop_schema.sql(dialect=self.dialect)) def create_schema( self, name: str, database: str | None = None, force: bool = False @@ -397,7 +397,7 @@ def create_schema( create_schema = sg.exp.Create(kind="SCHEMA", this=name, exists=force) open_schema = self.current_schema with self.begin() as con: - con.execute(create_schema.sql(dialect=self.compiler.dialect)) + con.execute(create_schema.sql(dialect=self.dialect)) # Exasol implicitly opens the created schema, therefore we need to restore # the previous context. con.execute( diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 071e597c6f05b..0be420b225b09 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -1,16 +1,15 @@ from __future__ import annotations -import contextlib from functools import singledispatchmethod import sqlglot.expressions as sge -from sqlglot.dialects import Postgres import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import ExasolType +from ibis.backends.base.sqlglot.dialects import Exasol from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, @@ -20,35 +19,11 @@ ) -def _interval(self, e): - """Work around Exasol's inability to handle string literals in INTERVAL syntax.""" - arg = e.args["this"].this - with contextlib.suppress(AttributeError): - arg = arg.sql(self.dialect) - res = f"INTERVAL '{arg}' {e.args['unit']}" - return res - - -# Is postgres the best dialect to inherit from? -class Exasol(Postgres): - """The exasol dialect.""" - - class Generator(Postgres.Generator): - TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | { - sge.Interval: _interval, - } - - TYPE_MAPPING = Postgres.Generator.TYPE_MAPPING.copy() | { - sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP WITH LOCAL TIME ZONE", - } - - class ExasolCompiler(SQLGlotCompiler): __slots__ = () - dialect = "exasol" + dialect = Exasol type_mapper = ExasolType - quoted = True rewrites = ( rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 2ae904f505ba5..679c7680de62c 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -44,6 +44,11 @@ class Backend(BaseBackend, CanCreateDatabase, NoUrl): supports_temporary_tables = True supports_python_udfs = True + @property + def dialect(self): + # TODO: remove when ported to sqlglot + return self.compiler.dialect + def do_connect(self, table_env: TableEnvironment) -> None: """Create a Flink `Backend` for use with Ibis. diff --git a/ibis/backends/flink/compiler/core.py b/ibis/backends/flink/compiler/core.py index f5d4c37748000..be187ebdb664a 100644 --- a/ibis/backends/flink/compiler/core.py +++ b/ibis/backends/flink/compiler/core.py @@ -14,6 +14,7 @@ TableSetFormatter, ) from ibis.backends.base.sql.registry import quote_identifier +from ibis.backends.base.sqlglot.dialects import Flink from ibis.backends.flink.translator import FlinkExprTranslator @@ -96,6 +97,8 @@ class FlinkCompiler(Compiler): cheap_in_memory_tables = True + dialect = Flink + @classmethod def to_sql(cls, node, context=None, params=None): if isinstance(node, ir.Expr): diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index 1c8833d208978..9be53bd1532e3 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -69,8 +69,6 @@ class Backend(SQLGlotBackend): supports_in_memory_tables = True - _sqlglot_dialect = "hive" # not 100% accurate, but very close - class Options(ibis.config.Config): """Impala specific options. @@ -269,7 +267,7 @@ def _fetch_from_cursor(self, cursor, schema): def _safe_raw_sql(self, query: str | DDL | DML): if not isinstance(query, str): try: - query = query.sql(dialect=self.compiler.dialect) + query = query.sql(dialect=self.dialect) except AttributeError: query = query.compile() @@ -284,7 +282,7 @@ def _safe_exec_sql(self, *args, **kwargs): def _fully_qualified_name(self, name, database): database = database or self.current_database return sg.table(name, db=database, quoted=self.compiler.quoted).sql( - self.compiler.dialect + self.dialect ) @property diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index 73840a47cc447..f20f48095a5c1 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -1,12 +1,9 @@ from __future__ import annotations -import contextlib from functools import singledispatchmethod import sqlglot as sg import sqlglot.expressions as sge -from sqlglot.dialects import Hive -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -14,6 +11,7 @@ from ibis import util from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import ImpalaType +from ibis.backends.base.sqlglot.dialects import Impala from ibis.backends.base.sqlglot.rewrites import ( rewrite_empty_order_by_window, rewrite_first_to_first_value, @@ -22,30 +20,10 @@ ) -def _interval(self, e): - """Work around Impala's inability to handle string literals in INTERVAL syntax.""" - arg = e.args["this"].this - with contextlib.suppress(AttributeError): - arg = arg.sql(self.dialect) - res = f"INTERVAL {arg} {e.args['unit']}" - return res - - -class Impala(Hive): - class Generator(Hive.Generator): - TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | { - sge.ApproxDistinct: rename_func("ndv"), - sge.IsNan: rename_func("is_nan"), - sge.IsInf: rename_func("is_inf"), - sge.DayOfWeek: rename_func("dayofweek"), - sge.Interval: _interval, - } - - class ImpalaCompiler(SQLGlotCompiler): __slots__ = () - dialect = "impala" + dialect = Impala type_mapper = ImpalaType rewrites = ( rewrite_sample_as_filter, @@ -54,7 +32,6 @@ class ImpalaCompiler(SQLGlotCompiler): rewrite_empty_order_by_window, *SQLGlotCompiler.rewrites, ) - quoted = True def _aggregate(self, funcname: str, *args, where): if where is not None: diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_day_of_week/index/out.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_day_of_week/index/out.sql index 9a26667a72fab..3a6964d17de46 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_day_of_week/index/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_day_of_week/index/out.sql @@ -1,2 +1,2 @@ SELECT - PMOD(DAY_OF_WEEK('2015-09-01T01:00:23') - 2, 7) AS `DayOfWeekIndex(datetime.datetime(2015, 9, 1, 1, 0, 23))` \ No newline at end of file + PMOD(DAYOFWEEK('2015-09-01T01:00:23') - 2, 7) AS `DayOfWeekIndex(datetime.datetime(2015, 9, 1, 1, 0, 23))` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out1.sql index 7237d7b9004d4..b141f4686f278 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' DAY AS TIMESTAMP) AS `TimestampAdd(i, 5D)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 DAY AS TIMESTAMP) AS `TimestampAdd(i, 5D)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out2.sql index eb169d991d790..f9070d84c6657 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/days/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' DAY AS `TimestampSub(i, 5D)` + `t0`.`i` - INTERVAL 5 DAY AS `TimestampSub(i, 5D)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out1.sql index 328d2ef5a820a..d7e1723c5a49f 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' HOUR AS TIMESTAMP) AS `TimestampAdd(i, 5h)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 HOUR AS TIMESTAMP) AS `TimestampAdd(i, 5h)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out2.sql index 86951b1a52ab0..d8600b726ebda 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/hours/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' HOUR AS `TimestampSub(i, 5h)` + `t0`.`i` - INTERVAL 5 HOUR AS `TimestampSub(i, 5h)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out1.sql index a97e1fda3fb7a..a3c3b3497cd6d 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' MINUTE AS TIMESTAMP) AS `TimestampAdd(i, 5m)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 MINUTE AS TIMESTAMP) AS `TimestampAdd(i, 5m)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out2.sql index 9646ba677e2f0..1ddb2eb2cdd0a 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/minutes/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' MINUTE AS `TimestampSub(i, 5m)` + `t0`.`i` - INTERVAL 5 MINUTE AS `TimestampSub(i, 5m)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out1.sql index 2f2d23da1686e..cfb1da6d1fbdf 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' MONTH AS TIMESTAMP) AS `TimestampAdd(i, 5M)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 MONTH AS TIMESTAMP) AS `TimestampAdd(i, 5M)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out2.sql index fc1911b370cb6..0a33736acef4f 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/months/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' MONTH AS `TimestampSub(i, 5M)` + `t0`.`i` - INTERVAL 5 MONTH AS `TimestampSub(i, 5M)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out1.sql index 9fb8b0686200b..7012f22bb9a30 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' SECOND AS TIMESTAMP) AS `TimestampAdd(i, 5s)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 SECOND AS TIMESTAMP) AS `TimestampAdd(i, 5s)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out2.sql index 8f89c57c0e983..8bd856c63fe80 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/seconds/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' SECOND AS `TimestampSub(i, 5s)` + `t0`.`i` - INTERVAL 5 SECOND AS `TimestampSub(i, 5s)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out1.sql index 2ee339b0a84ee..3bea56a27a7c7 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' WEEK AS TIMESTAMP) AS `TimestampAdd(i, 5W)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 WEEK AS TIMESTAMP) AS `TimestampAdd(i, 5W)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out2.sql index 92b9cfeda5bac..bcc3e272b39be 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/weeks/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' WEEK AS `TimestampSub(i, 5W)` + `t0`.`i` - INTERVAL 5 WEEK AS `TimestampSub(i, 5W)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out1.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out1.sql index b0bbfd8bd6ca9..92d9010f536fd 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out1.sql @@ -1,3 +1,3 @@ SELECT - CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL '5' YEAR AS TIMESTAMP) AS `TimestampAdd(i, 5Y)` + CAST(CAST(`t0`.`i` AS TIMESTAMP) + INTERVAL 5 YEAR AS TIMESTAMP) AS `TimestampAdd(i, 5Y)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out2.sql b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out2.sql index 0778cd8b068ef..64c3c19a6c816 100644 --- a/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out2.sql +++ b/ibis/backends/impala/tests/snapshots/test_value_exprs/test_timestamp_deltas/years/out2.sql @@ -1,3 +1,3 @@ SELECT - `t0`.`i` - INTERVAL '5' YEAR AS `TimestampSub(i, 5Y)` + `t0`.`i` - INTERVAL 5 YEAR AS `TimestampSub(i, 5Y)` FROM `alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/mssql/__init__.py b/ibis/backends/mssql/__init__.py index 9f7f2315c5f64..bbedea3c0e3aa 100644 --- a/ibis/backends/mssql/__init__.py +++ b/ibis/backends/mssql/__init__.py @@ -126,7 +126,7 @@ def get_schema( meta = cur.fetchall() if not meta: - fqn = sg.table(name, db=schema, catalog=database).sql(self.compiler.dialect) + fqn = sg.table(name, db=schema, catalog=database).sql(self.dialect) raise com.IbisError(f"Table not found: {fqn}") mapping = {} @@ -154,7 +154,7 @@ def get_schema( return sch.Schema(mapping) def _metadata(self, query) -> Iterator[tuple[str, dt.DataType]]: - tsql = sge.convert(str(query)).sql(self.compiler.dialect) + tsql = sge.convert(str(query)).sql(self.dialect) query = f"EXEC sp_describe_first_result_set @tsql = N{tsql}" with self._safe_raw_sql(query) as cur: rows = cur.fetchall() @@ -221,7 +221,7 @@ def begin(self): @contextlib.contextmanager def _safe_raw_sql(self, query, *args, **kwargs): with contextlib.suppress(AttributeError): - query = query.sql(self.compiler.dialect) + query = query.sql(self.dialect) with self.begin() as cur: cur.execute(query, *args, **kwargs) @@ -229,7 +229,7 @@ def _safe_raw_sql(self, query, *args, **kwargs): def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: with contextlib.suppress(AttributeError): - query = query.sql(self.compiler.dialect) + query = query.sql(self.dialect) con = self.con cursor = con.cursor() @@ -295,7 +295,7 @@ def create_schema( cur.execute(f"USE {self._quote(current_database)}") def _quote(self, name: str): - return sg.to_identifier(name, quoted=True).sql(self.compiler.dialect) + return sg.to_identifier(name, quoted=True).sql(self.dialect) def drop_schema( self, name: str, database: str | None = None, force: bool = False @@ -342,7 +342,7 @@ def list_tables( if conditions: sql = sql.where(*conditions) - sql = sql.sql(self.compiler.dialect) + sql = sql.sql(self.dialect) with self._safe_raw_sql(sql) as cur: out = cur.fetchall() @@ -432,19 +432,15 @@ def create_table( raw_this = sg.table(name, catalog=database, quoted=False) with self._safe_raw_sql(create_stmt) as cur: if query is not None: - insert_stmt = sge.Insert(this=table, expression=query).sql( - self.compiler.dialect - ) + insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect) cur.execute(insert_stmt) if overwrite: cur.execute( - sge.Drop(kind="TABLE", this=this, exists=True).sql( - self.compiler.dialect - ) + sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect) ) - old = raw_table.sql(self.compiler.dialect) - new = raw_this.sql(self.compiler.dialect) + old = raw_table.sql(self.dialect) + new = raw_this.sql(self.dialect) cur.execute(f"EXEC sp_rename '{old}', '{new}'") if schema is None: @@ -494,14 +490,14 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: df = op.data.to_frame() data = df.itertuples(index=False) cols = ", ".join( - ident.sql(self.compiler.dialect) + ident.sql(self.dialect) for ident in map( partial(sg.to_identifier, quoted=quoted), schema.keys() ) ) specs = ", ".join(repeat("?", len(schema))) table = sg.table(name, quoted=quoted) - sql = f"INSERT INTO {table.sql(self.compiler.dialect)} ({cols}) VALUES ({specs})" + sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})" with self._safe_raw_sql(create_stmt) as cur: if not df.empty: diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 604250ba2b3c7..cb74911200d7f 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -6,8 +6,6 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public -from sqlglot.dialects import TSQL -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -21,6 +19,7 @@ paren, ) from ibis.backends.base.sqlglot.datatypes import MSSQLType +from ibis.backends.base.sqlglot.dialects import MSSQL from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, @@ -30,24 +29,6 @@ ) from ibis.common.deferred import var - -class MSSQL(TSQL): - class Generator(TSQL.Generator): - pass - - -MSSQL.Generator.TRANSFORMS |= { - sge.ApproxDistinct: rename_func("approx_count_distinct"), - sge.Stddev: rename_func("stdevp"), - sge.StddevPop: rename_func("stdevp"), - sge.StddevSamp: rename_func("stdev"), - sge.Variance: rename_func("var"), - sge.VariancePop: rename_func("varp"), - sge.Ceil: rename_func("ceiling"), - sge.Trim: lambda self, e: f"TRIM({e.this.sql(self.dialect)})", - sge.DateFromParts: rename_func("datefromparts"), -} - y = var("y") start = var("start") end = var("end") @@ -70,7 +51,7 @@ class Generator(TSQL.Generator): class MSSQLCompiler(SQLGlotCompiler): __slots__ = () - dialect = "mssql" + dialect = MSSQL type_mapper = MSSQLType rewrites = ( rewrite_sample_as_filter, @@ -80,7 +61,6 @@ class MSSQLCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_row_number, *SQLGlotCompiler.rewrites, ) - quoted = True @property def NAN(self): diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 2277de215fb68..641a49a34b1e1 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -6,14 +6,13 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public -from sqlglot.dialects import MySQL -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import MySQLType +from ibis.backends.base.sqlglot.dialects import MySQL from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, @@ -26,19 +25,6 @@ from ibis.common.patterns import replace from ibis.expr.rewrites import p -MySQL.Generator.TRANSFORMS |= { - sge.LogicalOr: rename_func("max"), - sge.LogicalAnd: rename_func("min"), - sge.VariancePop: rename_func("var_pop"), - sge.Variance: rename_func("var_samp"), - sge.Stddev: rename_func("stddev_pop"), - sge.StddevPop: rename_func("stddev_pop"), - sge.StddevSamp: rename_func("stddev_samp"), - sge.RegexpLike: ( - lambda _, e: f"({e.this.sql('mysql')} RLIKE {e.expression.sql('mysql')})" - ), -} - @replace(p.Limit) def rewrite_limit(_, **kwargs): @@ -62,7 +48,7 @@ def rewrite_limit(_, **kwargs): class MySQLCompiler(SQLGlotCompiler): __slots__ = () - dialect = "mysql" + dialect = MySQL type_mapper = MySQLType rewrites = ( rewrite_limit, @@ -75,7 +61,6 @@ class MySQLCompiler(SQLGlotCompiler): rewrite_empty_order_by_window, *SQLGlotCompiler.rewrites, ) - quoted = True @property def NAN(self): diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index ad398e7e7d2a7..0f3e0472fac50 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -6,13 +6,12 @@ import sqlglot.expressions as sge import toolz from public import public -from sqlglot.dialects import Oracle -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func import ibis.common.exceptions as com import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import OracleType +from ibis.backends.base.sqlglot.dialects import Oracle from ibis.backends.base.sqlglot.rewrites import ( Window, exclude_unsupported_window_frame_from_ops, @@ -26,51 +25,11 @@ ) -def _create_sql(self, expression: sge.Create) -> str: - # TODO: should we use CREATE PRIVATE instead? That will set an implicit - # lower bound of Oracle 18c - properties = expression.args.get("properties") - temporary = any( - isinstance(prop, sge.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - - kind = expression.args["kind"] - if (obj := kind.upper()) in ("TABLE", "VIEW") and temporary: - if expression.expression: - return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" - else: - # TODO: why does autocommit not work here? need to specify the ON COMMIT part... - return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} ON COMMIT PRESERVE ROWS" - - return create_with_partitions_sql(self, expression) - - -Oracle.Generator.TRANSFORMS |= { - sge.LogicalOr: rename_func("max"), - sge.LogicalAnd: rename_func("min"), - sge.VariancePop: rename_func("var_pop"), - sge.Variance: rename_func("var_samp"), - sge.Stddev: rename_func("stddev_pop"), - sge.ApproxDistinct: rename_func("approx_count_distinct"), - sge.Create: _create_sql, - sge.Select: sg.transforms.preprocess([sg.transforms.eliminate_semi_and_anti_joins]), -} - -# TODO: can delete this after bumping sqlglot version > 20.9.0 -Oracle.Generator.TYPE_MAPPING |= { - sge.DataType.Type.TIMETZ: "TIME", - sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", -} -Oracle.Generator.TZ_TO_WITH_TIME_ZONE = True - - @public class OracleCompiler(SQLGlotCompiler): __slots__ = () - dialect = "oracle" - quoted = True + dialect = Oracle type_mapper = OracleType rewrites = ( exclude_unsupported_window_frame_from_row_number, diff --git a/ibis/backends/pandas/__init__.py b/ibis/backends/pandas/__init__.py index e738ae8018354..4d3cddb665c0d 100644 --- a/ibis/backends/pandas/__init__.py +++ b/ibis/backends/pandas/__init__.py @@ -26,6 +26,7 @@ class BasePandasBackend(BaseBackend, NoUrl): """Base class for backends based on pandas.""" name = "pandas" + dialect = None backend_table_type = pd.DataFrame class Options(ibis.config.Config): diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 373e68b1012b7..ef3c9b09ecee5 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -13,6 +13,7 @@ import ibis.expr.schema as sch import ibis.expr.types as ir from ibis.backends.base import BaseBackend, Database, NoUrl +from ibis.backends.base.sqlglot.dialects import Polars from ibis.backends.pandas.rewrites import ( bind_unbound_table, replace_parameter, @@ -31,9 +32,7 @@ class Backend(BaseBackend, NoUrl): name = "polars" - builder = None - - _sqlglot_dialect = "postgres" + dialect = Polars def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index de9210b0ffb14..0c546f00985f3 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -44,7 +44,6 @@ def _verify_source_line(func_name: str, line: str): class Backend(SQLGlotBackend): name = "postgres" - dialect = "postgres" compiler = PostgresCompiler() supports_python_udfs = True diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index af640cdd0f989..b421ea51050a5 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -6,8 +6,6 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public -from sqlglot.dialects import Postgres -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -15,17 +13,9 @@ import ibis.expr.rules as rlz from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler, paren from ibis.backends.base.sqlglot.datatypes import PostgresType +from ibis.backends.base.sqlglot.dialects import Postgres from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter -Postgres.Generator.TRANSFORMS |= { - sge.Map: rename_func("hstore"), - sge.Split: rename_func("string_to_array"), - sge.RegexpSplit: rename_func("regexp_split_to_array"), - sge.DateFromParts: rename_func("make_date"), - sge.ArraySize: rename_func("cardinality"), - sge.Pow: rename_func("pow"), -} - class PostgresUDFNode(ops.Value): shape = rlz.shape_like("args") @@ -35,10 +25,9 @@ class PostgresUDFNode(ops.Value): class PostgresCompiler(SQLGlotCompiler): __slots__ = () - dialect = "postgres" + dialect = Postgres type_mapper = PostgresType rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) - quoted = True NAN = sge.Literal.number("'NaN'::double precision") POS_INF = sge.Literal.number("'Inf'::double precision") diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index b186502f4cfed..e77839fa939fc 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -88,7 +88,6 @@ def __exit__(self, exc_type, exc_value, traceback): class Backend(SQLGlotBackend, CanCreateDatabase): name = "pyspark" compiler = PySparkCompiler() - _sqlglot_dialect = "spark" class Options(ibis.config.Config): """PySpark options. @@ -245,7 +244,7 @@ def _safe_raw_sql(self, query: str) -> _PySparkCursor: def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> _PySparkCursor: with contextlib.suppress(AttributeError): - query = query.sql(dialect=self._sqlglot_dialect) + query = query.sql(dialect=self.dialect) query = self._session.sql(query) return _PySparkCursor(query) @@ -482,7 +481,7 @@ def compute_stats( """ maybe_noscan = " NOSCAN" * noscan table = sg.table(name, db=database, quoted=self.compiler.quoted).sql( - dialect=self._sqlglot_dialect + dialect=self.dialect ) return self.raw_sql(f"ANALYZE TABLE {table} COMPUTE STATISTICS{maybe_noscan}") diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index bdb68f968ea50..9bf9bd1ffe583 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -15,6 +15,7 @@ import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import FALSE, NULL, STAR, TRUE, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import PySparkType +from ibis.backends.base.sqlglot.dialects import PySpark from ibis.backends.base.sqlglot.rewrites import Window, p from ibis.common.patterns import replace from ibis.config import options @@ -48,10 +49,9 @@ def offset_to_filter(_): class PySparkCompiler(SQLGlotCompiler): __slots__ = () - dialect = "spark" + dialect = PySpark type_mapper = PySparkType rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites) - quoted = True def _aggregate(self, funcname: str, *args, where): func = self.f[funcname] diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index 43191503b90ef..7016917aca1ec 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -19,22 +19,14 @@ from ibis import util from ibis.backends.postgres import Backend as PostgresBackend from ibis.backends.risingwave.compiler import RisingwaveCompiler -from ibis.backends.risingwave.dialect import RisingWave as RisingWaveDialect if TYPE_CHECKING: import pandas as pd import pyarrow as pa -def _verify_source_line(func_name: str, line: str): - if line.startswith("@"): - raise com.InvalidDecoratorError(func_name, line) - return line - - class Backend(PostgresBackend): name = "risingwave" - dialect = RisingWaveDialect compiler = RisingwaveCompiler() supports_python_udfs = False diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py index 5bc7bfef2f5bc..6aeb21ec9ddbb 100644 --- a/ibis/backends/risingwave/compiler.py +++ b/ibis/backends/risingwave/compiler.py @@ -10,16 +10,15 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.datatypes import RisingWaveType +from ibis.backends.base.sqlglot.dialects import RisingWave from ibis.backends.postgres.compiler import PostgresCompiler -from ibis.backends.risingwave.dialect import RisingWave # noqa: F401 @public class RisingwaveCompiler(PostgresCompiler): __slots__ = () - dialect = "risingwave" - name = "risingwave" + dialect = RisingWave type_mapper = RisingWaveType @singledispatchmethod @@ -30,7 +29,7 @@ def visit_node(self, op, **kwargs): def visit_Correlation(self, op, *, left, right, how, where): if how == "sample": raise com.UnsupportedOperationError( - f"{self.name} only implements `pop` correlation coefficient" + "RisingWave only implements `pop` correlation coefficient" ) return super().visit_Correlation( op, left=left, right=right, how=how, where=where diff --git a/ibis/backends/risingwave/dialect.py b/ibis/backends/risingwave/dialect.py deleted file mode 100644 index 2237c2a4d188c..0000000000000 --- a/ibis/backends/risingwave/dialect.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -import sqlglot.expressions as sge -from sqlglot import generator -from sqlglot.dialects import Postgres - - -class RisingWave(Postgres): - # Need to disable timestamp precision - # No "or replace" allowed in create statements - # no "not null" clause for column constraints - - class Generator(generator.Generator): - SINGLE_STRING_INTERVAL = True - RENAME_TABLE_WITH_DB = False - LOCKING_READS_SUPPORTED = True - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - NVL2_SUPPORTED = False - PARAMETER_TOKEN = "$" - TABLESAMPLE_SIZE_IS_ROWS = False - TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" - SUPPORTS_SELECT_INTO = True - JSON_TYPE_REQUIRED_FOR_EXTRACTION = True - SUPPORTS_UNLOGGED_TABLES = True - - TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, - sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMPTZ", - } - - TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, - } diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index da44da4a87ea4..395d223fee7a4 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -6,9 +6,6 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public -from sqlglot import exp -from sqlglot.dialects import Snowflake -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt @@ -16,6 +13,7 @@ from ibis import util from ibis.backends.base.sqlglot.compiler import NULL, C, FuncGen, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import SnowflakeType +from ibis.backends.base.sqlglot.dialects import Snowflake from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, @@ -28,11 +26,6 @@ from ibis.common.patterns import replace from ibis.expr.analysis import p -Snowflake.Generator.TRANSFORMS |= { - exp.ApproxDistinct: rename_func("approx_count_distinct"), - exp.Levenshtein: rename_func("editdistance"), -} - @replace(p.ToJSONMap | p.ToJSONArray) def replace_to_json(_): @@ -47,8 +40,7 @@ class SnowflakeFuncGen(FuncGen): class SnowflakeCompiler(SQLGlotCompiler): __slots__ = () - dialect = "snowflake" - quoted = True + dialect = Snowflake type_mapper = SnowflakeType no_limit_value = NULL rewrites = ( diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index f02a77b9b3618..52b76a3b9ca4b 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -5,13 +5,13 @@ import sqlglot as sg import sqlglot.expressions as sge from public import public -from sqlglot.dialects.sqlite import SQLite import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import SQLiteType +from ibis.backends.base.sqlglot.dialects import SQLite from ibis.backends.base.sqlglot.rewrites import ( rewrite_first_to_first_value, rewrite_last_to_last_value, @@ -19,17 +19,12 @@ ) from ibis.common.temporal import DateUnit, IntervalUnit -SQLite.Generator.TYPE_MAPPING |= { - sge.DataType.Type.BOOLEAN: "BOOLEAN", -} - @public class SQLiteCompiler(SQLGlotCompiler): __slots__ = () - dialect = "sqlite" - quoted = True + dialect = SQLite type_mapper = SQLiteType rewrites = ( rewrite_sample_as_filter, diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 6e333b4a0e2b7..d1c5e2e140ca7 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -10,7 +10,20 @@ import ibis import ibis.common.exceptions as com from ibis import _ -from ibis.backends.base import _IBIS_TO_SQLGLOT_DIALECT, _get_backend_names +from ibis.backends.base import _get_backend_names + +# here to load the dialect in to sqlglot so we can use it for transpilation +from ibis.backends.base.sqlglot.dialects import ( # noqa: F401 + MSSQL, + DataFusion, + Druid, + Exasol, + Flink, + Impala, + Polars, + PySpark, + RisingWave, +) from ibis.backends.tests.errors import ( GoogleBadRequest, OracleDatabaseError, @@ -45,12 +58,11 @@ def test_con_dot_sql(backend, con, schema): # pull out the quoted name name = _NAMES.get(con.name, "functional_alltypes") quoted = True - dialect = _IBIS_TO_SQLGLOT_DIALECT.get(con.name, con.name) cols = [ - sg.column("string_col", quoted=quoted).as_("s", quoted=quoted).sql(dialect), + sg.column("string_col", quoted=quoted).as_("s", quoted=quoted).sql(con.dialect), (sg.column("double_col", quoted=quoted) + 1.0) .as_("new_col", quoted=quoted) - .sql(dialect), + .sql(con.dialect), ] t = ( con.sql( @@ -252,7 +264,6 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): name = "foo2" foo = alltypes.select(x=_.bigint_col + 1).alias(name) expr = sg.select(sg.column("x", quoted=True)).from_(sg.table(name, quoted=True)) - dialect = _IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect) sqlstr = expr.sql(dialect=dialect, pretty=True) dot_sql_expr = foo.sql(sqlstr, dialect=dialect) result = dot_sql_expr.execute() @@ -278,7 +289,6 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): foo = sg.select( sg.alias(sg.column("bigint_col", quoted=True) + 1, "x", quoted=True) ).from_(t) - dialect = _IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect) sqlstr = foo.sql(dialect=dialect, pretty=True) expr = con.sql(sqlstr, dialect=dialect) result = expr.execute() diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 97dd852217440..bb4a5e402ab03 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -6,14 +6,13 @@ import sqlglot as sg import sqlglot.expressions as sge import toolz -from sqlglot.dialects import Trino -from sqlglot.dialects.dialect import rename_func import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import FALSE, NULL, SQLGlotCompiler, paren from ibis.backends.base.sqlglot.datatypes import TrinoType +from ibis.backends.base.sqlglot.dialects import Trino from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, rewrite_first_to_first_value, @@ -22,25 +21,10 @@ ) -# TODO(cpcloud): remove this hack once -# https://github.com/tobymao/sqlglot/issues/2735 is resolved -def make_cross_joins_explicit(node): - if not (node.kind or node.side): - node.args["kind"] = "CROSS" - return node - - -Trino.Generator.TRANSFORMS |= { - sge.BitwiseLeftShift: rename_func("bitwise_left_shift"), - sge.BitwiseRightShift: rename_func("bitwise_right_shift"), - sge.Join: sg.transforms.preprocess([make_cross_joins_explicit]), -} - - class TrinoCompiler(SQLGlotCompiler): __slots__ = () - dialect = "trino" + dialect = Trino type_mapper = TrinoType rewrites = ( rewrite_sample_as_filter, diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index d3b1c7e833425..8eb5a4e57ee5a 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -370,14 +370,14 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString: read = "duckdb" write = ibis.options.sql.default_dialect else: - read = write = getattr(backend, "_sqlglot_dialect", backend.name) + read = write = backend.dialect else: try: backend = getattr(ibis, dialect) except AttributeError: raise ValueError(f"Unknown dialect {dialect}") else: - read = write = getattr(backend, "_sqlglot_dialect", dialect) + read = write = getattr(backend, "dialect", dialect) sql = backend._to_sql(expr.unbind(), **kwargs) (pretty,) = sg.transpile(sql, read=read, write=write, pretty=True) diff --git a/ibis/tests/expr/mocks.py b/ibis/tests/expr/mocks.py index d6344621893be..7a83adfe70604 100644 --- a/ibis/tests/expr/mocks.py +++ b/ibis/tests/expr/mocks.py @@ -16,6 +16,8 @@ import contextlib +from sqlglot.dialects import DuckDB + import ibis.expr.operations as ops import ibis.expr.types as ir from ibis.backends.base import BaseBackend @@ -27,6 +29,7 @@ class MockBackend(BaseBackend): name = "mock" version = "1.0" current_database = "mockdb" + dialect = DuckDB def __init__(self): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index f57748d50e064..6fdf826666cae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -270,6 +270,7 @@ filterwarnings = [ "ignore:Concatenating dataframes with unknown divisions:UserWarning", "ignore:Possible nested set at position:FutureWarning", 'ignore:\s+You did not provide metadata:UserWarning', + "ignore:Minimal version of pyarrow will soon be increased:FutureWarning", # numpy by way of dask 'ignore:np\.find_common_type is deprecated:DeprecationWarning', # pandas