From 13f4567f3e1f21e87f3a70acaba02e3548711f6c Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Mon, 10 Jul 2023 16:58:53 -0400 Subject: [PATCH 001/222] feat(duckdb): initial cut of sqlglot DuckDB compiler it's alive! tests run (and fail) chore(duckdb): naive port of clickhouse compiler fix(duckdb): hacky fix for output shape feat(duckdb): bitwise ops (most of them) feat(duckdb): handle pandas dtype mapping in execute feat(duckdb): handle decimal types feat(duckdb): add euler's number test(duckdb): remove duckdb from alchemycon feat(duckdb): get _most_ of string ops working still some failures in re_exract feat(duckdb): add hash feat(duckdb): add CAST feat(duckdb): add cot and strright chore(duckdb): mark all the targets that still need attention (at least) feat(duckdb): combine binary bitwise ops chore(datestuff): some datetime ops feat(duckdb): add levenshtein, use op.dtype instead of output_dtype feat(duckdb): add blank list_schemas, use old current_database for now feat(duckdb): basic interval ops feat(duckdb): timestamp and temporal ops feat(duckdb): use pyarrow for fetching execute results feat(duckdb): handle interval casts, broken for columns feat(duckdb): shove literal handling up top feat(duckdb): more timestamp ops feat(duckdb): back to pandas output in execute feat(duckdb): timezone handling in cast feat(duckdb): ms and us epoch timestamp support chore(duckdb): misc cleanup feat(duckdb): initial create table feat(duckdb): add _from_url feat(duckdb): add read_parquet feat(duckdb): add persistent cache fix(duckdb): actually insert data if present in create_table feat(duckdb): use duckdb API read_parquet feat(duckdb): add read_csv This, frustratingly, cannot use the Python API for `read_csv` since that does not support list of files, for some reason. fix(duckdb): dont fully qualify the table names chore(duckdb): cleanup chore(duckdb): mark broken test broken fix(duckdb): fix read_parquet so it works feat(duckdb): add to_pyarrow, to_pyarrow_batches, sql() feat(duckdb): null checking feat(duckdb): translate uints fix(duckdb): fix file outputs and torch output fix(duckdb): add rest of integer types fix(duckdb): ops.InValues feat(duckdb): use sqlglot expressions (maybe a big mistake) fix(duckdb): don't stringify strings feat(duckdb): use sqlglot expr instead of strings for count fix(duckdb): fix isin fix(duckdb): fix some agg variance functions fix(duckdb): for logical equals, use sqlglot not operator fix(duckdb): struct not tuple for struct type --- ibis/backends/conftest.py | 1 - ibis/backends/duckdb/__init__.py | 714 ++++++---- ibis/backends/duckdb/compiler.py | 67 - ibis/backends/duckdb/compiler/__init__.py | 13 + ibis/backends/duckdb/compiler/core.py | 95 ++ ibis/backends/duckdb/compiler/relations.py | 220 +++ ibis/backends/duckdb/compiler/values.py | 1484 ++++++++++++++++++++ ibis/backends/duckdb/datatypes.py | 182 ++- ibis/backends/duckdb/tests/conftest.py | 5 + ibis/backends/tests/test_client.py | 3 + ibis/backends/tests/test_numeric.py | 21 +- ibis/backends/tests/test_string.py | 5 + ibis/backends/tests/test_temporal.py | 16 +- 13 files changed, 2403 insertions(+), 423 deletions(-) delete mode 100644 ibis/backends/duckdb/compiler.py create mode 100644 ibis/backends/duckdb/compiler/__init__.py create mode 100644 ibis/backends/duckdb/compiler/core.py create mode 100644 ibis/backends/duckdb/compiler/relations.py create mode 100644 ibis/backends/duckdb/compiler/values.py diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index 4f54107643e1..1913fe51d6b6 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -543,7 +543,6 @@ def ddl_con(ddl_backend): @pytest.fixture( params=_get_backends_to_test( keep=( - "duckdb", "mssql", "mysql", "oracle", diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index d1e2b34fdf69..e3b6ad7e554a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -7,17 +7,15 @@ import os import warnings from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, -) +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, MutableMapping import duckdb import pyarrow as pa -import sqlalchemy as sa +import sqlglot as sg import toolz from packaging.version import parse as vparse +import ibis import ibis.common.exceptions as exc import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -25,16 +23,17 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base import CanCreateSchema -from ibis.backends.base.sql.alchemy import BaseAlchemyBackend -from ibis.backends.duckdb.compiler import DuckDBSQLCompiler -from ibis.backends.duckdb.datatypes import DuckDBType +from ibis.backends.base.sql import BaseBackend +from ibis.backends.duckdb.compiler import translate +from ibis.backends.duckdb.datatypes import parse, serialize, DuckDBType from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType +from ibis.formats.pyarrow import PyArrowData from ibis.formats.pandas import PandasData if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence - + import ibis.expr.operations as ops import pandas as pd import torch @@ -68,46 +67,208 @@ def _format_kwargs(kwargs: Mapping[str, Any]): } -class Backend(BaseAlchemyBackend, CanCreateSchema): +class DuckDBTable(ir.Table): + """References a physical table in DuckDB.""" + + @property + def _client(self): + return self.op().source + + @property + def name(self): + return self.op().name + + +class Backend(BaseBackend, CanCreateSchema): name = "duckdb" - compiler = DuckDBSQLCompiler supports_create_or_replace = True + def _define_udf_translation_rules(self, expr): + # TODO: + ... + + def _register_udfs(self, expr): + # TODO: + ... + @property def current_database(self) -> str: - return self._scalar_query(sa.select(sa.func.current_database())) + return "main" + + @property + def current_schema(self) -> str: + return self.raw_sql("SELECT current_schema()") + + def raw_sql(self, query: str, **kwargs: Any) -> Any: + return self.con.execute(query, **kwargs) + + def create_table( + self, + name: str, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, + *, + schema: ibis.Schema | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + ): + tmp = "TEMP " * temp + replace = "OR REPLACE" * overwrite + + if temp and overwrite: + raise exc.IbisInputError("Cannot specify both temp and overwrite") + + if not temp: + table = self._fully_qualified_name(name, database) + else: + table = name + database = None + code = f"CREATE {replace}{tmp}TABLE {table}" + + if obj is None and schema is None: + raise exc.IbisError("The schema or obj parameter is required") + + if obj is not None and not isinstance(obj, ir.Expr): + obj = ibis.memtable(obj, schema=schema) + self._register_in_memory_table(obj.op()) + code += f" AS {self.compile(obj)}" + else: + # If both `obj` and `schema` are specified, `obj` overrides `schema` + # DuckDB doesn't support `create table (schema) AS select * ...` + if obj is not None: + code += f" AS {self.compile(obj)}" + else: + serialized_schema = ", ".join( + f"{name} {serialize(typ)}" for name, typ in schema.items() + ) + + code += f" ({serialized_schema})" + + # create the table + self.raw_sql(code) + + return self.table(name, database=database) + + def create_view( + self, + name: str, + obj: ir.Table, + *, + database: str | None = None, + overwrite: bool = False, + ) -> ir.Table: + qualname = self._fully_qualified_name(name, database) + replace = "OR REPLACE " * overwrite + query = self.compile(obj) + code = f"CREATE {replace}VIEW {qualname} AS {query}" + self.raw_sql(code) + + return self.table(name, database=database) + + def drop_table( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + ident = self._fully_qualified_name(name, database) + self.raw_sql(f"DROP TABLE {'IF EXISTS ' * force}{ident}") + + def drop_view( + self, name: str, *, database: str | None = None, force: bool = False + ) -> None: + name = self._fully_qualified_name(name, database) + if_exists = "IF EXISTS " * force + self.raw_sql(f"DROP VIEW {if_exists}{name}") + + def _load_into_cache(self, name, expr): + self.create_table(name, expr, schema=expr.schema(), temp=True) + + def _clean_up_cached_table(self, op): + self.drop_table(op.name) + + def list_schemas(self): + ... + + def table(self, name: str, database: str | None = None) -> ir.Table: + """Construct a table expression. + + Parameters + ---------- + name + Table name + database + Database name + + Returns + ------- + Table + Table expression + """ + schema = self.get_schema(name, database=database) + qname = self._fully_qualified_name(name, database) + return DuckDBTable(ops.DatabaseTable(qname, schema, self)) + + def _fully_qualified_name(self, name: str, database: str | None) -> str: + return name + # TODO: make this less bad + # calls to here from `drop_table` already have `main` prepended to the table name + # so what's the more robust way to deduplicate that identifier? + db = database or self.current_database + if name.startswith(db): + # This is a hack to get around nested quoting of table name + # e.g. '"main._ibis_temp_table_2"' + return name + return sg.table(name, db=db).sql(dialect="duckdb") + + def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema: + """Return a Schema object for the indicated table and database. + + Parameters + ---------- + table_name + May **not** be fully qualified. Use `database` if you want to + qualify the identifier. + database + Database name + + Returns + ------- + sch.Schema + Ibis schema + """ + qualified_name = self._fully_qualified_name(table_name, database) + query = f"DESCRIBE {qualified_name}" + results = self.raw_sql(query) + names, types, *_ = results.fetch_arrow_table() + names = names.to_pylist() + types = types.to_pylist() + return sch.Schema(dict(zip(names, map(parse, types)))) def list_databases(self, like: str | None = None) -> list[str]: - s = sa.table( - "schemata", - sa.column("catalog_name", sa.TEXT()), - schema="information_schema", - ) + result = self.raw_sql("PRAGMA database_list;") + results = result.fetch_arrow_table() - query = sa.select(sa.distinct(s.c.catalog_name)) - with self.begin() as con: - results = list(con.execute(query).scalars()) - return self._filter_with_like(results, like=like) - - def list_schemas( - self, like: str | None = None, database: str | None = None - ) -> list[str]: - # override duckdb because all databases are always visible - text = """\ -SELECT schema_name -FROM information_schema.schemata -WHERE catalog_name = :database""" - query = sa.text(text).bindparams( - database=database if database is not None else self.current_database - ) + if results: + _, databases, *_ = results + databases = databases.to_pylist() + else: + databases = [] + return self._filter_with_like(databases, like) - with self.begin() as con: - schemas = list(con.execute(query).scalars()) - return self._filter_with_like(schemas, like=like) + def list_tables(self, like: str | None = None) -> list[str]: + result = self.raw_sql("PRAGMA show_tables;") + results = result.fetch_arrow_table() - @property - def current_schema(self) -> str: - return self._scalar_query(sa.select(sa.func.current_schema())) + if results: + tables, *_ = results + tables = tables.to_pylist() + else: + tables = [] + return self._filter_with_like(tables, like) + + @classmethod + def has_operation(cls, operation: type[ops.Value]) -> bool: + from ibis.backends.duckdb.compiler.values import translate_val + + return translate_val.dispatch(operation) is not translate_val.dispatch(object) @staticmethod def _convert_kwargs(kwargs: MutableMapping) -> None: @@ -126,47 +287,6 @@ def version(self) -> str: return importlib.metadata.version("duckdb") - @staticmethod - def _new_sa_metadata(): - meta = sa.MetaData() - - # _new_sa_metadata is invoked whenever `_get_sqla_table` is called, so - # it's safe to store columns as keys, that is, columns from different - # tables with the same name won't collide - complex_type_info_cache = {} - - @sa.event.listens_for(meta, "column_reflect") - def column_reflect(inspector, table, column_info): - import duckdb_engine.datatypes as ddt - - # duckdb_engine as of 0.7.2 doesn't expose the inner types of any - # complex types so we have to extract it from duckdb directly - ddt_struct_type = getattr(ddt, "Struct", sa.types.NullType) - ddt_map_type = getattr(ddt, "Map", sa.types.NullType) - if isinstance( - column_info["type"], (sa.ARRAY, ddt_struct_type, ddt_map_type) - ): - engine = inspector.engine - colname = column_info["name"] - if (coltype := complex_type_info_cache.get(colname)) is None: - quote = engine.dialect.identifier_preparer.quote - quoted_colname = quote(colname) - quoted_tablename = quote(table.name) - with engine.connect() as con: - # The .connection property is used to avoid creating a - # nested transaction - con.connection.execute( - f"DESCRIBE SELECT {quoted_colname} FROM {quoted_tablename}" - ) - _, typ, *_ = con.connection.fetchone() - complex_type_info_cache[colname] = coltype = DuckDBType.from_string( - typ - ) - - column_info["type"] = DuckDBType.from_ibis(coltype) - - return meta - def do_connect( self, database: str | Path = ":memory:", @@ -216,51 +336,111 @@ def do_connect( Path(temp_directory).mkdir(parents=True, exist_ok=True) config["temp_directory"] = str(temp_directory) - engine = sa.create_engine( - f"duckdb:///{database}", - connect_args=dict(read_only=read_only, config=config), - poolclass=sa.pool.StaticPool, - ) + import duckdb + + self.con = duckdb.connect(str(database)) - @sa.event.listens_for(engine, "connect") - def configure_connection(dbapi_connection, connection_record): - if extensions is not None: - self._sa_load_extensions(dbapi_connection, extensions) - dbapi_connection.execute("SET TimeZone = 'UTC'") - # the progress bar in duckdb <0.8.0 causes kernel crashes in - # jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831 - if vparse(duckdb.__version__) < vparse("0.8.0"): - dbapi_connection.execute("SET enable_progress_bar = false") + # TODO: disable progress bar for < 0.8.0 + # TODO: set timezone to UTC self._record_batch_readers_consumed = {} - # TODO(cpcloud): remove this when duckdb is >0.8.1 - # this is here to workaround https://github.com/duckdb/duckdb/issues/8735 - with contextlib.suppress(duckdb.InvalidInputException): - duckdb.execute("SELECT ?", (1,)) + def _from_url(self, url: str, **kwargs) -> BaseBackend: + """Connect to a backend using a URL `url`. - super().do_connect(engine) + Parameters + ---------- + url + URL with which to connect to a backend. + kwargs + Additional keyword arguments - @staticmethod - def _sa_load_extensions(dbapi_con, extensions): - query = """ - WITH exts AS ( - SELECT extension_name AS name, aliases FROM duckdb_extensions() - WHERE installed AND loaded + Returns + ------- + BaseBackend + A backend instance + """ + import sqlalchemy as sa + + url = sa.engine.make_url(url) + + kwargs = toolz.merge( + { + name: value + for name in ("database", "read_only", "temp_directory") + if (value := getattr(url, name, None)) + }, + kwargs, ) - SELECT name FROM exts - UNION (SELECT UNNEST(aliases) AS name FROM exts) + + kwargs.update(url.query) + self._convert_kwargs(kwargs) + return self.connect(**kwargs) + + def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any): + table_expr = expr.as_table() + + if limit == "default": + limit = ibis.options.sql.default_limit + if limit is not None: + table_expr = table_expr.limit(limit) + + if params is None: + params = {} + + sql = translate(table_expr.op(), params=params) + assert not isinstance(sql, sg.exp.Subquery) + + if isinstance(sql, sg.exp.Table): + sql = sg.select("*").from_(sql) + + assert not isinstance(sql, sg.exp.Subquery) + return sql.sql(dialect="duckdb", pretty=True) + + def _to_sql(self, expr: ir.Expr, **kwargs) -> str: + return str(self.compile(expr, **kwargs)) + + def _log(self, sql: str) -> None: + """Log `sql`. + + This method can be implemented by subclasses. Logging occurs when + `ibis.options.verbose` is `True`. """ - installed = (name for (name,) in dbapi_con.sql(query).fetchall()) - # Install and load all other extensions - todo = set(extensions).difference(installed) - for extension in todo: - dbapi_con.install_extension(extension) - dbapi_con.load_extension(extension) + util.log(sql) - def _load_extensions(self, extensions): - with self.begin() as con: - self._sa_load_extensions(con.connection, extensions) + def execute( + self, + expr: ir.Expr, + limit: str | None = "default", + external_tables: Mapping[str, pd.DataFrame] | None = None, + **kwargs: Any, + ) -> Any: + """Execute an expression.""" + + self._run_pre_execute_hooks(expr) + table = expr.as_table() + sql = self.compile(table, limit=limit, **kwargs) + + schema = table.schema() + self._log(sql) + + try: + result = self.con.execute(sql) + except duckdb.CatalogException as e: + raise exc.IbisError(e) + + # TODO: should we do this in arrow? + # also wth is pandas doing with dates? + pandas_df = result.fetch_df() + result = PandasData.convert_table(pandas_df, schema) + if isinstance(expr, ir.Table): + return result + elif isinstance(expr, ir.Column): + return result.iloc[:, 0] + elif isinstance(expr, ir.Scalar): + return result.iat[0, 0] + else: + raise ValueError def load_extension(self, extension: str) -> None: """Install and load a duckdb extension by name or path. @@ -272,6 +452,22 @@ def load_extension(self, extension: str) -> None: """ self._load_extensions([extension]) + def _load_extensions(self, extensions): + query = """ + WITH exts AS ( + SELECT extension_name AS name, aliases FROM duckdb_extensions() + WHERE installed AND loaded + ) + SELECT name FROM exts + UNION (SELECT UNNEST(aliases) AS name FROM exts) + """ + installed = (name for (name,) in self.con.sql(query).fetchall()) + # Install and load all other extensions + todo = set(extensions).difference(installed) + for extension in extensions: + self.con.install_extension(extension) + self.con.load_extension(extension) + def create_schema( self, name: str, database: str | None = None, force: bool = False ) -> None: @@ -279,7 +475,7 @@ def create_schema( raise exc.UnsupportedOperationError( "DuckDB cannot create a schema in another database." ) - name = self._quote(name) + # name = self._quote(name) if_not_exists = "IF NOT EXISTS " * force with self.begin() as con: con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") @@ -291,11 +487,26 @@ def drop_schema( raise exc.UnsupportedOperationError( "DuckDB cannot drop a schema in another database." ) - name = self._quote(name) + # name = self._quote(name) if_exists = "IF EXISTS " * force with self.begin() as con: con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") + def sql( + self, + query: str, + schema: SupportsSchema | None = None, + dialect: str | None = None, + ) -> ir.Table: + query = self._transpile_sql(query, dialect=dialect) + if schema is None: + schema = self._get_schema_using_query(query) + return ops.SQLQueryResult(query, ibis.schema(schema), self).to_expr() + + def _get_schema_using_query(self, query: str) -> sch.Schema: + """Return an ibis Schema from a backend-specific SQL string.""" + return sch.Schema.from_tuples(self._metadata(query)) + def register( self, source: str | Path | Any, @@ -454,93 +665,12 @@ def read_csv( kwargs.setdefault("header", True) kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs) - source = sa.select(sa.literal_column("*")).select_from( - sa.func.read_csv(sa.func.list_value(*source_list), _format_kwargs(kwargs)) - ) + options = ", " + ",".join([f"{key}={val}" for key, val in kwargs.items()]) - view = self._compile_temp_view(table_name, source) - with self.begin() as con: - con.exec_driver_sql(view) - return self.table(table_name) - - def _get_sqla_table( - self, - name: str, - schema: str | None = None, - database: str | None = None, - **_: Any, - ) -> sa.Table: - if schema is None: - schema = self.current_schema - *db, schema = schema.split(".") - db = "".join(db) or database - ident = ".".join( - map( - self._quote, - filter(None, (db if db != self.current_database else None, schema)), - ) - ) + sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv({source_list}{options})" - s = sa.table( - "columns", - sa.column("table_catalog", sa.TEXT()), - sa.column("table_schema", sa.TEXT()), - sa.column("table_name", sa.TEXT()), - sa.column("column_name", sa.TEXT()), - sa.column("data_type", sa.TEXT()), - sa.column("is_nullable", sa.TEXT()), - sa.column("ordinal_position", sa.INTEGER()), - schema="information_schema", - ) - - where = s.c.table_name == name - - if db: - where &= s.c.table_catalog == db - - if schema: - where &= s.c.table_schema == schema - - query = ( - sa.select( - s.c.column_name, - s.c.data_type, - (s.c.is_nullable == "YES").label("nullable"), - ) - .where(where) - .order_by(sa.asc(s.c.ordinal_position)) - ) - - with self.begin() as con: - # fetch metadata with pyarrow, it's much faster for wide tables - meta = con.execute(query).cursor.fetch_arrow_table() - - if not meta: - raise sa.exc.NoSuchTableError(name) - - names = meta["column_name"].to_pylist() - types = meta["data_type"].to_pylist() - nullables = meta["nullable"].to_pylist() - - ibis_schema = sch.Schema( - { - name: DuckDBType.from_string(typ, nullable=nullable) - for name, typ, nullable in zip(names, types, nullables) - } - ) - columns = self._columns_from_schema(name, ibis_schema) - return sa.table(name, *columns, schema=ident) - - def drop_table( - self, name: str, database: str | None = None, force: bool = False - ) -> None: - name = self._quote(name) - # TODO: handle database quoting - if database is not None: - name = f"{database}.{name}" - drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}" - with self.begin() as con: - con.exec_driver_sql(drop_stmt) + self.raw_sql(sql) + return self.table(table_name) def read_parquet( self, @@ -574,13 +704,12 @@ def read_parquet( # Default to using the native duckdb parquet reader # If that fails because of auth issues, fall back to ingesting via # pyarrow dataset - try: - self._read_parquet_duckdb_native(source_list, table_name, **kwargs) - except sa.exc.OperationalError as e: - if isinstance(e.orig, duckdb.IOException): - self._read_parquet_pyarrow_dataset(source_list, table_name, **kwargs) - else: - raise e + self._read_parquet_duckdb_native(source_list, table_name, **kwargs) + # except sa.exc.OperationalError as e: + # if isinstance(e.orig, duckdb.IOException): + # self._read_parquet_pyarrow_dataset(source_list, table_name, **kwargs) + # else: + # raise e return self.table(table_name) @@ -593,14 +722,13 @@ def _read_parquet_duckdb_native( ): self._load_extensions(["httpfs"]) - source = sa.select(sa.literal_column("*")).select_from( - sa.func.read_parquet( - sa.func.list_value(*source_list), _format_kwargs(kwargs) - ) - ) - view = self._compile_temp_view(table_name, source) - with self.begin() as con: - con.exec_driver_sql(view) + options = "" + if kw := kwargs: + options = ", " + ",".join([f"{key}={val}" for key, val in kw.items()]) + + sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet({source_list}{options})" + + self.raw_sql(sql) def _read_parquet_pyarrow_dataset( self, source_list: str | Iterable[str], table_name: str, **kwargs: Any @@ -693,15 +821,6 @@ def read_delta( delta_table.to_pyarrow_dataset(), table_name=table_name ) - def list_tables(self, like=None, database=None): - tables = self.inspector.get_table_names(schema=database) - views = self.inspector.get_view_names(schema=database) - # workaround for GH5503 - temp_views = self.inspector.get_view_names( - schema="temp" if database is None else database - ) - return self._filter_with_like(tables + views + temp_views, like) - def read_postgres( self, uri: str, table_name: str | None = None, schema: str = "public" ) -> ir.Table: @@ -871,8 +990,8 @@ def to_pyarrow_batches( ::: """ self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() + table = expr.as_table() + sql = self.compile(table, limit=limit, params=params) # handle the argument name change in duckdb 0.8.0 fetch_record_batch = ( @@ -881,15 +1000,17 @@ def to_pyarrow_batches( else (lambda cur: cur.fetch_record_batch(chunk_size=chunk_size)) ) - def batch_producer(con): - with con.begin() as c, contextlib.closing(c.execute(sql)) as cur: - yield from fetch_record_batch(cur.cursor) + def batch_producer(table): + yield from fetch_record_batch(table) + # TODO: check that this is still handled correctly # batch_producer keeps the `self.con` member alive long enough to # exhaust the record batch reader, even if the backend or connection # have gone out of scope in the caller + table = self.raw_sql(sql) + return pa.RecordBatchReader.from_batches( - expr.as_table().schema().to_pyarrow(), batch_producer(self.con) + expr.as_table().schema().to_pyarrow(), batch_producer(table) ) def to_pyarrow( @@ -901,20 +1022,10 @@ def to_pyarrow( **_: Any, ) -> pa.Table: self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - - # We use `.sql` instead of `.execute` below for performance - in - # certain cases duckdb query -> arrow table can be significantly faster - # in this configuration. Currently `.sql` doesn't support parametrized - # queries, so we need to compile with literal_binds for now. - sql = str( - query_ast.compile().compile( - dialect=self.con.dialect, compile_kwargs={"literal_binds": True} - ) - ) + table = expr.as_table() + sql = self.compile(table, limit=limit, params=params) - with self.begin() as con: - table = con.connection.sql(sql).to_arrow_table() + table = self.raw_sql(sql).fetch_arrow_table() return expr.__pyarrow_result__(table) @@ -946,8 +1057,7 @@ def to_torch( A dictionary of torch tensors, keyed by column name. """ compiled = self.compile(expr, limit=limit, params=params, **kwargs) - with self._safe_raw_sql(compiled) as cur: - return cur.connection.connection.torch() + return self.raw_sql(compiled).torch() @util.experimental def to_parquet( @@ -1005,8 +1115,7 @@ def to_parquet( query = self._to_sql(expr, params=params) args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items())] copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})" - with self.begin() as con: - con.exec_driver_sql(copy_cmd) + self.raw_sql(copy_cmd) @util.experimental def to_csv( @@ -1044,8 +1153,7 @@ def to_csv( *(f"{k.upper()} {v!r}" for k, v in kwargs.items()), ] copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})" - with self.begin() as con: - con.exec_driver_sql(copy_cmd) + self.raw_sql(copy_cmd) def fetch_from_cursor( self, cursor: duckdb.DuckDBPyConnection, schema: sch.Schema @@ -1074,15 +1182,21 @@ def fetch_from_cursor( return PandasData.convert_table(df, schema) def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: - with self.begin() as con: - rows = con.exec_driver_sql(f"DESCRIBE {query}") + rows = self.raw_sql(f"DESCRIBE {query}").fetch_arrow_table() + + as_py = lambda val: val.as_py() + for name, type, null in zip( + map(as_py, rows["column_name"]), + map(as_py, rows["column_type"]), + map(as_py, rows["null"]), + ): + ibis_type = parse(type) + # ibis_type = DuckDBType.from_string(type, nullable=nullable) + yield name, ibis_type.copy(nullable=null.lower() == "yes") - for name, type, null in toolz.pluck( - ["column_name", "column_type", "null"], rows.mappings() - ): - nullable = null.lower() == "yes" - ibis_type = DuckDBType.from_string(type, nullable=nullable) - yield name, ibis_type + def _register_in_memory_tables(self, expr: ir.Expr) -> None: + for memtable in expr.op().find(ops.InMemoryTable): + self._register_in_memory_table(memtable) def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: # in theory we could use pandas dataframes, but when using dataframes @@ -1117,8 +1231,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: # register creates a transaction, and we can't nest transactions so # we create a function to encapsulate the whole shebang def _register(name, table): - with self.begin() as con: - con.connection.register(name, table) + self.con.register(name, table) try: _register(name, table) @@ -1131,36 +1244,37 @@ def _get_temp_view_definition( yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" def _register_udfs(self, expr: ir.Expr) -> None: - import ibis.expr.operations as ops + ... + # import ibis.expr.operations as ops - with self.begin() as con: - for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - ) - with contextlib.suppress(duckdb.InvalidInputException): - con.connection.remove_function(udf_node.__class__.__name__) + # with self.begin() as con: + # for udf_node in expr.op().find(ops.ScalarUDF): + # compile_func = getattr( + # self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + # ) + # with contextlib.suppress(duckdb.InvalidInputException): + # con.connection.remove_function(udf_node.__class__.__name__) - registration_func = compile_func(udf_node) - if registration_func is not None: - registration_func(con) + # registration_func = compile_func(udf_node) + # registration_func(con) def _compile_udf(self, udf_node: ops.ScalarUDF) -> None: - func = udf_node.__func__ - name = func.__name__ - input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args] - output_type = DuckDBType.to_string(udf_node.dtype) - - def register_udf(con): - return con.connection.create_function( - name, - func, - input_types, - output_type, - type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__], - ) - - return register_udf + ... + # func = udf_node.__func__ + # name = func.__name__ + # input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args] + # output_type = DuckDBType.to_string(udf_node.dtype) + + # def register_udf(con): + # return con.connection.create_function( + # name, + # func, + # input_types, + # output_type, + # type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__], + # ) + + # return register_udf _compile_python_udf = _compile_udf _compile_pyarrow_udf = _compile_udf @@ -1177,17 +1291,19 @@ def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable) def _insert_dataframe( self, table_name: str, df: pd.DataFrame, overwrite: bool ) -> None: - columns = list(df.columns) - t = sa.table(table_name, *map(sa.column, columns)) + # TODO: reimplement + ... + # columns = list(df.columns) + # t = sa.table(table_name, *map(sa.column, columns)) - table_name = self._quote(table_name) + # table_name = self._quote(table_name) - # the table name df here matters, and *must* match the input variable's - # name because duckdb will look up this name in the outer scope of the - # insert call and pull in that variable's data to scan - source = sa.table("df", *map(sa.column, columns)) + # # the table name df here matters, and *must* match the input variable's + # # name because duckdb will look up this name in the outer scope of the + # # insert call and pull in that variable's data to scan + # source = sa.table("df", *map(sa.column, columns)) - with self.begin() as con: - if overwrite: - con.execute(t.delete()) - con.execute(t.insert().from_select(columns, sa.select(source))) + # with self.begin() as con: + # if overwrite: + # con.execute(t.delete()) + # con.execute(t.insert().from_select(columns, sa.select(source))) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py deleted file mode 100644 index caf2d6c266a6..000000000000 --- a/ibis/backends/duckdb/compiler.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import sqlalchemy as sa -from sqlalchemy.ext.compiler import compiles - -import ibis.backends.base.sql.alchemy.datatypes as sat -import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator -from ibis.backends.duckdb.datatypes import DuckDBType -from ibis.backends.duckdb.registry import operation_registry - - -class DuckDBSQLExprTranslator(AlchemyExprTranslator): - _registry = operation_registry - _rewrites = AlchemyExprTranslator._rewrites.copy() - _has_reduction_filter_syntax = True - _supports_tuple_syntax = True - _dialect_name = "duckdb" - - type_mapper = DuckDBType - - -@compiles(sat.UInt8, "duckdb") -def compile_uint8(element, compiler, **kw): - return "UTINYINT" - - -@compiles(sat.UInt16, "duckdb") -def compile_uint16(element, compiler, **kw): - return "USMALLINT" - - -@compiles(sat.UInt32, "duckdb") -def compile_uint32(element, compiler, **kw): - return "UINTEGER" - - -@compiles(sat.UInt64, "duckdb") -def compile_uint(element, compiler, **kw): - return "UBIGINT" - - -@compiles(sat.ArrayType, "duckdb") -def compile_array(element, compiler, **kw): - if isinstance(value_type := element.value_type, sa.types.NullType): - # duckdb infers empty arrays with no other context as array - typ = "INTEGER" - else: - typ = compiler.process(value_type, **kw) - return f"{typ}[]" - - -rewrites = DuckDBSQLExprTranslator.rewrites - - -@rewrites(ops.Any) -@rewrites(ops.All) -@rewrites(ops.NotAny) -@rewrites(ops.NotAll) -@rewrites(ops.StringContains) -def _no_op(expr): - return expr - - -class DuckDBSQLCompiler(AlchemyCompiler): - cheap_in_memory_tables = True - translator_class = DuckDBSQLExprTranslator diff --git a/ibis/backends/duckdb/compiler/__init__.py b/ibis/backends/duckdb/compiler/__init__.py new file mode 100644 index 000000000000..dfb2f03acc1a --- /dev/null +++ b/ibis/backends/duckdb/compiler/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from public import public + +from ibis.backends.duckdb.compiler.core import translate +from ibis.backends.duckdb.compiler.relations import translate_rel +from ibis.backends.duckdb.compiler.values import translate_val + +public( + translate=translate, + translate_rel=translate_rel, + translate_val=translate_val, +) diff --git a/ibis/backends/duckdb/compiler/core.py b/ibis/backends/duckdb/compiler/core.py new file mode 100644 index 000000000000..2004d24e20f1 --- /dev/null +++ b/ibis/backends/duckdb/compiler/core.py @@ -0,0 +1,95 @@ +"""DuckDB ibis expression to sqlglot compiler. + +The compiler is built with a few `singledispatch` functions: + + 1. `translate` for table expressions + 1. `translate` for table nodes + 1. `translate_rel` + 1. `translate_val` + +## `translate` + +### Expression Implementation + +The table expression implementation of `translate` is a pass through to the +node implementation. + +### Node Implementation + +There's a single `ops.Node` implementation for `ops.TableNode`s instances. + +This function: + + 1. Topologically sorts the expression graph. + 1. Seeds the compilation cache with in-degree-zero table names. + 1. Iterates though nodes with at least one in-degree and places the result + in the compilation cache. The cache is used to construct `ops.TableNode` + keyword arguments to the current translation rule. + +## `translate_rel` + +Translates a table operation given already-translated table inputs. + +If a table node needs to translate value expressions, for example, an +`ops.Aggregation` that rule is responsible for calling `translate_val`. + +## `translate_val` + +Recurses top-down and translates the arguments of the value expression and uses +those as input to construct the output. +""" + +from __future__ import annotations + +from typing import Any, Mapping + +import sqlglot as sg + +import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.backends.duckdb.compiler.relations import translate_rel + + +def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression: + """Translate an ibis operation to a sqlglot expression. + + Parameters + ---------- + op + An ibis `TableNode` + params + A mapping of expressions to concrete values + + Returns + ------- + sqlglot.expressions.Expression + A sqlglot expression + """ + params = {param.op(): value for param, value in params.items()} + + alias_index = 0 + aliases = {} + + def fn(node, cache, params=params, **kwargs): + nonlocal alias_index + + # don't alias the root node + if node is not op: + aliases[node] = f"t{alias_index:d}" + alias_index += 1 + + raw_rel = translate_rel( + node, aliases=aliases, params=params, cache=cache, **kwargs + ) + + if alias := aliases.get(node): + try: + return raw_rel.subquery(alias) + except AttributeError: + return sg.alias(raw_rel, alias) + else: + return raw_rel + + results = op.map(fn, filter=ops.TableNode) + node = results[op] + return node.this if isinstance(node, sg.exp.Subquery) else node diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py new file mode 100644 index 000000000000..202dabc977d2 --- /dev/null +++ b/ibis/backends/duckdb/compiler/relations.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import functools +from functools import partial + +import sqlglot as sg + +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.backends.duckdb.compiler.values import translate_val + + +@functools.singledispatch +def translate_rel(op: ops.TableNode, **_): + """Translate a table node into sqlglot.""" + raise com.OperationNotDefinedError(f"No translation rule for {type(op)}") + + +@translate_rel.register(ops.DummyTable) +def _dummy(op: ops.DummyTable, **kw): + return sg.select(*map(partial(translate_val, **kw), op.values), dialect="duckdb") + + +@translate_rel.register(ops.PhysicalTable) +def _physical_table(op: ops.PhysicalTable, **_): + return sg.parse_one(op.name, into=sg.exp.Table) + + +@translate_rel.register(ops.Selection) +def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): + # needs_alias should never be true here in explicitly, but it may get + # passed via a (recursive) call to translate_val + assert not needs_alias, "needs_alias is True" + if needs_alias := isinstance(op.table, ops.Join) and not isinstance( + op.table, (ops.LeftSemiJoin, ops.LeftAntiJoin) + ): + args = table.this.args + from_ = args["from"] + (join,) = args["joins"] + else: + from_ = join = None + tr_val = partial(translate_val, needs_alias=needs_alias, **kw) + selections = tuple(map(tr_val, op.selections)) or "*" + sel = sg.select(*selections, dialect="duckdb").from_( + from_ if from_ is not None else table, dialect="duckdb" + ) + + if join is not None: + sel = sel.join(join) + + if predicates := op.predicates: + if join is not None: + sel = sg.select("*").from_(sel.subquery(kw["aliases"][op.table])) + res = functools.reduce( + lambda left, right: left.and_(right), + ( + sg.condition(tr_val(predicate), dialect="duckdb") + for predicate in predicates + ), + ) + sel = sel.where(res, dialect="duckdb") + + if sort_keys := op.sort_keys: + sel = sel.order_by(*map(tr_val, sort_keys), dialect="duckdb") + + return sel + + +@translate_rel.register(ops.Aggregation) +def _aggregation(op: ops.Aggregation, *, table, **kw): + tr_val = partial(translate_val, **kw) + tr_val_no_alias = partial(translate_val, render_aliases=False, **kw) + + by = tuple(map(tr_val, op.by)) + metrics = tuple(map(tr_val, op.metrics)) + selections = (by + metrics) or "*" + sel = sg.select(*selections).from_(table) + + if group_keys := op.by: + sel = sel.group_by(*map(tr_val_no_alias, group_keys), dialect="duckdb") + + if predicates := op.predicates: + sel = sel.where(*map(tr_val_no_alias, predicates), dialect="duckdb") + + if having := op.having: + sel = sel.having(*map(tr_val_no_alias, having), dialect="duckdb") + + if sort_keys := op.sort_keys: + sel = sel.order_by(*map(tr_val_no_alias, sort_keys), dialect="duckdb") + + return sel + + +_JOIN_TYPES = { + ops.InnerJoin: "INNER", + ops.AnyInnerJoin: "ANY", + ops.LeftJoin: "LEFT OUTER", + ops.AnyLeftJoin: "LEFT ANY", + ops.RightJoin: "RIGHT OUTER", + ops.OuterJoin: "FULL OUTER", + ops.CrossJoin: "CROSS", + ops.LeftSemiJoin: "LEFT SEMI", + ops.LeftAntiJoin: "LEFT ANTI", + ops.AsOfJoin: "LEFT ASOF", +} + + +@translate_rel.register +def _join(op: ops.Join, *, left, right, **kw): + predicates = op.predicates + if predicates: + on = functools.reduce( + lambda left, right: left.and_(right), + ( + sg.condition(translate_val(predicate, **kw), dialect="duckdb") + for predicate in predicates + ), + ) + else: + on = None + join_type = _JOIN_TYPES[type(op)] + try: + return left.join(right, join_type=join_type, on=on, dialect="duckdb") + except AttributeError: + select_args = [f"{left.alias_or_name}.*"] + + # select from both the left and right side of the join if the join + # is not a filtering join (semi join or anti join); filtering joins + # only return the left side columns + if not isinstance(op, (ops.LeftSemiJoin, ops.LeftAntiJoin)): + select_args.append(f"{right.alias_or_name}.*") + return ( + sg.select(*select_args, dialect="duckdb") + .from_(left, dialect="duckdb") + .join(right, join_type=join_type, on=on, dialect="duckdb") + ) + + +@translate_rel.register +def _self_ref(op: ops.SelfReference, *, table, aliases, **kw): + if (name := aliases.get(op)) is None: + return table + return sg.alias(table, name) + + +@translate_rel.register +def _query(op: ops.SQLQueryResult, *, aliases, **_): + res = sg.parse_one(op.query, read="duckdb") + return res.subquery(aliases.get(op, "_")) + + +_SET_OP_FUNC = { + ops.Union: sg.union, + ops.Intersection: sg.intersect, + ops.Difference: sg.except_, +} + + +@translate_rel.register +def _set_op(op: ops.SetOp, *, left, right, **_): + dialect = "duckdb" + + if isinstance(left, sg.exp.Table): + left = sg.select("*", dialect=dialect).from_(left, dialect=dialect) + + if isinstance(right, sg.exp.Table): + right = sg.select("*", dialect=dialect).from_(right, dialect=dialect) + + return _SET_OP_FUNC[type(op)]( + left.args.get("this", left), + right.args.get("this", right), + distinct=op.distinct, + dialect=dialect, + ) + + +@translate_rel.register +def _limit(op: ops.Limit, *, table, **kw): + n = op.n + limited = sg.select("*").from_(table).limit(n) + + if offset := op.offset: + limited = limited.offset(offset) + return limited + + +@translate_rel.register +def _distinct(_: ops.Distinct, *, table, **kw): + return sg.select("*").distinct().from_(table) + + +@translate_rel.register(ops.DropNa) +def _dropna(op: ops.DropNa, *, table, **kw): + how = op.how + + if op.subset is None: + columns = [ops.TableColumn(op.table, name) for name in op.table.schema.names] + else: + columns = op.subset + + if columns: + raw_predicate = functools.reduce( + ops.And if how == "any" else ops.Or, + map(ops.NotNull, columns), + ) + elif how == "all": + raw_predicate = ops.Literal(False, dtype=dt.bool) + else: + raw_predicate = None + + if not raw_predicate: + return table + + tr_val = partial(translate_val, **kw) + predicate = tr_val(raw_predicate) + try: + return table.where(predicate, dialect="duckdb") + except AttributeError: + return sg.select("*").from_(table).where(predicate, dialect="duckdb") diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py new file mode 100644 index 000000000000..df7c1f179d07 --- /dev/null +++ b/ibis/backends/duckdb/compiler/values.py @@ -0,0 +1,1484 @@ +from __future__ import annotations + +import calendar +import contextlib +import functools +import math +from functools import partial +import operator +from operator import add, mul, sub +from typing import Any, Literal, Mapping + +import ibis +import ibis.common.exceptions as com +import ibis.expr.analysis as an +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import ibis.expr.rules as rlz +import sqlglot as sg +from ibis.backends.base.sql.registry import helpers +from ibis.backends.duckdb.datatypes import serialize +from toolz import flip + +# TODO: Ideally we can translate bottom up a la `relations.py` +# TODO: Find a way to remove all the dialect="duckdb" kwargs + + +@functools.singledispatch +def translate_val(op, **_): + """Translate a value expression into sqlglot.""" + raise com.OperationNotDefinedError(f"No translation rule for {type(op)}") + + +@translate_val.register(dt.DataType) +def _datatype(t, **_): + return serialize(t) + + +@translate_val.register(ops.PhysicalTable) +def _val_physical_table(op, *, aliases, **kw): + return f"{aliases.get(op, op.name)}.*" + + +@translate_val.register(ops.TableNode) +def _val_table_node(op, *, aliases, needs_alias=False, **_): + return f"{aliases[op]}.*" if needs_alias else "*" + + +@translate_val.register(ops.TableColumn) +def _column(op, *, aliases, **_): + table_name = (aliases or {}).get(op.table) + return sg.column(op.name, table=table_name) + + +@translate_val.register(ops.Alias) +def _alias(op, render_aliases: bool = True, **kw): + val = translate_val(op.arg, render_aliases=render_aliases, **kw) + if render_aliases: + return sg.alias(val, op.name, dialect="duckdb") + return val + + +### Bitwise Business + +_bitwise_mapping = { + ops.BitwiseLeftShift: "<<", + ops.BitwiseRightShift: ">>", + ops.BitwiseAnd: "&", + ops.BitwiseOr: "|", +} + + +@translate_val.register(ops.BitwiseLeftShift) +@translate_val.register(ops.BitwiseRightShift) +@translate_val.register(ops.BitwiseAnd) +@translate_val.register(ops.BitwiseOr) +def _bitwise_binary(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + _operator = _bitwise_mapping[type(op)] + + return f"{left} {_operator} {right}" + + +@translate_val.register(ops.BitwiseXor) +def _bitwise_xor(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + + return f"xor({left}, {right})" + + +@translate_val.register(ops.BitwiseNot) +def _bitwise_not(op, **kw): + value = translate_val(op.arg, **kw) + + return f"~{value}" + + +### Mathematical Calisthenics + + +@translate_val.register(ops.E) +def _euler(op, **kw): + return sg.func("exp", 1) + + +@translate_val.register(ops.Log) +def _generic_log(op, **kw): + arg, base = op.args + arg = translate_val(arg, **kw) + if base is not None: + base = translate_val(base, **kw) + return f"ln({arg}) / ln({base})" + return f"ln({arg})" + + +### Dtype Dysmorphia + + +_interval_cast_suffixes = { + "s": "Second", + "m": "Minute", + "h": "Hour", + "D": "Day", + "W": "Week", + "M": "Month", + "Q": "Quarter", + "Y": "Year", +} + + +@translate_val.register(ops.Cast) +def _cast(op, **kw): + arg = translate_val(op.arg, **kw) + + if isinstance(op.to, dt.Interval): + suffix = _interval_cast_suffixes[op.to.unit.short] + if isinstance(op.arg, ops.TableColumn): + return ( + f"INTERVAL (i) {suffix} FROM (SELECT {arg.name} FROM {arg.table}) t(i)" + ) + + else: + return f"INTERVAL {arg} {suffix}" + elif isinstance(op.to, dt.Timestamp) and isinstance(op.arg.dtype, dt.Integer): + return sg.func("to_timestamp", arg) + elif isinstance(op.to, dt.Timestamp) and (timezone := op.to.timezone) is not None: + return sg.func("timezone", timezone, arg) + + to = translate_val(op.to, **kw) + return sg.cast(expression=arg, to=to) + + +@translate_val.register(ops.TryCast) +def _try_cast(op, **kw): + return sg.func( + "try_cast", translate_val(op.arg, **kw), serialize(op.to), dialect="duckdb" + ) + + +### Comparator Conundrums + + +@translate_val.register(ops.Between) +def _between(op, **kw): + arg = translate_val(op.arg, **kw) + lower_bound = translate_val(op.lower_bound, **kw) + upper_bound = translate_val(op.upper_bound, **kw) + return f"{arg} BETWEEN {lower_bound} AND {upper_bound}" + + +@translate_val.register(ops.Negate) +def _negate(op, **kw): + arg = translate_val(op.arg, **kw) + return f"-{_parenthesize(op.arg, arg)}" + + +@translate_val.register(ops.Not) +def _not(op, **kw): + arg = translate_val(op.arg, **kw) + return f"NOT {_parenthesize(op.arg, arg)}" + + +def _parenthesize(op, arg): + # function calls don't need parens + if isinstance(op, (ops.Binary, ops.Unary)): + return f"({arg})" + else: + return arg + + +### Timey McTimeFace + + +@translate_val.register(ops.Date) +def _to_date(op, **kw): + arg = translate_val(op.arg, **kw) + return f"DATE {arg}" + + +@translate_val.register(ops.Time) +def _time(op, **kw): + arg = translate_val(op.arg, **kw) + return f"{arg}::TIME" + + +@translate_val.register(ops.Strftime) +def _strftime(op, **kw): + if not isinstance(op.format_str, ops.Literal): + raise com.UnsupportedOperationError( + f"DuckDB format_str must be a literal `str`; got {type(op.format_str)}" + ) + arg = translate_val(op.arg, **kw) + format_str = translate_val(op.format_str, **kw) + return sg.func("strftime", arg, format_str) + + +@translate_val.register(ops.TimeFromHMS) +def _time_from_hms(op, **kw): + hours = translate_val(op.hours, **kw) + minutes = translate_val(op.minutes, **kw) + seconds = translate_val(op.seconds, **kw) + return sg.func("make_time", hours, minutes, seconds) + + +@translate_val.register(ops.StringToTimestamp) +def _string_to_timestamp(op, **kw): + arg = translate_val(op.arg, **kw) + format_str = translate_val(op.format_str, **kw) + return sg.func("strptime", arg, format_str) + + +@translate_val.register(ops.ExtractEpochSeconds) +def _extract_epoch_seconds(op, **kw): + arg = translate_val(op.arg, **kw) + # TODO: do we need the TIMESTAMP cast? + return f"epoch({arg}::TIMESTAMP)" + + +_extract_mapping = { + ops.ExtractYear: "year", + ops.ExtractMonth: "month", + ops.ExtractDay: "day", + ops.ExtractDayOfYear: "dayofyear", + ops.ExtractQuarter: "quarter", + ops.ExtractWeekOfYear: "week", + ops.ExtractHour: "hour", + ops.ExtractMinute: "minute", + ops.ExtractSecond: "second", +} + + +@translate_val.register(ops.ExtractYear) +@translate_val.register(ops.ExtractMonth) +@translate_val.register(ops.ExtractDay) +@translate_val.register(ops.ExtractDayOfYear) +@translate_val.register(ops.ExtractQuarter) +@translate_val.register(ops.ExtractWeekOfYear) +@translate_val.register(ops.ExtractHour) +@translate_val.register(ops.ExtractMinute) +@translate_val.register(ops.ExtractSecond) +def _extract_time(op, **kw): + part = _extract_mapping[type(op)] + timestamp = translate_val(op.arg, **kw) + return f"extract({part}, {timestamp})" + + +# DuckDB extracts subminute microseconds and milliseconds +# so we have to finesse it a little bit +@translate_val.register(ops.ExtractMicrosecond) +def _extract_microsecond(op, **kw): + arg = translate_val(op.arg, **kw) + dtype = serialize(op.dtype) + + return f"extract('us', {arg}::TIMESTAMP) % 1000000" + + +@translate_val.register(ops.ExtractMillisecond) +def _extract_microsecond(op, **kw): + arg = translate_val(op.arg, **kw) + dtype = serialize(op.dtype) + + return f"extract('ms', {arg}::TIMESTAMP) % 1000" + + +@translate_val.register(ops.Date) +def _date(op, **kw): + arg = translate_val(op.arg, **kw) + + return f"{arg}::DATE" + + +@translate_val.register(ops.DateTruncate) +@translate_val.register(ops.TimestampTruncate) +@translate_val.register(ops.TimeTruncate) +def _truncate(op, **kw): + unit_mapping = { + "Y": "year", + "M": "month", + "W": "week", + "D": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "ms", + "us": "us", + } + + unit = op.unit.short + arg = translate_val(op.arg, **kw) + try: + duckunit = unit_mapping[unit] + except KeyError: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") + + return f"date_trunc('{duckunit}', {arg})" + + +@translate_val.register(ops.DateFromYMD) +def _date_from_ymd(op, **kw): + y = translate_val(op.year, **kw) + m = translate_val(op.month, **kw) + d = translate_val(op.day, **kw) + return f"make_date({y}, {m}, {d})" + + +@translate_val.register(ops.DayOfWeekIndex) +def _day_of_week_index(op, **kw): + arg = translate_val(op.arg, **kw) + return f"(dayofweek({arg}) + 6) % 7" + + +@translate_val.register(ops.TimestampFromUNIX) +def _timestamp_from_unix(op, **kw): + arg = translate_val(op.arg, **kw) + if (unit := op.unit.short) in {"ns"}: + raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") + + if op.unit.short == "ms": + return f"to_timestamp({arg[:-3]}) + INTERVAL {arg[-3:]} millisecond" + elif op.unit.short == "us": + return f"to_timestamp({arg[:-6]}) + INTERVAL {arg[-6:]} microsecond" + + return f"to_timestamp({arg})" + + +@translate_val.register(ops.TimestampFromYMDHMS) +def _timestamp_from_ymdhms(op, **kw): + year = translate_val(op.year, **kw) + month = translate_val(op.month, **kw) + day = translate_val(op.day, **kw) + hour = translate_val(op.hours, **kw) + minute = translate_val(op.minutes, **kw) + second = translate_val(op.seconds, **kw) + + if (timezone := op.dtype.timezone) is not None: + return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" + else: + return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" + + +### Interval Marginalia + + +_interval_mapping = { + ops.IntervalAdd: operator.add, + ops.IntervalSubtract: operator.sub, + ops.IntervalMultiply: operator.mul, +} + + +@translate_val.register(ops.IntervalAdd) +@translate_val.register(ops.IntervalSubtract) +@translate_val.register(ops.IntervalMultiply) +def _interval_binary(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + _operator = _interval_mapping[type(op)] + + return operator(left, right) + + +def _interval_format(op): + dtype = op.dtype + if dtype.unit.short == "ns": + raise com.UnsupportedOperationError( + "Duckdb doesn't support nanosecond interval resolutions" + ) + + return f"INTERVAL {op.value} {dtype.resolution.upper()}" + + +@translate_val.register(ops.IntervalFromInteger) +def _interval_from_integer(op, **kw): + dtype = op.dtype + if dtype.unit.short == "ns": + raise com.UnsupportedOperationError( + "Duckdb doesn't support nanosecond interval resolutions" + ) + + arg = translate_val(op.arg, **kw) + if op.dtype.resolution == "week": + return sg.func("to_days", arg * 7) + # TODO: make less gross + # to_days, to_years, etc... + return sg.func(f"to_{op.dtype.resolution}s", arg) + + +### String Instruments + + +@translate_val.register(ops.Substring) +def _substring(op, **kw): + # Duckdb is 1-indexed + arg = translate_val(op.arg, **kw) + start = translate_val(op.start, **kw) + arg_length = f"length({arg})" + if op.length is not None: + length = translate_val(op.length, **kw) + suffix = f", {length}" + else: + suffix = "" + + if_pos = f"substring({arg}, {start} + 1{suffix})" + if_neg = f"substring({arg}, {arg_length} + {start} + 1{suffix})" + return f"if({start} >= 0, {if_pos}, {if_neg})" + + +@translate_val.register(ops.StringFind) +def _string_find(op, **kw): + if op.end is not None: + raise com.UnsupportedOperationError("String find doesn't support end argument") + + arg = translate_val(op.arg, **kw) + substr = translate_val(op.substr, **kw) + + return f"instr({arg}, {substr}) - 1" + + +@translate_val.register(ops.RegexSearch) +def _regex_search(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + return f"regexp_matches({arg}, {pattern}, 's')" + + +@translate_val.register(ops.RegexReplace) +def _regex_replace(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + replacement = translate_val(op.replacement, **kw) + return sg.func("regexp_replace", arg, pattern, replacement, "g", dialect="duckdb") + + +@translate_val.register(ops.RegexExtract) +def _regex_extract(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + group = translate_val(op.index, **kw) + return f"regexp_extract({arg}, {pattern}, {group})" + + +@translate_val.register(ops.Levenshtein) +def _levenshtein(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return f"levenshtein({left}, {right})" + + +### Simple Ops + +_simple_ops = { + ops.Power: "pow", + # Unary operations + ops.IsNan: "isnan", + ops.IsInf: "isinf", + ops.Abs: "abs", + ops.Ceil: "ceil", + ops.Floor: "floor", + ops.Exp: "exp", + ops.Sqrt: "sqrt", + ops.Ln: "ln", + ops.Log2: "log2", + ops.Log10: "log", + ops.Acos: "acos", + ops.Asin: "asin", + ops.Atan: "atan", + ops.Atan2: "atan2", + ops.Cos: "cos", + ops.Sin: "sin", + ops.Tan: "tan", + ops.Cot: "cot", + ops.Pi: "pi", + ops.RandomScalar: "random", + ops.Sign: "sign", + # Unary aggregates + # ops.ApproxMedian: "median", # TODO + # ops.Median: "quantileExactExclusive", # TODO + ops.ApproxCountDistinct: "list_unique", + ops.Mean: "avg", + ops.Sum: "sum", + ops.Max: "max", + ops.Min: "min", + ops.Any: "any_value", + ops.All: "min", + ops.ArgMin: "arg_min", + ops.Mode: "mode", + ops.ArgMax: "arg_max", + # ops.ArrayCollect: "groupArray", # TODO + ops.Count: "count", + ops.CountDistinct: "list_unique", + ops.First: "first", + ops.Last: "last", + # string operations + ops.StringContains: "contains", + ops.StringLength: "length", + ops.Lowercase: "lower", + ops.Uppercase: "upper", + ops.Reverse: "reverse", + ops.StringReplace: "replace", + ops.StartsWith: "prefix", + ops.EndsWith: "suffix", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.LStrip: "ltrim", + ops.RStrip: "rtrim", + ops.Strip: "trim", + ops.StringAscii: "ascii", + ops.StrRight: "right", + # Temporal operations + ops.TimestampNow: "current_timestamp", + # Other operations + ops.Where: "if", + ops.ArrayLength: "length", + ops.ArrayConcat: "arrayConcat", # TODO + ops.Unnest: "arrayJoin", # TODO + ops.Degrees: "degrees", + ops.Radians: "radians", + ops.NullIf: "nullIf", + ops.MapContains: "mapContains", # TODO + ops.MapLength: "length", + ops.MapKeys: "mapKeys", # TODO + ops.MapValues: "mapValues", # TODO + ops.MapMerge: "mapUpdate", # TODO + ops.ArrayDistinct: "arrayDistinct", # TODO + ops.ArraySort: "arraySort", # TODO + ops.ArrayContains: "has", + ops.FirstValue: "first_value", + ops.LastValue: "last_value", + ops.NTile: "ntile", + ops.Hash: "hash", +} + + +def _agg(func_name): + def formatter(op, **kw): + return _aggregate(op, func_name, where=op.where, **kw) + + return formatter + + +for _op, _name in _simple_ops.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + translate_val.register(_op)(_agg(_name)) + else: + + @translate_val.register(_op) + def _fmt(op, _name: str = _name, **kw): + return sg.func( + _name, *map(partial(translate_val, **kw), op.args), dialect="duckdb" + ) + + +del _fmt, _name, _op + + +### NULL PLAYER CHARACTER +# ops.IsNull: "isNull", # TODO +# ops.NotNull: "isNotNull", # TODO +# ops.IfNull: "ifNull", # TODO +@translate_val.register(ops.IsNull) +def _is_null(op, **kw): + arg = translate_val(op.arg, **kw) + return arg.is_(sg.expressions.null()) + + +@translate_val.register(ops.NotNull) +def _is_not_null(op, **kw): + arg = translate_val(op.arg, **kw) + return arg.is_(sg.not_(sg.expressions.null())) + + +@translate_val.register(ops.IfNull) +def _if_null(op, **kw): + arg = translate_val(op.arg, **kw) + ifnull = translate_val(op.ifnull_expr, **kw) + return sg.func("ifnull", arg, ifnull, dialect="duckdb") + + +### Definitely Not Tensors + + +@translate_val.register(ops.ArrayIndex) +def _array_index_op(op, **kw): + arg = translate_val(op.arg, **kw) + index = translate_val(op.index, **kw) + correct_idx = f"if({index} >= 0, {index} + 1, {index})" + return f"array_extract({arg}, {correct_idx})" + + +@translate_val.register(ops.InValues) +def _in_values(op, **kw): + if not op.options: + return False + value = translate_val(op.value, **kw) + options = [translate_val(x, **kw) for x in op.options] + return sg.func("list_contains", options, value, dialect="duckdb") + + +@translate_val.register(ops.InColumn) +def _in_column(op, **kw): + value = translate_val(op.value, **kw) + options = translate_val(ops.TableArrayView(op.options.to_expr().as_table()), **kw) + # TODO: fix? + # if not isinstance(options, sa.sql.Selectable): + # options = sg.select(options) + return value.isin(options) + + +### LITERALLY + + +# TODO: need to go through this carefully +@translate_val.register(ops.Literal) +def _literal(op, **kw): + value = op.value + dtype = op.dtype + if value is None and dtype.nullable: + if dtype.is_null(): + return "Null" + return f"CAST(Null AS {serialize(dtype)})" + if dtype.is_boolean(): + return str(int(bool(value))) + elif dtype.is_inet(): + com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") + elif dtype.is_string(): + return value + elif dtype.is_decimal(): + precision = dtype.precision + scale = dtype.scale + if precision is None: + precision = 38 + if scale is None: + scale = 9 + if not 1 <= precision <= 38: + raise NotImplementedError( + f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" + ) + + # TODO: handle if `value` is "Infinity" + + return f"{value!s}::decimal({precision}, {scale})" + elif dtype.is_numeric(): + if math.isinf(value): + return f"'{repr(value)}inity'::FLOAT" + elif math.isnan(value): + return "'NaN'::FLOAT" + return value + elif dtype.is_interval(): + return _interval_format(op) + elif dtype.is_timestamp(): + year = op.value.year + month = op.value.month + day = op.value.day + hour = op.value.hour + minute = op.value.minute + second = op.value.second + if op.value.microsecond: + microsecond = op.value.microsecond / 1e6 + second += microsecond + if (timezone := dtype.timezone) is not None: + return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" + else: + return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" + elif dtype.is_date(): + return f"make_date({op.value.year}, {op.value.month}, {op.value.day})" + elif dtype.is_array(): + value_type = dtype.value_type + values = ", ".join( + _literal(ops.Literal(v, dtype=value_type), **kw) for v in value + ) + return f"[{values}]" + elif dtype.is_map(): + value_type = dtype.value_type + values = ", ".join( + f"{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}" + for k, v in value.items() + ) + return f"map({values})" + elif dtype.is_struct(): + fields = ", ".join( + _literal(ops.Literal(v, dtype=subdtype), **kw) + for subdtype, v in zip(dtype.types, value.values()) + ) + return f"tuple({fields})" + else: + raise NotImplementedError(f"Unsupported type: {dtype!r}") + + +### BELOW HERE BE DRAGONS + + +# TODO +@translate_val.register(ops.ArrayRepeat) +def _array_repeat_op(op, **kw): + arg = translate_val(op.arg, **kw) + times = translate_val(op.times, **kw) + from_ = f"(SELECT {arg} AS arr FROM system.numbers LIMIT {times})" + query = sg.parse_one( + f"SELECT arrayFlatten(groupArray(arr)) FROM {from_}", read="duckdb" + ) + return query.subquery() + + +# TODO +@translate_val.register(ops.ArraySlice) +def _array_slice_op(op, **kw): + arg = translate_val(op.arg, **kw) + start = translate_val(op.start, **kw) + start = _parenthesize(op.start, start) + start_correct = f"if({start} < 0, {start}, {start} + 1)" + + if (stop := op.stop) is not None: + stop = translate_val(stop, **kw) + stop = _parenthesize(op.stop, stop) + + neg_start = f"(length({arg}) + {start})" + diff_fmt = f"greatest(-0, {stop} - {{}})".format + + length = ( + f"if({stop} < 0, {stop}, " + f"if({start} < 0, {diff_fmt(neg_start)}, {diff_fmt(start)}))" + ) + + return f"arraySlice({arg}, {start_correct}, {length})" + + return f"arraySlice({arg}, {start_correct})" + + +@translate_val.register(ops.CountStar) +def _count_star(op, **kw): + sql = sg.expressions.Count(this=sg.expressions.Star()) + if (predicate := op.where) is not None: + return sg.select(sql).where(predicate) + return sql + + +@translate_val.register(ops.NotAny) +def _not_any(op, **kw): + return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) + + +@translate_val.register(ops.NotAll) +def _not_all(op, **kw): + return translate_val(ops.Any(ops.Not(op.arg), where=op.where), **kw) + + +# TODO +def _quantile_like(func_name: str, op: ops.Node, quantile: str, **kw): + args = [_sql(translate_val(op.arg, **kw))] + + if (where := op.where) is not None: + args.append(_sql(translate_val(where, **kw))) + func_name += "If" + + return f"{func_name}({quantile})({', '.join(args)})" + + +@translate_val.register(ops.Quantile) +def _quantile(op, **kw): + quantile = _sql(translate_val(op.quantile, **kw)) + return _quantile_like("quantile", op, quantile, **kw) + + +@translate_val.register(ops.MultiQuantile) +def _multi_quantile(op, **kw): + if not isinstance(op.quantile, ops.Literal): + raise TypeError("Duckdb quantile only accepts a list of Python floats") + + quantile = ", ".join(map(str, op.quantile.value)) + return _quantile_like("quantiles", op, quantile, **kw) + + +def _agg_variance_like(func): + variants = {"sample": f"{func}_samp", "pop": f"{func}_pop"} + + def formatter(op, **kw): + return _aggregate(op, variants[op.how], where=op.where, **kw) + + return formatter + + +@translate_val.register(ops.Correlation) +def _corr(op, **kw): + if op.how == "pop": + raise ValueError("Duckdb only implements `sample` correlation coefficient") + return _aggregate(op, "corr", where=op.where, **kw) + + +def _aggregate(op, func, *, where=None, **kw): + args = [ + translate_val(arg, **kw) + for argname, arg in zip(op.argnames, op.args) + if argname not in ("where", "how") + ] + if where is not None: + predicate = translate_val(where, **kw) + return sg.func(func, *args).where(predicate) + + res = sg.func(func, *args) + return res + + +@translate_val.register(ops.Arbitrary) +def _arbitrary(op, **kw): + functions = { + "first": "first", + "last": "last", + } + return _aggregate(op, functions[op.how], where=op.where, **kw) + + +@translate_val.register(ops.FindInSet) +def _index_of(op, **kw): + values = map(partial(translate_val, **kw), op.values) + values = ", ".join(map(_sql, values)) + needle = translate_val(op.needle, **kw) + return f"list_indexof([{values}], {needle}) - 1" + + +@translate_val.register(ops.Round) +def _round(op, **kw): + arg = translate_val(op.arg, **kw) + if (digits := op.digits) is not None: + return f"round({arg}, {translate_val(digits, **kw)})" + return f"round({arg})" + + +@translate_val.register(tuple) +def _node_list(op, punct="()", **kw): + values = ", ".join(map(_sql, map(partial(translate_val, **kw), op))) + return f"{punct[0]}{values}{punct[1]}" + + +def _sql(obj, dialect="duckdb"): + try: + return obj.sql(dialect=dialect) + except AttributeError: + return obj + + +# TODO +@translate_val.register(ops.SimpleCase) +@translate_val.register(ops.SearchedCase) +def _case(op, **kw): + buf = ["CASE"] + + if (base := getattr(op, "base", None)) is not None: + buf.append(translate_val(base, **kw)) + + for when, then in zip(op.cases, op.results): + buf.append(f"WHEN {translate_val(when, **kw)}") + buf.append(f"THEN {translate_val(then, **kw)}") + + if (default := op.default) is not None: + buf.append(f"ELSE {translate_val(default, **kw)}") + + buf.append("END") + return " ".join(map(_sql, buf)) + + +@translate_val.register(ops.TableArrayView) +def _table_array_view(op, *, cache, **kw): + table = op.table + try: + return cache[table] + except KeyError: + from ibis.backends.duckdb.compiler.relations import translate_rel + + # ignore the top level table, so that we can compile its dependencies + (leaf,) = an.find_immediate_parent_tables(table, keep_input=False) + res = translate_rel(table, table=cache[leaf], cache=cache, **kw) + return res.subquery() + + +# TODO +@translate_val.register(ops.ExistsSubquery) +@translate_val.register(ops.NotExistsSubquery) +def _exists_subquery(op, **kw): + from ibis.backends.duckdb.compiler.relations import translate_rel + + foreign_table = translate_rel(op.foreign_table, **kw) + predicates = translate_val(op.predicates, **kw) + subq = ( + sg.select(1) + .from_(foreign_table, dialect="duckdb") + .where(sg.condition(predicates), dialect="duckdb") + ) + prefix = "NOT " * isinstance(op, ops.NotExistsSubquery) + return f"{prefix}EXISTS ({subq})" + + +@translate_val.register(ops.StringSplit) +def _string_split(op, **kw): + arg = translate_val(op.arg, **kw) + delimiter = translate_val(op.delimiter, **kw) + return f"string_split({arg}, {delimiter})" + + +@translate_val.register(ops.StringJoin) +def _string_join(op, **kw): + arg = map(partial(translate_val, **kw), op.arg) + sep = translate_val(op.sep, **kw) + elements = ", ".join(map(_sql, arg)) + return f"list_aggregate([{elements}], 'string_agg', {sep})" + + +@translate_val.register(ops.StringConcat) +def _string_concat(op, **kw): + arg = map(partial(translate_val, **kw), op.arg) + return " || ".join(map(_sql, arg)) + + +@translate_val.register(ops.StringSQLLike) +def _string_like(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + return f"{arg} LIKE {pattern}" + + +@translate_val.register(ops.StringSQLILike) +def _string_ilike(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + return f"lower({arg}) LIKE lower({pattern})" + + +# TODO +@translate_val.register(ops.Capitalize) +def _string_capitalize(op, **kw): + arg = translate_val(op.arg, **kw) + return f"CONCAT(UPPER(SUBSTR({arg}, 1, 1)), LOWER(SUBSTR({arg}, 2)))" + + +# TODO +@translate_val.register(ops.GroupConcat) +def _group_concat(op, **kw): + arg = translate_val(op.arg, **kw) + sep = translate_val(op.sep, **kw) + + args = [arg] + func = "groupArray" + + if (where := op.where) is not None: + func += "If" + args.append(translate_val(where, **kw)) + + joined_args = ", ".join(map(_sql, args)) + call = f"{func}({joined_args})" + expr = f"list_concat({call}, {sep})" + return f"CASE WHEN empty({call}) THEN NULL ELSE {expr} END" + + +# TODO +def _bit_agg(func): + def _translate(op, **kw): + arg = translate_val(op.arg, **kw) + if not isinstance((type := op.arg.dtype), dt.UnsignedInteger): + nbits = type.nbytes * 8 + arg = f"reinterpretAsUInt{nbits}({arg})" + + if (where := op.where) is not None: + return f"{func}If({arg}, {translate_val(where, **kw)})" + else: + return f"{func}({arg})" + + return _translate + + +@translate_val.register(ops.ArrayColumn) +def _array_column(op, **kw): + cols = map(partial(translate_val, **kw), op.cols) + args = ", ".join(map(_sql, cols)) + return f"[{args}]" + + +# TODO +@translate_val.register(ops.StructColumn) +def _struct_column(op, **kw): + values = translate_val(op.values, **kw) + struct_type = serialize(op.dtype.copy(nullable=False)) + return f"CAST({values} AS {struct_type})" + + +@translate_val.register(ops.Clip) +def _clip(op, **kw): + arg = translate_val(op.arg, **kw) + if (upper := op.upper) is not None: + arg = f"least({translate_val(upper, **kw)}, {arg})" + + if (lower := op.lower) is not None: + arg = f"greatest({translate_val(lower, **kw)}, {arg})" + + return arg + + +@translate_val.register(ops.StructField) +def _struct_field(op, render_aliases: bool = False, **kw): + arg = op.arg + arg_dtype = arg.dtype + arg = translate_val(op.arg, render_aliases=render_aliases, **kw) + idx = arg_dtype.names.index(op.field) + typ = arg_dtype.types[idx] + return f"CAST({arg}.{idx + 1} AS {serialize(typ)})" + + +# TODO +@translate_val.register(ops.NthValue) +def _nth_value(op, **kw): + arg = translate_val(op.arg, **kw) + nth = translate_val(op.nth, **kw) + return f"nth_value({arg}, ({nth}) + 1)" + + +@translate_val.register(ops.Repeat) +def _repeat(op, **kw): + arg = translate_val(op.arg, **kw) + times = translate_val(op.times, **kw) + return f"repeat({arg}, {times})" + + +# TODO +@translate_val.register(ops.NullIfZero) +def _null_if_zero(op, **kw): + arg = translate_val(op.arg, **kw) + return f"nullIf({arg}, 0)" + + +# TODO +@translate_val.register(ops.ZeroIfNull) +def _zero_if_null(op, **kw): + arg = translate_val(op.arg, **kw) + return f"ifNull({arg}, 0)" + + +@translate_val.register(ops.FloorDivide) +def _floor_divide(op, **kw): + new_op = ops.Floor(ops.Divide(op.left, op.right)) + return translate_val(new_op, **kw) + + +@translate_val.register(ops.ScalarParameter) +def _scalar_param(op, params: Mapping[ops.Node, Any], **kw): + raw_value = params[op] + dtype = op.dtype + if isinstance(dtype, dt.Struct): + literal = ibis.struct(raw_value, type=dtype) + elif isinstance(dtype, dt.Map): + literal = ibis.map(raw_value, type=dtype) + else: + literal = ibis.literal(raw_value, type=dtype) + return translate_val(literal.op(), **kw) + + +# TODO +def contains(op_string: Literal["IN", "NOT IN"]) -> str: + def tr(op, *, cache, **kw): + from ibis.backends.duckdb.compiler import translate + + value = op.value + options = op.options + if isinstance(options, tuple) and not options: + return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string] + + left_arg = translate_val(value, **kw) + if helpers.needs_parens(value): + left_arg = helpers.parenthesize(left_arg) + + # special case non-foreign isin/notin expressions + if ( + not isinstance(options, tuple) + and options.output_shape is rlz.Shape.COLUMNAR + ): + # this will fail to execute if there's a correlation, but it's too + # annoying to detect so we let it through to enable the + # uncorrelated use case (pandas-style `.isin`) + subquery = translate(options.to_expr().as_table().op(), {}) + right_arg = f"({_sql(subquery)})" + else: + right_arg = _sql(translate_val(options, cache=cache, **kw)) + + # we explicitly do NOT parenthesize the right side because it doesn't + # make sense to do so for Sequence operations + return f"{left_arg} {op_string} {right_arg}" + + return tr + + +# TODO +# translate_val.register(ops.Contains)(contains("IN")) +# translate_val.register(ops.NotContains)(contains("NOT IN")) + + +# TODO +@translate_val.register(ops.DayOfWeekName) +def day_of_week_name(op, **kw): + arg = op.arg + nullable = arg.dtype.nullable + empty_string = ops.Literal("", dtype=dt.String(nullable=nullable)) + weekdays = range(7) + return translate_val( + ops.NullIf( + ops.SimpleCase( + base=ops.DayOfWeekIndex(arg), + cases=[ + ops.Literal(day, dtype=dt.Int8(nullable=nullable)) + for day in weekdays + ], + results=[ + ops.Literal( + calendar.day_name[day], + dtype=dt.String(nullable=nullable), + ) + for day in weekdays + ], + default=empty_string, + ), + empty_string, + ), + **kw, + ) + + +@translate_val.register(ops.IdenticalTo) +def _identical_to(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return sg.exp.NullSafeEQ(this=left, expression=right) + + +@translate_val.register(ops.Greatest) +@translate_val.register(ops.Least) +@translate_val.register(ops.Coalesce) +def _vararg_func(op, **kw): + return sg.func( + f"{op.__class__.__name__.lower()}", + *map(partial(translate_val, **kw), op.arg), + dialect="duckdb", + ) + + +# TODO +@translate_val.register(ops.Map) +def _map(op, **kw): + keys = translate_val(op.keys, **kw) + values = translate_val(op.values, **kw) + typ = serialize(op.dtype) + return f"CAST(({keys}, {values}) AS {typ})" + + +# TODO +@translate_val.register(ops.MapGet) +def _map_get(op, **kw): + arg = translate_val(op.arg, **kw) + key = translate_val(op.key, **kw) + default = translate_val(op.default, **kw) + return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})" + + +def _binary_infix(symbol: str): + def formatter(op, **kw): + left = translate_val(op_left := op.left, **kw) + right = translate_val(op_right := op.right, **kw) + + return symbol(left, right) + + return formatter + + +import operator + +_binary_infix_ops = { + # Binary operations + ops.Add: operator.add, + ops.Subtract: operator.sub, + ops.Multiply: operator.mul, + ops.Divide: operator.truediv, + ops.Modulus: operator.mod, + # Comparisons + ops.GreaterEqual: operator.ge, + ops.Greater: operator.gt, + ops.LessEqual: operator.le, + ops.Less: operator.lt, + # Boolean comparisons + ops.And: operator.and_, + ops.Or: operator.or_, + ops.DateAdd: operator.add, + ops.DateSub: operator.sub, + ops.DateDiff: operator.sub, + ops.TimestampAdd: operator.add, + ops.TimestampSub: operator.sub, + ops.TimestampDiff: operator.sub, +} + + +for _op, _sym in _binary_infix_ops.items(): + translate_val.register(_op)(_binary_infix(_sym)) + +del _op, _sym + + +@translate_val.register(ops.Equals) +def _equals(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return left.eq(right) + + +@translate_val.register(ops.NotEquals) +def _equals(op, **kw): + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + breakpoint() + return left.eq(right) + + +# TODO +translate_val.register(ops.BitAnd)(_bit_agg("groupBitAnd")) +translate_val.register(ops.BitOr)(_bit_agg("groupBitOr")) +translate_val.register(ops.BitXor)(_bit_agg("groupBitXor")) + +translate_val.register(ops.StandardDev)(_agg_variance_like("stddev")) +translate_val.register(ops.Variance)(_agg_variance_like("var")) +translate_val.register(ops.Covariance)(_agg_variance_like("covar")) + + +@translate_val.register +def _sort_key(op: ops.SortKey, **kw): + arg = translate_val(op.expr, **kw) + direction = "ASC" if op.ascending else "DESC" + return f"{_sql(arg)} {direction}" + + +_cumulative_to_reduction = { + ops.CumulativeSum: ops.Sum, + ops.CumulativeMin: ops.Min, + ops.CumulativeMax: ops.Max, + ops.CumulativeMean: ops.Mean, + ops.CumulativeAny: ops.Any, + ops.CumulativeAll: ops.All, +} + + +def cumulative_to_window(func, frame): + klass = _cumulative_to_reduction[type(func)] + new_op = klass(*func.args) + new_frame = frame.copy(start=None, end=0) + new_expr = an.windowize_function(new_op.to_expr(), frame=new_frame) + return new_expr.op() + + +def format_window_boundary(boundary, **kw): + value = translate_val(boundary.value, **kw) + if boundary.preceding: + return f"{value} PRECEDING" + else: + return f"{value} FOLLOWING" + + +# TODO +def format_window_frame(func, frame, **kw): + components = [] + + if frame.how == "rows" and frame.max_lookback is not None: + raise NotImplementedError( + "Rows with max lookback is not implemented for the Duckdb backend." + ) + + if frame.group_by: + partition_args = ", ".join( + map(_sql, map(partial(translate_val, **kw), frame.group_by)) + ) + components.append(f"PARTITION BY {partition_args}") + + if frame.order_by: + order_args = ", ".join( + map(_sql, map(partial(translate_val, **kw), frame.order_by)) + ) + components.append(f"ORDER BY {order_args}") + + frame_clause_not_allowed = ( + ops.Lag, + ops.Lead, + ops.DenseRank, + ops.MinRank, + ops.NTile, + ops.PercentRank, + ops.CumeDist, + ops.RowNumber, + ) + + if frame.start is None and frame.end is None: + # no-op, default is full sample + pass + elif not isinstance(func, frame_clause_not_allowed): + if frame.start is None: + start = "UNBOUNDED PRECEDING" + else: + start = format_window_boundary(frame.start, **kw) + + if frame.end is None: + end = "UNBOUNDED FOLLOWING" + else: + end = format_window_boundary(frame.end, **kw) + + frame = f"{frame.how.upper()} BETWEEN {start} AND {end}" + components.append(frame) + + return f"OVER ({' '.join(components)})" + + +# TODO +_map_interval_to_microseconds = { + "W": 604800000000, + "D": 86400000000, + "h": 3600000000, + "m": 60000000, + "s": 1000000, + "ms": 1000, + "us": 1, + "ns": 0.001, +} + + +# TODO +UNSUPPORTED_REDUCTIONS = ( + ops.ApproxMedian, + ops.GroupConcat, + ops.ApproxCountDistinct, +) + + +# TODO +@translate_val.register(ops.WindowFunction) +def _window(op: ops.WindowFunction, **kw: Any): + if isinstance(op.func, UNSUPPORTED_REDUCTIONS): + raise com.UnsupportedOperationError( + f"{type(op.func)} is not supported in window functions" + ) + + if isinstance(op.func, ops.CumulativeOp): + arg = cumulative_to_window(op.func, op.frame) + return translate_val(arg, **kw) + + window_formatted = format_window_frame(op, op.frame, **kw) + func = op.func.__window_op__ + func_formatted = translate_val(func, **kw) + result = f"{func_formatted} {window_formatted}" + + if isinstance(func, ops.RankBase): + return f"({result} - 1)" + + return result + + +# TODO +def shift_like(op_class, name): + @translate_val.register(op_class) + def formatter(op, **kw): + arg = op.arg + offset = op.offset + default = op.default + + arg_fmt = translate_val(arg, **kw) + pieces = [arg_fmt] + + if default is not None: + if offset is None: + offset_fmt = "1" + else: + offset_fmt = translate_val(offset, **kw) + + default_fmt = translate_val(default, **kw) + + pieces.append(offset_fmt) + pieces.append(default_fmt) + elif offset is not None: + offset_fmt = translate_val(offset, **kw) + pieces.append(offset_fmt) + + return f"{name}({', '.join(map(_sql, pieces))})" + + return formatter + + +# TODO +shift_like(ops.Lag, "lagInFrame") +shift_like(ops.Lead, "leadInFrame") + + +# TODO +@translate_val.register(ops.RowNumber) +def _row_number(_, **kw): + return "row_number()" + + +# TODO +@translate_val.register(ops.DenseRank) +def _dense_rank(_, **kw): + return "dense_rank()" + + +# TODO +@translate_val.register(ops.MinRank) +def _rank(_, **kw): + return "rank()" + + +@translate_val.register(ops.ArrayStringJoin) +def _array_string_join(op, **kw): + arg = translate_val(op.arg, **kw) + sep = translate_val(op.sep, **kw) + return f"list_aggregate({arg}, 'string_agg', {sep})" + + +@translate_val.register(ops.Argument) +def _argument(op, **_): + return op.name + + +# TODO +@translate_val.register(ops.ArrayMap) +def _array_map(op, **kw): + arg = translate_val(op.arg, **kw) + result = translate_val(op.result, **kw) + return f"arrayMap(({op.parameter}) -> {result}, {arg})" + + +# TODO +@translate_val.register(ops.ArrayFilter) +def _array_filter(op, **kw): + arg = translate_val(op.arg, **kw) + result = translate_val(op.result, **kw) + return f"arrayFilter(({op.parameter}) -> {result}, {arg})" + + +@translate_val.register(ops.ArrayPosition) +def _array_position(op, **kw): + arg = translate_val(op.arg, **kw) + el = translate_val(op.other, **kw) + return f"list_indexof({arg}, {el}) - 1" + + +@translate_val.register(ops.ArrayRemove) +def _array_remove(op, **kw): + return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw) + + +@translate_val.register(ops.ArrayUnion) +def _array_union(op, **kw): + return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw) + + +# TODO +@translate_val.register(ops.ArrayZip) +def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: + arglist = [] + for arg in op.arg: + sql_arg = translate_val(arg, **kw) + with contextlib.suppress(AttributeError): + sql_arg = sql_arg.sql(dialect="duckdb") + arglist.append(sql_arg) + return f"arrayZip({', '.join(arglist)})" diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py index e931867f4b14..80102f5b53b4 100644 --- a/ibis/backends/duckdb/datatypes.py +++ b/ibis/backends/duckdb/datatypes.py @@ -1,41 +1,42 @@ from __future__ import annotations -import duckdb_engine.datatypes as ducktypes +import functools + import sqlalchemy.dialects.postgresql as psql import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import AlchemyType from ibis.backends.base.sqlglot.datatypes import DuckDBType as SqlglotDuckdbType -_from_duckdb_types = { - psql.BYTEA: dt.Binary, - psql.UUID: dt.UUID, - ducktypes.TinyInteger: dt.Int8, - ducktypes.SmallInteger: dt.Int16, - ducktypes.Integer: dt.Int32, - ducktypes.BigInteger: dt.Int64, - ducktypes.HugeInteger: dt.Decimal(38, 0), - ducktypes.UInt8: dt.UInt8, - ducktypes.UTinyInteger: dt.UInt8, - ducktypes.UInt16: dt.UInt16, - ducktypes.USmallInteger: dt.UInt16, - ducktypes.UInt32: dt.UInt32, - ducktypes.UInteger: dt.UInt32, - ducktypes.UInt64: dt.UInt64, - ducktypes.UBigInteger: dt.UInt64, -} - -_to_duckdb_types = { - dt.UUID: psql.UUID, - dt.Int8: ducktypes.TinyInteger, - dt.Int16: ducktypes.SmallInteger, - dt.Int32: ducktypes.Integer, - dt.Int64: ducktypes.BigInteger, - dt.UInt8: ducktypes.UTinyInteger, - dt.UInt16: ducktypes.USmallInteger, - dt.UInt32: ducktypes.UInteger, - dt.UInt64: ducktypes.UBigInteger, -} +# _from_duckdb_types = { +# psql.BYTEA: dt.Binary, +# psql.UUID: dt.UUID, +# ducktypes.TinyInteger: dt.Int8, +# ducktypes.SmallInteger: dt.Int16, +# ducktypes.Integer: dt.Int32, +# ducktypes.BigInteger: dt.Int64, +# ducktypes.HugeInteger: dt.Decimal(38, 0), +# ducktypes.UInt8: dt.UInt8, +# ducktypes.UTinyInteger: dt.UInt8, +# ducktypes.UInt16: dt.UInt16, +# ducktypes.USmallInteger: dt.UInt16, +# ducktypes.UInt32: dt.UInt32, +# ducktypes.UInteger: dt.UInt32, +# ducktypes.UInt64: dt.UInt64, +# ducktypes.UBigInteger: dt.UInt64, +# } + +# _to_duckdb_types = { +# dt.UUID: psql.UUID, +# dt.Int8: ducktypes.TinyInteger, +# dt.Int16: ducktypes.SmallInteger, +# dt.Int32: ducktypes.Integer, +# dt.Int64: ducktypes.BigInteger, +# dt.UInt8: ducktypes.UTinyInteger, +# dt.UInt16: ducktypes.USmallInteger, +# dt.UInt32: ducktypes.UInteger, +# dt.UInt64: ducktypes.UBigInteger, +# } class DuckDBType(AlchemyType): @@ -58,3 +59,124 @@ def from_ibis(cls, dtype): @classmethod def from_string(cls, type_string, nullable=True): return SqlglotDuckdbType.from_string(type_string, nullable=nullable) + + +@functools.singledispatch +def serialize(ty) -> str: + raise NotImplementedError(f"{ty} not serializable to DuckDB type string") + + +@serialize.register(dt.DataType) +def _(ty: dt.DataType) -> str: + ser_ty = serialize_raw(ty) + if not ty.nullable: + return f"{ser_ty} NOT NULL" + return ser_ty + + +@serialize.register(dt.Map) +def _(ty: dt.Map) -> str: + return serialize_raw(ty) + + +@functools.singledispatch +def serialize_raw(ty: dt.DataType) -> str: + raise NotImplementedError(f"{ty} not serializable to DuckDB type string") + + +@serialize_raw.register(dt.DataType) +def _(ty: dt.DataType) -> str: + return type(ty).__name__.capitalize() + + +@serialize_raw.register(dt.Int8) +def _(_: dt.Int8) -> str: + return "TINYINT" + + +@serialize_raw.register(dt.Int16) +def _(_: dt.Int16) -> str: + return "SMALLINT" + + +@serialize_raw.register(dt.Int32) +def _(_: dt.Int32) -> str: + return "INTEGER" + + +@serialize_raw.register(dt.Int64) +def _(_: dt.Int64) -> str: + return "BIGINT" + + +@serialize_raw.register(dt.UInt8) +def _(_: dt.UInt8) -> str: + return "UTINYINT" + + +@serialize_raw.register(dt.UInt16) +def _(_: dt.UInt16) -> str: + return "USMALLINT" + + +@serialize_raw.register(dt.UInt32) +def _(_: dt.UInt32) -> str: + return "UINTEGER" + + +@serialize_raw.register(dt.UInt64) +def _(_: dt.UInt64) -> str: + return "UBIGINT" + + +@serialize_raw.register(dt.Float32) +def _(_: dt.Float32) -> str: + return "FLOAT" + + +@serialize_raw.register(dt.Float64) +def _(_: dt.Float64) -> str: + return "DOUBLE" + + +@serialize_raw.register(dt.Binary) +def _(_: dt.Binary) -> str: + return "BLOB" + + +@serialize_raw.register(dt.Boolean) +def _(_: dt.Boolean) -> str: + return "BOOLEAN" + + +@serialize_raw.register(dt.Array) +def _(ty: dt.Array) -> str: + return f"Array({serialize(ty.value_type)})" + + +@serialize_raw.register(dt.Map) +def _(ty: dt.Map) -> str: + # nullable key type is not allowed inside maps + key_type = serialize_raw(ty.key_type) + value_type = serialize(ty.value_type) + return f"Map({key_type}, {value_type})" + + +@serialize_raw.register(dt.Struct) +def _(ty: dt.Struct) -> str: + fields = ", ".join( + f"{name} {serialize(field_ty)}" for name, field_ty in ty.fields.items() + ) + return f"STRUCT({fields})" + + +@serialize_raw.register(dt.Timestamp) +def _(ty: dt.Timestamp) -> str: + if ty.timezone: + return "TIMESTAMPTZ" + return "TIMESTAMP" + + +@serialize_raw.register(dt.Decimal) +def _(ty: dt.Decimal) -> str: + return f"Decimal({ty.precision}, {ty.scale})" diff --git a/ibis/backends/duckdb/tests/conftest.py b/ibis/backends/duckdb/tests/conftest.py index 698004ae9f6d..1f8d70c05a41 100644 --- a/ibis/backends/duckdb/tests/conftest.py +++ b/ibis/backends/duckdb/tests/conftest.py @@ -50,6 +50,11 @@ def load_tpch(self) -> None: with self.connection.begin() as con: con.exec_driver_sql("CALL dbgen(sf=0.1)") + def _load_data(self, **_: Any) -> None: + """Load test data into a backend.""" + for stmt in self.ddl_script: + self.connection.raw_sql(stmt) + @pytest.fixture(scope="session") def con(data_dir, tmp_path_factory, worker_id): diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index df7031ea0c0f..af2de4601095 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1243,6 +1243,9 @@ def test_persist_expression_repeated_cache(alltypes): assert not nested_cached_table.to_pandas().empty +@mark.broken( + "duckdb", reason="table name has `main` prepended, breaking the match check" +) @mark.notimpl(["datafusion", "bigquery", "impala", "trino", "druid"]) @mark.never( ["mssql"], diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 0f04f2e5435c..ac4d4995ee59 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -386,8 +386,8 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.broken( ["duckdb"], - "(duckdb.ParserException) Parser Error: Width must be between 1 and 38!", - raises=sa.exc.ProgrammingError, + "Unsupported precision.", + raises=NotImplementedError, ), pytest.mark.notyet(["datafusion"], raises=Exception), pytest.mark.notyet( @@ -427,8 +427,8 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.broken( ["duckdb"], - "duckdb.ConversionException: Conversion Error: Could not cast value inf to DECIMAL(18,3)", - raises=DuckDBConversionException, + "Unsupported precision. Supported values: [1 : 38]. Current value: None", + raises=NotImplementedError, ), pytest.mark.broken( ["trino"], @@ -507,8 +507,8 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.broken( ["duckdb"], - "duckdb.ConversionException: Conversion Error: Could not cast value -inf to DECIMAL(18,3)", - raises=DuckDBConversionException, + "Unsupported precision. Supported values: [1 : 38]. Current value: None", + raises=NotImplementedError, ), pytest.mark.broken( ["trino"], @@ -587,11 +587,7 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.broken( ["duckdb"], - "(duckdb.InvalidInputException) Invalid Input Error: Attempting " - "to execute an unsuccessful or closed pending query result" - "Error: Invalid Input Error: Type DOUBLE with value nan can't be " - "cast because the value is out of range for the destination type INT64", - raises=sa.exc.ProgrammingError, + "Unsupported precision. Supported values: [1 : 38]. Current value: None", ), pytest.mark.broken( ["trino"], @@ -1286,7 +1282,6 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "mysql": 10, "snowflake": 38, "trino": 18, - "duckdb": None, "sqlite": None, "mssql": None, "oracle": 38, @@ -1296,7 +1291,6 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "mysql": 0, "snowflake": 0, "trino": 3, - "duckdb": None, "sqlite": None, "mssql": None, "oracle": 0, @@ -1310,6 +1304,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "clickhouse", "dask", "datafusion", + "duckdb", "impala", "pandas", "pyspark", diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index ca6ef8ac0ef4..5973c307a1ba 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -860,6 +860,11 @@ def test_string(backend, alltypes, df, result_func, expected_func): ["mysql", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError, ) +@pytest.mark.broken( + ["duckdb"], + reason="no idea, generated SQL looks very correct but this fails", + raises=AssertionError, +) def test_re_replace_global(con): expr = ibis.literal("aba").re_replace("a", "c") result = con.execute(expr) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 2f3f6eb3ed71..e8c0df5d78f7 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -933,11 +933,6 @@ def convert_to_offset(x): raises=ValidationError, reason="unsupported operand type(s) for -: 'StringColumn' and 'TimestampScalar'", ), - pytest.mark.xfail_version( - duckdb=["duckdb>=0.8.0"], - raises=AssertionError, - reason="duckdb 0.8.0 returns DateOffset columns", - ), ], ), param( @@ -1446,7 +1441,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): reason="PySpark backend does not support timestamp from unix time with unit us. Supported unit is s.", ), pytest.mark.notimpl( - ["duckdb", "mssql", "clickhouse"], + ["mssql", "clickhouse"], raises=com.UnsupportedOperationError, reason="`us` unit is not supported!", ), @@ -1458,12 +1453,12 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): pytest.mark.notimpl( ["pyspark"], raises=com.UnsupportedArgumentError, - reason="PySpark backend does not support timestamp from unix time with unit ms. Supported unit is s.", + reason="PySpark backend does not support timestamp from unix time with unit ns. Supported unit is s.", ), pytest.mark.notimpl( ["duckdb", "mssql", "clickhouse"], raises=com.UnsupportedOperationError, - reason="`ms` unit is not supported!", + reason="`ns` unit is not supported!", ), ], ), @@ -1989,11 +1984,6 @@ def test_extract_time_from_timestamp(con, microsecond): reason="Driver doesn't know how to handle intervals", raises=ClickhouseOperationalError, ) -@pytest.mark.xfail_version( - duckdb=["duckdb>=0.8.0"], - raises=AssertionError, - reason="duckdb 0.8.0 returns DateOffset columns", -) def test_interval_literal(con, backend): expr = ibis.interval(1, unit="s") result = con.execute(expr) From ec830d6d5a2fd235c0f9e1c6807bff131f393e25 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 13:24:44 -0400 Subject: [PATCH 002/222] feat(duckdb): simple and searched case support in sqlglot --- ibis/backends/duckdb/compiler/values.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index df7c1f179d07..5abdcc73490f 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -864,20 +864,22 @@ def _sql(obj, dialect="duckdb"): @translate_val.register(ops.SimpleCase) @translate_val.register(ops.SearchedCase) def _case(op, **kw): - buf = ["CASE"] + case = sg.expressions.Case() if (base := getattr(op, "base", None)) is not None: + breakpoint() buf.append(translate_val(base, **kw)) for when, then in zip(op.cases, op.results): - buf.append(f"WHEN {translate_val(when, **kw)}") - buf.append(f"THEN {translate_val(then, **kw)}") + case = case.when( + condition=translate_val(when, **kw), + then=translate_val(then, **kw), + ) if (default := op.default) is not None: - buf.append(f"ELSE {translate_val(default, **kw)}") + case = case.else_(condition=translate_val(default, **kw)) - buf.append("END") - return " ".join(map(_sql, buf)) + return case @translate_val.register(ops.TableArrayView) From 783192f0a8136521ca49c72b7af8e7415c064723 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 14:48:01 -0400 Subject: [PATCH 003/222] fix(duckdb): handle base table in case statement --- ibis/backends/duckdb/compiler/values.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5abdcc73490f..99efe302af04 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -860,15 +860,13 @@ def _sql(obj, dialect="duckdb"): return obj -# TODO @translate_val.register(ops.SimpleCase) @translate_val.register(ops.SearchedCase) def _case(op, **kw): case = sg.expressions.Case() if (base := getattr(op, "base", None)) is not None: - breakpoint() - buf.append(translate_val(base, **kw)) + case = sg.expressions.Case(this=translate_val(base, **kw)) for when, then in zip(op.cases, op.results): case = case.when( From 610b41850a6e54f148a7cb8b4a79072a5fbd89ba Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 14:48:30 -0400 Subject: [PATCH 004/222] feat(duckdb): ArrayMap and ArrayFilter --- ibis/backends/duckdb/compiler/values.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 99efe302af04..8e36fa3402ed 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1439,20 +1439,18 @@ def _argument(op, **_): return op.name -# TODO @translate_val.register(ops.ArrayMap) def _array_map(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - return f"arrayMap(({op.parameter}) -> {result}, {arg})" + return sg.func("list_transform", arg, f"{op.parameter}) -> {result}") -# TODO @translate_val.register(ops.ArrayFilter) def _array_filter(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - return f"arrayFilter(({op.parameter}) -> {result}, {arg})" + return sg.func("list_filter", arg, f"{op.parameter} -> {result}") @translate_val.register(ops.ArrayPosition) From d3682199152a8a223304b00da4a4a13a4dfebdd5 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 14:48:54 -0400 Subject: [PATCH 005/222] feat(duckdb): fix up window function lags and leads --- ibis/backends/duckdb/compiler/values.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8e36fa3402ed..8042343056eb 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1399,32 +1399,29 @@ def formatter(op, **kw): offset_fmt = translate_val(offset, **kw) pieces.append(offset_fmt) - return f"{name}({', '.join(map(_sql, pieces))})" + res = f"{name}({', '.join(map(_sql, pieces))})" + return res return formatter -# TODO -shift_like(ops.Lag, "lagInFrame") -shift_like(ops.Lead, "leadInFrame") +shift_like(ops.Lag, "lag") +shift_like(ops.Lead, "lead") -# TODO @translate_val.register(ops.RowNumber) def _row_number(_, **kw): - return "row_number()" + return sg.expressions.RowNumber() -# TODO @translate_val.register(ops.DenseRank) def _dense_rank(_, **kw): - return "dense_rank()" + return sg.func("dense_rank") -# TODO @translate_val.register(ops.MinRank) def _rank(_, **kw): - return "rank()" + return sg.func("rank") @translate_val.register(ops.ArrayStringJoin) From ec585607f85b333256184855f1f2796fa9b66bd9 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 15:06:31 -0400 Subject: [PATCH 006/222] feat(duckdb): implement current_database, list_schemas --- ibis/backends/duckdb/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index e3b6ad7e554a..d1927f4efabd 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -93,7 +93,11 @@ def _register_udfs(self, expr): @property def current_database(self) -> str: - return "main" + return ( + self.raw_sql("PRAGMA database_size; CALL pragma_database_size();") + .arrow()["database_name"] + .to_pylist()[0] + ) @property def current_schema(self) -> str: @@ -185,7 +189,8 @@ def _clean_up_cached_table(self, op): self.drop_table(op.name) def list_schemas(self): - ... + out = self.raw_sql("SELECT current_schemas(True) as schemas").arrow() + return list(set(out["schemas"].to_pylist()[0])) def table(self, name: str, database: str | None = None) -> ir.Table: """Construct a table expression. From f6f2fee3622df14b2036ae4d0e1b21e5ddaa64be Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 30 Aug 2023 15:24:35 -0400 Subject: [PATCH 007/222] feat(duckdb): fix xor chore(duckdb): random crap chore(duckdb): few more str -> sqlglot exprs --- ibis/backends/duckdb/compiler/values.py | 103 +++++++++++------------- 1 file changed, 45 insertions(+), 58 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8042343056eb..d4490e2581b8 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -86,7 +86,7 @@ def _bitwise_xor(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return f"xor({left}, {right})" + return sg.func("xor", left, right, dialect="duckdb") @translate_val.register(ops.BitwiseNot) @@ -158,6 +158,12 @@ def _try_cast(op, **kw): ) +@translate_val.register(ops.TypeOf) +def _type_of(op, **kw): + arg = translate_val(op.arg, **kw) + return sg.func("typeof", arg) + + ### Comparator Conundrums @@ -166,27 +172,19 @@ def _between(op, **kw): arg = translate_val(op.arg, **kw) lower_bound = translate_val(op.lower_bound, **kw) upper_bound = translate_val(op.upper_bound, **kw) - return f"{arg} BETWEEN {lower_bound} AND {upper_bound}" + return sg.expressions.Between(this=arg, low=lower_bound, high=upper_bound) @translate_val.register(ops.Negate) def _negate(op, **kw): arg = translate_val(op.arg, **kw) - return f"-{_parenthesize(op.arg, arg)}" + return sg.expressions.Neg(this=arg) @translate_val.register(ops.Not) def _not(op, **kw): arg = translate_val(op.arg, **kw) - return f"NOT {_parenthesize(op.arg, arg)}" - - -def _parenthesize(op, arg): - # function calls don't need parens - if isinstance(op, (ops.Binary, ops.Unary)): - return f"({arg})" - else: - return arg + return sg.expressions.Not(this=arg) ### Timey McTimeFace @@ -532,8 +530,8 @@ def _levenshtein(op, **kw): # Other operations ops.Where: "if", ops.ArrayLength: "length", - ops.ArrayConcat: "arrayConcat", # TODO - ops.Unnest: "arrayJoin", # TODO + ops.ArrayConcat: "list_concat", + ops.Unnest: "unnest", ops.Degrees: "degrees", ops.Radians: "radians", ops.NullIf: "nullIf", @@ -576,9 +574,6 @@ def _fmt(op, _name: str = _name, **kw): ### NULL PLAYER CHARACTER -# ops.IsNull: "isNull", # TODO -# ops.NotNull: "isNotNull", # TODO -# ops.IfNull: "ifNull", # TODO @translate_val.register(ops.IsNull) def _is_null(op, **kw): arg = translate_val(op.arg, **kw) @@ -645,7 +640,10 @@ def _literal(op, **kw): elif dtype.is_inet(): com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") elif dtype.is_string(): - return value + # TODO: if this is stringified, then test_select_filter_select breaks + # if it isn't, then try_cast breaks + # There's sqlglot.to_identifer which might help with this + return f"'{value}'" elif dtype.is_decimal(): precision = dtype.precision scale = dtype.scale @@ -666,7 +664,7 @@ def _literal(op, **kw): return f"'{repr(value)}inity'::FLOAT" elif math.isnan(value): return "'NaN'::FLOAT" - return value + return repr(value) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): @@ -711,41 +709,29 @@ def _literal(op, **kw): ### BELOW HERE BE DRAGONS -# TODO -@translate_val.register(ops.ArrayRepeat) -def _array_repeat_op(op, **kw): - arg = translate_val(op.arg, **kw) - times = translate_val(op.times, **kw) - from_ = f"(SELECT {arg} AS arr FROM system.numbers LIMIT {times})" - query = sg.parse_one( - f"SELECT arrayFlatten(groupArray(arr)) FROM {from_}", read="duckdb" - ) - return query.subquery() +# # TODO +# @translate_val.register(ops.ArrayRepeat) +# def _array_repeat_op(op, **kw): +# arg = translate_val(op.arg, **kw) +# times = translate_val(op.times, **kw) +# from_ = f"(SELECT {arg} AS arr FROM system.numbers LIMIT {times})" +# query = sg.parse_one( +# f"SELECT arrayFlatten(groupArray(arr)) FROM {from_}", read="duckdb" +# ) +# return query.subquery() -# TODO @translate_val.register(ops.ArraySlice) def _array_slice_op(op, **kw): arg = translate_val(op.arg, **kw) start = translate_val(op.start, **kw) - start = _parenthesize(op.start, start) - start_correct = f"if({start} < 0, {start}, {start} + 1)" if (stop := op.stop) is not None: stop = translate_val(stop, **kw) - stop = _parenthesize(op.stop, stop) - - neg_start = f"(length({arg}) + {start})" - diff_fmt = f"greatest(-0, {stop} - {{}})".format - - length = ( - f"if({stop} < 0, {stop}, " - f"if({start} < 0, {diff_fmt(neg_start)}, {diff_fmt(start)}))" - ) - - return f"arraySlice({arg}, {start_correct}, {length})" + else: + stop = sg.expressions.Null() - return f"arraySlice({arg}, {start_correct})" + return sg.func("list_slice", arg, start, stop) @translate_val.register(ops.CountStar) @@ -816,7 +802,10 @@ def _aggregate(op, func, *, where=None, **kw): ] if where is not None: predicate = translate_val(where, **kw) - return sg.func(func, *args).where(predicate) + return sg.expressions.Filter( + this=sg.func(func, *args, dialect="duckdb"), + expression=sg.expressions.Where(this=predicate), + ) res = sg.func(func, *args) return res @@ -900,6 +889,8 @@ def _table_array_view(op, *, cache, **kw): def _exists_subquery(op, **kw): from ibis.backends.duckdb.compiler.relations import translate_rel + if not "table" in kw: + kw["table"] = translate_rel(op.foreign_table.table, **kw) foreign_table = translate_rel(op.foreign_table, **kw) predicates = translate_val(op.predicates, **kw) subq = ( @@ -953,23 +944,21 @@ def _string_capitalize(op, **kw): return f"CONCAT(UPPER(SUBSTR({arg}, 1, 1)), LOWER(SUBSTR({arg}, 2)))" -# TODO @translate_val.register(ops.GroupConcat) def _group_concat(op, **kw): arg = translate_val(op.arg, **kw) sep = translate_val(op.sep, **kw) - args = [arg] - func = "groupArray" + concat = sg.func("array_to_string", arg, sep, dialect="duckdb") if (where := op.where) is not None: - func += "If" - args.append(translate_val(where, **kw)) + predicate = translate_val(where, **kw) + return sg.expressions.Filter( + this=concat, + expression=sg.expressions.Where(this=predicate), + ) - joined_args = ", ".join(map(_sql, args)) - call = f"{func}({joined_args})" - expr = f"list_concat({call}, {sep})" - return f"CASE WHEN empty({call}) THEN NULL ELSE {expr} END" + return concat # TODO @@ -1224,15 +1213,14 @@ def formatter(op, **kw): def _equals(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return left.eq(right) + return sg.expressions.EQ(this=left, expression=right) @translate_val.register(ops.NotEquals) def _equals(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - breakpoint() - return left.eq(right) + return sg.expressions.NEQ(this=left, expression=right) # TODO @@ -1374,7 +1362,6 @@ def _window(op: ops.WindowFunction, **kw: Any): return result -# TODO def shift_like(op_class, name): @translate_val.register(op_class) def formatter(op, **kw): From b27a24f565bede960c34eef4a43c627e59257c1a Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 10:58:26 -0400 Subject: [PATCH 008/222] feat(duckdb): sqlglot exprs for bitwise ops feat(duckdb): sg.funcify a few more ops feat(duckdb): add logical xor, more sglglot exprs fix(duckdb): remove nested strings for ops.InValues feat(duckdb): current timestamp should be case to no timezone chore(duckdb): date expr and cast numerics feat(duckdb): use sqlglot literals fix(duckdb): use sqlglot for more temporal ops --- ibis/backends/duckdb/compiler/values.py | 181 +++++++++++++----------- 1 file changed, 97 insertions(+), 84 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index d4490e2581b8..aabfa860690c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -62,10 +62,11 @@ def _alias(op, render_aliases: bool = True, **kw): ### Bitwise Business _bitwise_mapping = { - ops.BitwiseLeftShift: "<<", - ops.BitwiseRightShift: ">>", - ops.BitwiseAnd: "&", - ops.BitwiseOr: "|", + ops.BitwiseLeftShift: sg.expressions.BitwiseLeftShift, + ops.BitwiseRightShift: sg.expressions.BitwiseRightShift, + ops.BitwiseAnd: sg.expressions.BitwiseAnd, + ops.BitwiseOr: sg.expressions.BitwiseOr, + ops.BitwiseXor: sg.expressions.BitwiseXor, } @@ -73,27 +74,20 @@ def _alias(op, render_aliases: bool = True, **kw): @translate_val.register(ops.BitwiseRightShift) @translate_val.register(ops.BitwiseAnd) @translate_val.register(ops.BitwiseOr) -def _bitwise_binary(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - _operator = _bitwise_mapping[type(op)] - - return f"{left} {_operator} {right}" - - @translate_val.register(ops.BitwiseXor) -def _bitwise_xor(op, **kw): +def _bitwise_binary(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) + sg_expr = _bitwise_mapping[type(op)] - return sg.func("xor", left, right, dialect="duckdb") + return sg_expr(this=left, expression=right) @translate_val.register(ops.BitwiseNot) def _bitwise_not(op, **kw): value = translate_val(op.arg, **kw) - return f"~{value}" + return sg.expressions.BitwiseNot(this=value) ### Mathematical Calisthenics @@ -110,8 +104,8 @@ def _generic_log(op, **kw): arg = translate_val(arg, **kw) if base is not None: base = translate_val(base, **kw) - return f"ln({arg}) / ln({base})" - return f"ln({arg})" + return sg.func("ln", arg) / sg.func("ln", base) + return sg.func("ln", arg) ### Dtype Dysmorphia @@ -141,11 +135,11 @@ def _cast(op, **kw): ) else: - return f"INTERVAL {arg} {suffix}" + return sg.expressions.Interval(this=arg, unit=suffix, dialect="duckdb") elif isinstance(op.to, dt.Timestamp) and isinstance(op.arg.dtype, dt.Integer): - return sg.func("to_timestamp", arg) + return sg.func("to_timestamp", arg, dialect="duckdb") elif isinstance(op.to, dt.Timestamp) and (timezone := op.to.timezone) is not None: - return sg.func("timezone", timezone, arg) + return sg.func("timezone", f"'{timezone}'", arg, dialect="duckdb") to = translate_val(op.to, **kw) return sg.cast(expression=arg, to=to) @@ -161,7 +155,7 @@ def _try_cast(op, **kw): @translate_val.register(ops.TypeOf) def _type_of(op, **kw): arg = translate_val(op.arg, **kw) - return sg.func("typeof", arg) + return sg.func("typeof", arg, dialect="duckdb") ### Comparator Conundrums @@ -193,6 +187,7 @@ def _not(op, **kw): @translate_val.register(ops.Date) def _to_date(op, **kw): arg = translate_val(op.arg, **kw) + return sg.expressions.Date(this=arg) return f"DATE {arg}" @@ -202,6 +197,12 @@ def _time(op, **kw): return f"{arg}::TIME" +@translate_val.register(ops.TimestampNow) +def _timestamp_now(op, **kw): + """DuckDB current timestamp defaults to timestamp + tz""" + return sg.cast(expression=sg.func("current_timestamp"), to="TIMESTAMP") + + @translate_val.register(ops.Strftime) def _strftime(op, **kw): if not isinstance(op.format_str, ops.Literal): @@ -232,7 +233,13 @@ def _string_to_timestamp(op, **kw): def _extract_epoch_seconds(op, **kw): arg = translate_val(op.arg, **kw) # TODO: do we need the TIMESTAMP cast? - return f"epoch({arg}::TIMESTAMP)" + return sg.func( + "epoch", + sg.expressions.cast( + expression=sg.expressions.Literal(this=arg, is_string=True), + to=sg.expressions.DataType.Type.TIMESTAMP, + ), + ) _extract_mapping = { @@ -260,7 +267,9 @@ def _extract_epoch_seconds(op, **kw): def _extract_time(op, **kw): part = _extract_mapping[type(op)] timestamp = translate_val(op.arg, **kw) - return f"extract({part}, {timestamp})" + return sg.func( + "extract", sg.expressions.Literal(this=part, is_string=True), timestamp + ) # DuckDB extracts subminute microseconds and milliseconds @@ -319,7 +328,7 @@ def _date_from_ymd(op, **kw): y = translate_val(op.year, **kw) m = translate_val(op.month, **kw) d = translate_val(op.day, **kw) - return f"make_date({y}, {m}, {d})" + return sg.expressions.DateFromParts(year=y, month=m, day=d) @translate_val.register(ops.DayOfWeekIndex) @@ -331,15 +340,10 @@ def _day_of_week_index(op, **kw): @translate_val.register(ops.TimestampFromUNIX) def _timestamp_from_unix(op, **kw): arg = translate_val(op.arg, **kw) - if (unit := op.unit.short) in {"ns"}: + if (unit := op.unit.short) in {"ms", "us", "ns"}: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - if op.unit.short == "ms": - return f"to_timestamp({arg[:-3]}) + INTERVAL {arg[-3:]} millisecond" - elif op.unit.short == "us": - return f"to_timestamp({arg[:-6]}) + INTERVAL {arg[-6:]} microsecond" - - return f"to_timestamp({arg})" + return sg.expressions.UnixToTime(this=arg) @translate_val.register(ops.TimestampFromYMDHMS) @@ -361,9 +365,9 @@ def _timestamp_from_ymdhms(op, **kw): _interval_mapping = { - ops.IntervalAdd: operator.add, - ops.IntervalSubtract: operator.sub, - ops.IntervalMultiply: operator.mul, + ops.IntervalAdd: sg.expressions.Add, + ops.IntervalSubtract: sg.expressions.Sub, + ops.IntervalMultiply: sg.expressions.Mul, } @@ -373,9 +377,9 @@ def _timestamp_from_ymdhms(op, **kw): def _interval_binary(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - _operator = _interval_mapping[type(op)] + sg_expr = _interval_mapping[type(op)] - return operator(left, right) + return sg_expr(this=left, expression=right) def _interval_format(op): @@ -385,7 +389,10 @@ def _interval_format(op): "Duckdb doesn't support nanosecond interval resolutions" ) - return f"INTERVAL {op.value} {dtype.resolution.upper()}" + return sg.expressions.Interval( + this=sg.expressions.Literal(this=op.value, is_string=False), + unit=dtype.resolution.upper(), + ) @translate_val.register(ops.IntervalFromInteger) @@ -525,8 +532,6 @@ def _levenshtein(op, **kw): ops.Strip: "trim", ops.StringAscii: "ascii", ops.StrRight: "right", - # Temporal operations - ops.TimestampNow: "current_timestamp", # Other operations ops.Where: "if", ops.ArrayLength: "length", @@ -609,7 +614,9 @@ def _in_values(op, **kw): if not op.options: return False value = translate_val(op.value, **kw) - options = [translate_val(x, **kw) for x in op.options] + options = sg.expressions.Array().from_arg_list( + [translate_val(x, **kw) for x in op.options] + ) return sg.func("list_contains", options, value, dialect="duckdb") @@ -633,17 +640,14 @@ def _literal(op, **kw): dtype = op.dtype if value is None and dtype.nullable: if dtype.is_null(): - return "Null" + return sg.expressions.Null() return f"CAST(Null AS {serialize(dtype)})" if dtype.is_boolean(): - return str(int(bool(value))) + return sg.expressions.Boolean(value) elif dtype.is_inet(): com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") elif dtype.is_string(): - # TODO: if this is stringified, then test_select_filter_select breaks - # if it isn't, then try_cast breaks - # There's sqlglot.to_identifer which might help with this - return f"'{value}'" + return sg.expressions.Literal(this=f"{value}", is_string=True) elif dtype.is_decimal(): precision = dtype.precision scale = dtype.scale @@ -661,10 +665,17 @@ def _literal(op, **kw): return f"{value!s}::decimal({precision}, {scale})" elif dtype.is_numeric(): if math.isinf(value): - return f"'{repr(value)}inity'::FLOAT" + return sg.expressions.cast( + expression=sg.expressions.Literal(this=value, is_string=True), + to=sg.expressions.DataType.Type.FLOAT, + ) elif math.isnan(value): - return "'NaN'::FLOAT" - return repr(value) + return sg.expressions.cast( + expression=sg.expressions.Literal(this="NaN", is_string=True), + to=sg.expressions.DataType.Type.FLOAT, + ) + # return value + return sg.expressions.Literal(this=f"{value}", is_string=False) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): @@ -1029,18 +1040,16 @@ def _repeat(op, **kw): return f"repeat({arg}, {times})" -# TODO @translate_val.register(ops.NullIfZero) def _null_if_zero(op, **kw): arg = translate_val(op.arg, **kw) - return f"nullIf({arg}, 0)" + return sg.func("nullif", arg, 0, dialect="duckdb") -# TODO @translate_val.register(ops.ZeroIfNull) def _zero_if_null(op, **kw): arg = translate_val(op.arg, **kw) - return f"ifNull({arg}, 0)" + return sg.func("ifnull", arg, 0, dialect="duckdb") @translate_val.register(ops.FloorDivide) @@ -1167,12 +1176,12 @@ def _map_get(op, **kw): return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})" -def _binary_infix(symbol: str): +def _binary_infix(sg_expr: sg.expressions._Expression): def formatter(op, **kw): - left = translate_val(op_left := op.left, **kw) - right = translate_val(op_right := op.right, **kw) + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) - return symbol(left, right) + return sg_expr(this=left, expression=right, dialect="duckdb") return formatter @@ -1181,25 +1190,28 @@ def formatter(op, **kw): _binary_infix_ops = { # Binary operations - ops.Add: operator.add, - ops.Subtract: operator.sub, - ops.Multiply: operator.mul, - ops.Divide: operator.truediv, - ops.Modulus: operator.mod, + ops.Add: sg.expressions.Add, + ops.Subtract: sg.expressions.Sub, + ops.Multiply: sg.expressions.Mul, + ops.Divide: sg.expressions.Div, + ops.Modulus: sg.expressions.Mod, # Comparisons - ops.GreaterEqual: operator.ge, - ops.Greater: operator.gt, - ops.LessEqual: operator.le, - ops.Less: operator.lt, + ops.GreaterEqual: sg.expressions.GTE, + ops.Greater: sg.expressions.GT, + ops.LessEqual: sg.expressions.LTE, + ops.Less: sg.expressions.LT, + ops.Equals: sg.expressions.EQ, + ops.NotEquals: sg.expressions.NEQ, + ops.Xor: sg.expressions.Xor, # Boolean comparisons - ops.And: operator.and_, - ops.Or: operator.or_, - ops.DateAdd: operator.add, - ops.DateSub: operator.sub, - ops.DateDiff: operator.sub, - ops.TimestampAdd: operator.add, - ops.TimestampSub: operator.sub, - ops.TimestampDiff: operator.sub, + ops.And: sg.expressions.And, + ops.Or: sg.expressions.Or, + ops.DateAdd: sg.expressions.Add, + ops.DateSub: sg.expressions.Sub, + ops.DateDiff: sg.expressions.Sub, + ops.TimestampAdd: sg.expressions.Add, + ops.TimestampSub: sg.expressions.Sub, + ops.TimestampDiff: sg.expressions.Sub, } @@ -1209,18 +1221,19 @@ def formatter(op, **kw): del _op, _sym -@translate_val.register(ops.Equals) -def _equals(op, **kw): +@translate_val.register(ops.Xor) +def _xor(op, **kw): + # TODO: is this really the best way to do this? left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return sg.expressions.EQ(this=left, expression=right) - - -@translate_val.register(ops.NotEquals) -def _equals(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.expressions.NEQ(this=left, expression=right) + return sg.expressions.And( + this=sg.expressions.Paren(this=sg.expressions.Or(this=left, expression=right)), + expression=sg.expressions.Paren( + this=sg.expressions.Not( + this=sg.expressions.And(this=left, expression=right) + ) + ), + ) # TODO From 9a7da6a23ce5cb8afe2ed15bb085150f40dd9145 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 15:06:04 -0400 Subject: [PATCH 009/222] feat(duckdb): like, ilike, capitalize as sqlglot exprs --- ibis/backends/duckdb/compiler/values.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index aabfa860690c..3db0b0582aa1 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -938,21 +938,24 @@ def _string_concat(op, **kw): def _string_like(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return f"{arg} LIKE {pattern}" + return sg.expressions.Like(this=arg, expression=pattern) @translate_val.register(ops.StringSQLILike) def _string_ilike(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return f"lower({arg}) LIKE lower({pattern})" + return sg.expressions.Like(this=sg.func("lower", arg), expression=pattern) -# TODO @translate_val.register(ops.Capitalize) def _string_capitalize(op, **kw): arg = translate_val(op.arg, **kw) - return f"CONCAT(UPPER(SUBSTR({arg}, 1, 1)), LOWER(SUBSTR({arg}, 2)))" + return sg.func( + "concat", + sg.func("upper", sg.func("substr", arg, 1, 1)), + sg.func("lower", sg.func("substr", arg, 2)), + ) @translate_val.register(ops.GroupConcat) @@ -1181,7 +1184,7 @@ def formatter(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return sg_expr(this=left, expression=right, dialect="duckdb") + return sg_expr(this=left, expression=right) return formatter From 95f0e43f800041a29391d32f4132ec96180001f1 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 15:07:07 -0400 Subject: [PATCH 010/222] feat(duckdb): literal arrays -> sqlglot --- ibis/backends/duckdb/compiler/values.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 3db0b0582aa1..f3c8832c1bbd 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -693,13 +693,16 @@ def _literal(op, **kw): else: return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" elif dtype.is_date(): - return f"make_date({op.value.year}, {op.value.month}, {op.value.day})" + return sg.expressions.DateFromParts( + year=op.value.year, month=op.value.month, day=op.value.day + ) elif dtype.is_array(): value_type = dtype.value_type - values = ", ".join( - _literal(ops.Literal(v, dtype=value_type), **kw) for v in value + is_string = isinstance(value_type, dt.String) + values = sg.expressions.Array().from_arg_list( + [sg.expressions.Literal(this=v, is_string=is_string) for v in value] ) - return f"[{values}]" + return values elif dtype.is_map(): value_type = dtype.value_type values = ", ".join( From 06f4806afcb4d66b48d84e44395e091c1bb7d837 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 15:46:52 -0400 Subject: [PATCH 011/222] feat(duckdb): correct types in arrays, arraydistinct, arraysort --- ibis/backends/duckdb/compiler/values.py | 31 ++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index f3c8832c1bbd..a78ba2d03287 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -545,8 +545,7 @@ def _levenshtein(op, **kw): ops.MapKeys: "mapKeys", # TODO ops.MapValues: "mapValues", # TODO ops.MapMerge: "mapUpdate", # TODO - ops.ArrayDistinct: "arrayDistinct", # TODO - ops.ArraySort: "arraySort", # TODO + ops.ArraySort: "list_sort", ops.ArrayContains: "has", ops.FirstValue: "first_value", ops.LastValue: "last_value", @@ -601,6 +600,24 @@ def _if_null(op, **kw): ### Definitely Not Tensors +@translate_val.register(ops.ArrayDistinct) +def _array_sort(op, **kw): + arg = translate_val(op.arg, **kw) + + sg_expr = sg.expressions.If( + this=arg.is_(sg.expressions.Null()), + true=sg.expressions.Null(), + false=sg.func("list_distinct", arg) + + sg.expressions.If( + this=sg.func("list_count", arg) < sg.func("array_length", arg), + true=sg.func("list_value", sg.expressions.Null()), + false=sg.func("list_value"), + ), + ) + # TODO: this is (I think) working but tests fail because of broken NaN / None stuff + return sg_expr + + @translate_val.register(ops.ArrayIndex) def _array_index_op(op, **kw): arg = translate_val(op.arg, **kw) @@ -700,7 +717,15 @@ def _literal(op, **kw): value_type = dtype.value_type is_string = isinstance(value_type, dt.String) values = sg.expressions.Array().from_arg_list( - [sg.expressions.Literal(this=v, is_string=is_string) for v in value] + [ + # TODO: this cast makes for frustrating output + # is there any better way to handle it? + sg.cast( + sg.expressions.Literal(this=f"{v}", is_string=is_string), + to=getattr(sg.expressions.DataType.Type, serialize(value_type)), + ) + for v in value + ] ) return values elif dtype.is_map(): From 0a2692de0ba19a6b99259b25abac6d5b24c98828 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 15:47:14 -0400 Subject: [PATCH 012/222] feat(duckdb): sqlglot cast null --- ibis/backends/duckdb/compiler/values.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index a78ba2d03287..b1cef43daea2 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -658,7 +658,10 @@ def _literal(op, **kw): if value is None and dtype.nullable: if dtype.is_null(): return sg.expressions.Null() - return f"CAST(Null AS {serialize(dtype)})" + return sg.cast( + sg.expressions.Null(), + to=getattr(sg.expressions.DataType.Type, serialize(dtype)), + ) if dtype.is_boolean(): return sg.expressions.Boolean(value) elif dtype.is_inet(): From 83af08dd37e745dfc85b27f8c1a6b7e37399d63b Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 31 Aug 2023 16:26:53 -0400 Subject: [PATCH 013/222] feat(duckdb): maps in progress... --- ibis/backends/duckdb/compiler/values.py | 37 +++++++++++++++---------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b1cef43daea2..72afbcb9909b 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -542,8 +542,8 @@ def _levenshtein(op, **kw): ops.NullIf: "nullIf", ops.MapContains: "mapContains", # TODO ops.MapLength: "length", - ops.MapKeys: "mapKeys", # TODO - ops.MapValues: "mapValues", # TODO + ops.MapKeys: "map_keys", + ops.MapValues: "map_values", ops.MapMerge: "mapUpdate", # TODO ops.ArraySort: "list_sort", ops.ArrayContains: "has", @@ -650,7 +650,6 @@ def _in_column(op, **kw): ### LITERALLY -# TODO: need to go through this carefully @translate_val.register(ops.Literal) def _literal(op, **kw): value = op.value @@ -681,7 +680,14 @@ def _literal(op, **kw): ) # TODO: handle if `value` is "Infinity" - + # precision = sg.expressions.DataTypeParam( + # this=sg.expressions.Literal(this=f"{precision}", is_string=False) + # ) + # scale = sg.expressions.DataTypeParam( + # this=sg.expressions.Literal(this=f"{scale}", is_string=False) + # ) + # need sg.expressions.DataTypeParam to be available + # ... return f"{value!s}::decimal({precision}, {scale})" elif dtype.is_numeric(): if math.isinf(value): @@ -734,7 +740,7 @@ def _literal(op, **kw): elif dtype.is_map(): value_type = dtype.value_type values = ", ".join( - f"{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}" + f"[{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}]" for k, v in value.items() ) return f"map({values})" @@ -1198,6 +1204,10 @@ def _map(op, **kw): keys = translate_val(op.keys, **kw) values = translate_val(op.values, **kw) typ = serialize(op.dtype) + breakpoint() + sg_expr = sg.expressions.Map(keys=keys, values=values) + breakpoint() + return sg_expr return f"CAST(({keys}, {values}) AS {typ})" @@ -1481,7 +1491,8 @@ def _array_map(op, **kw): def _array_filter(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - return sg.func("list_filter", arg, f"{op.parameter} -> {result}") + func = sg.func("list_filter", arg, f"{op.parameter} -> {result}") + return func @translate_val.register(ops.ArrayPosition) @@ -1501,13 +1512,11 @@ def _array_union(op, **kw): return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw) -# TODO +# TODO: need to do this as a an array map + struct pack -- look at existing +# alchemy backend implementation @translate_val.register(ops.ArrayZip) def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: - arglist = [] - for arg in op.arg: - sql_arg = translate_val(arg, **kw) - with contextlib.suppress(AttributeError): - sql_arg = sql_arg.sql(dialect="duckdb") - arglist.append(sql_arg) - return f"arrayZip({', '.join(arglist)})" + zipped = sg.expressions.ArrayJoin().from_arg_list( + [translate_val(arg, **kw) for arg in op.arg] + ) + return zipped From d1f8d4867428055141efc1efa557a7ac73e39418 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 5 Sep 2023 14:39:32 -0400 Subject: [PATCH 014/222] feat(duckdb): use new TypeMapper class for parsing dtypes --- ibis/backends/duckdb/__init__.py | 22 ++++++------ ibis/backends/duckdb/datatypes.py | 56 ------------------------------- 2 files changed, 11 insertions(+), 67 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index d1927f4efabd..be8ee8f6dc4b 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -9,27 +9,28 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, MutableMapping -import duckdb -import pyarrow as pa -import sqlglot as sg -import toolz -from packaging.version import parse as vparse - import ibis import ibis.common.exceptions as exc import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +import pyarrow as pa +import sqlglot as sg +import toolz from ibis import util from ibis.backends.base import CanCreateSchema from ibis.backends.base.sql import BaseBackend +from ibis.backends.base.sqlglot.datatypes import DuckDBType from ibis.backends.duckdb.compiler import translate -from ibis.backends.duckdb.datatypes import parse, serialize, DuckDBType +from ibis.backends.duckdb.datatypes import serialize from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType -from ibis.formats.pyarrow import PyArrowData from ibis.formats.pandas import PandasData +from ibis.formats.pyarrow import PyArrowData +from packaging.version import parse as vparse + +import duckdb if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence @@ -245,7 +246,7 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema names, types, *_ = results.fetch_arrow_table() names = names.to_pylist() types = types.to_pylist() - return sch.Schema(dict(zip(names, map(parse, types)))) + return sch.Schema(dict(zip(names, map(DuckDBType.from_string, types)))) def list_databases(self, like: str | None = None) -> list[str]: result = self.raw_sql("PRAGMA database_list;") @@ -1195,8 +1196,7 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: map(as_py, rows["column_type"]), map(as_py, rows["null"]), ): - ibis_type = parse(type) - # ibis_type = DuckDBType.from_string(type, nullable=nullable) + ibis_type = DuckDBType.from_string(type, nullable=nullable) yield name, ibis_type.copy(nullable=null.lower() == "yes") def _register_in_memory_tables(self, expr: ir.Expr) -> None: diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py index 80102f5b53b4..f8abd82bd531 100644 --- a/ibis/backends/duckdb/datatypes.py +++ b/ibis/backends/duckdb/datatypes.py @@ -2,63 +2,7 @@ import functools -import sqlalchemy.dialects.postgresql as psql - import ibis.expr.datatypes as dt -from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.backends.base.sqlglot.datatypes import DuckDBType as SqlglotDuckdbType - -# _from_duckdb_types = { -# psql.BYTEA: dt.Binary, -# psql.UUID: dt.UUID, -# ducktypes.TinyInteger: dt.Int8, -# ducktypes.SmallInteger: dt.Int16, -# ducktypes.Integer: dt.Int32, -# ducktypes.BigInteger: dt.Int64, -# ducktypes.HugeInteger: dt.Decimal(38, 0), -# ducktypes.UInt8: dt.UInt8, -# ducktypes.UTinyInteger: dt.UInt8, -# ducktypes.UInt16: dt.UInt16, -# ducktypes.USmallInteger: dt.UInt16, -# ducktypes.UInt32: dt.UInt32, -# ducktypes.UInteger: dt.UInt32, -# ducktypes.UInt64: dt.UInt64, -# ducktypes.UBigInteger: dt.UInt64, -# } - -# _to_duckdb_types = { -# dt.UUID: psql.UUID, -# dt.Int8: ducktypes.TinyInteger, -# dt.Int16: ducktypes.SmallInteger, -# dt.Int32: ducktypes.Integer, -# dt.Int64: ducktypes.BigInteger, -# dt.UInt8: ducktypes.UTinyInteger, -# dt.UInt16: ducktypes.USmallInteger, -# dt.UInt32: ducktypes.UInteger, -# dt.UInt64: ducktypes.UBigInteger, -# } - - -class DuckDBType(AlchemyType): - dialect = "duckdb" - - @classmethod - def to_ibis(cls, typ, nullable=True): - if dtype := _from_duckdb_types.get(type(typ)): - return dtype(nullable=nullable) - else: - return super().to_ibis(typ, nullable=nullable) - - @classmethod - def from_ibis(cls, dtype): - if typ := _to_duckdb_types.get(type(dtype)): - return typ - else: - return super().from_ibis(dtype) - - @classmethod - def from_string(cls, type_string, nullable=True): - return SqlglotDuckdbType.from_string(type_string, nullable=nullable) @functools.singledispatch From 1863b7912c574be671e0db54b3eaccfe06750db2 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 5 Sep 2023 16:25:58 -0400 Subject: [PATCH 015/222] feat(duckdb): improved map support --- ibis/backends/duckdb/compiler/values.py | 103 +++++++++++++++++++----- 1 file changed, 82 insertions(+), 21 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 72afbcb9909b..7693a2a37d8a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -4,11 +4,12 @@ import contextlib import functools import math -from functools import partial import operator +from functools import partial from operator import add, mul, sub from typing import Any, Literal, Mapping +import duckdb import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an @@ -17,11 +18,14 @@ import ibis.expr.rules as rlz import sqlglot as sg from ibis.backends.base.sql.registry import helpers +from ibis.backends.base.sqlglot.datatypes import DuckDBType from ibis.backends.duckdb.datatypes import serialize +from packaging.version import parse as vparse from toolz import flip # TODO: Ideally we can translate bottom up a la `relations.py` # TODO: Find a way to remove all the dialect="duckdb" kwargs +_SUPPORTS_MAPS = vparse(duckdb.__version__) >= vparse("0.8.0") @functools.singledispatch @@ -540,13 +544,11 @@ def _levenshtein(op, **kw): ops.Degrees: "degrees", ops.Radians: "radians", ops.NullIf: "nullIf", - ops.MapContains: "mapContains", # TODO - ops.MapLength: "length", + ops.MapLength: "cardinality", ops.MapKeys: "map_keys", ops.MapValues: "map_values", - ops.MapMerge: "mapUpdate", # TODO ops.ArraySort: "list_sort", - ops.ArrayContains: "has", + ops.ArrayContains: "list_contains", ops.FirstValue: "first_value", ops.LastValue: "last_value", ops.NTile: "ntile", @@ -731,19 +733,23 @@ def _literal(op, **kw): # is there any better way to handle it? sg.cast( sg.expressions.Literal(this=f"{v}", is_string=is_string), - to=getattr(sg.expressions.DataType.Type, serialize(value_type)), + to=DuckDBType.from_ibis(value_type), ) for v in value ] ) return values elif dtype.is_map(): + key_type = dtype.key_type value_type = dtype.value_type - values = ", ".join( - f"[{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}]" - for k, v in value.items() + keys = sg.expressions.Array().from_arg_list( + [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] + ) + values = sg.expressions.Array().from_arg_list( + [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] ) - return f"map({values})" + sg_expr = sg.expressions.Map(keys=keys, values=values) + return sg_expr elif dtype.is_struct(): fields = ", ".join( _literal(ops.Literal(v, dtype=subdtype), **kw) @@ -1030,9 +1036,10 @@ def _translate(op, **kw): @translate_val.register(ops.ArrayColumn) def _array_column(op, **kw): - cols = map(partial(translate_val, **kw), op.cols) - args = ", ".join(map(_sql, cols)) - return f"[{args}]" + sg_expr = sg.expressions.Array.from_arg_list( + [translate_val(col, **kw) for col in op.cols] + ) + return sg_expr # TODO @@ -1150,9 +1157,11 @@ def tr(op, *, cache, **kw): # translate_val.register(ops.NotContains)(contains("NOT IN")) -# TODO @translate_val.register(ops.DayOfWeekName) def day_of_week_name(op, **kw): + # day of week number is 0-indexed + # Sunday == 0 + # Saturday == 6 arg = op.arg nullable = arg.dtype.nullable empty_string = ops.Literal("", dtype=dt.String(nullable=nullable)) @@ -1198,26 +1207,78 @@ def _vararg_func(op, **kw): ) -# TODO @translate_val.register(ops.Map) def _map(op, **kw): keys = translate_val(op.keys, **kw) values = translate_val(op.values, **kw) - typ = serialize(op.dtype) - breakpoint() sg_expr = sg.expressions.Map(keys=keys, values=values) - breakpoint() return sg_expr - return f"CAST(({keys}, {values}) AS {typ})" -# TODO @translate_val.register(ops.MapGet) def _map_get(op, **kw): arg = translate_val(op.arg, **kw) key = translate_val(op.key, **kw) default = translate_val(op.default, **kw) - return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})" + sg_expr = sg.func( + "ifnull", + sg.func("element_at", arg, key), + default, + dialect="duckdb", + ) + return sg_expr + + +@translate_val.register(ops.MapContains) +def _map_contains(op, **kw): + arg = translate_val(op.arg, **kw) + key = translate_val(op.key, **kw) + sg_expr = sg.expressions.NEQ( + this=sg.func( + "array_length", + sg.func( + "element_at", + arg, + key, + ), + ), + expression=sg.expressions.Literal(this="0", is_string=False), + ) + return sg_expr + + +def _is_map_literal(op): + return isinstance(op, ops.Literal) or ( + isinstance(op, ops.Map) + and isinstance(op.keys, ops.Literal) + and isinstance(op.values, ops.Literal) + ) + + +@translate_val.register(ops.MapMerge) +def _map_merge(op, **kw): + if _SUPPORTS_MAPS: + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return sg.func("map_concat", left, right) + else: + if not (_is_map_literal(op.left) and _is_map_literal(op.right)): + breakpoint() + raise com.UnsupportedOperationError( + "Merging non-literal maps is not yet supported by DuckDB" + ) + left = sg.func("to_json", translate_val(op.left, **kw)) + right = sg.func("to_json", translate_val(op.right, **kw)) + pairs = sg.func("json_merge_patch", left, right) + keys = sg.func("json_keys", pairs) + return sg.cast( + expression=sg.func( + "map", + keys, + sg.func("json_extract_string", pairs, keys), + ), + to=DuckDBType.from_ibis(op.dtype), + ) def _binary_infix(sg_expr: sg.expressions._Expression): From 7b3a208ef8b5cd1a943fc6a4f9d8d7bafb8ec2cb Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 5 Sep 2023 16:50:21 -0400 Subject: [PATCH 016/222] fix(duckdb): few type conversion updates feat(duckdb): structs feat(duckdb): fix list lambdas, add arrayintersect feat(duckdb): add naive arraycollect fix(duckdb): fix arrayconcat feat(duckdb): more sqlglot exprs chore(duckdb): scattered junk fix(duckdb): dialect --- ibis/backends/duckdb/compiler/core.py | 1 + ibis/backends/duckdb/compiler/values.py | 111 +++++++++++++++--------- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/ibis/backends/duckdb/compiler/core.py b/ibis/backends/duckdb/compiler/core.py index 2004d24e20f1..c7db3faf51c1 100644 --- a/ibis/backends/duckdb/compiler/core.py +++ b/ibis/backends/duckdb/compiler/core.py @@ -75,6 +75,7 @@ def fn(node, cache, params=params, **kwargs): # don't alias the root node if node is not op: + # TODO: do we want to create sqlglot tables here? aliases[node] = f"t{alias_index:d}" alias_index += 1 diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 7693a2a37d8a..44312b54f11b 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -19,7 +19,6 @@ import sqlglot as sg from ibis.backends.base.sql.registry import helpers from ibis.backends.base.sqlglot.datatypes import DuckDBType -from ibis.backends.duckdb.datatypes import serialize from packaging.version import parse as vparse from toolz import flip @@ -36,7 +35,7 @@ def translate_val(op, **_): @translate_val.register(dt.DataType) def _datatype(t, **_): - return serialize(t) + return DuckDBType.from_ibis(t) @translate_val.register(ops.PhysicalTable) @@ -410,9 +409,7 @@ def _interval_from_integer(op, **kw): arg = translate_val(op.arg, **kw) if op.dtype.resolution == "week": return sg.func("to_days", arg * 7) - # TODO: make less gross - # to_days, to_years, etc... - return sg.func(f"to_{op.dtype.resolution}s", arg) + return sg.func(f"to_{op.dtype.resolution}s", arg, dialect="duckdb") ### String Instruments @@ -515,7 +512,6 @@ def _levenshtein(op, **kw): ops.ArgMin: "arg_min", ops.Mode: "mode", ops.ArgMax: "arg_max", - # ops.ArrayCollect: "groupArray", # TODO ops.Count: "count", ops.CountDistinct: "list_unique", ops.First: "first", @@ -539,7 +535,6 @@ def _levenshtein(op, **kw): # Other operations ops.Where: "if", ops.ArrayLength: "length", - ops.ArrayConcat: "list_concat", ops.Unnest: "unnest", ops.Degrees: "degrees", ops.Radians: "radians", @@ -609,11 +604,12 @@ def _array_sort(op, **kw): sg_expr = sg.expressions.If( this=arg.is_(sg.expressions.Null()), true=sg.expressions.Null(), - false=sg.func("list_distinct", arg) + false=sg.func("list_distinct", arg, dialect="duckdb") + sg.expressions.If( - this=sg.func("list_count", arg) < sg.func("array_length", arg), - true=sg.func("list_value", sg.expressions.Null()), - false=sg.func("list_value"), + this=sg.func("list_count", arg, dialect="duckdb") + < sg.func("array_length", arg, dialect="duckdb"), + true=sg.func("list_value", sg.expressions.Null(), dialect="duckdb"), + false=sg.func("list_value", dialect="duckdb"), ), ) # TODO: this is (I think) working but tests fail because of broken NaN / None stuff @@ -649,6 +645,25 @@ def _in_column(op, **kw): return value.isin(options) +@translate_val.register(ops.ArrayCollect) +def _array_collect(op, **kw): + if op.where is not None: + # TODO: handle when op.where is not none + # probably using list_agg? + ... + return sg.func("list", translate_val(op.arg, **kw), dialect="duckdb") + + +@translate_val.register(ops.ArrayConcat) +def _array_concat(op, **kw): + sg_expr = sg.func( + "list_concat", + *(translate_val(arg, **kw) for arg in op.arg), + dialect="duckdb", + ) + return sg_expr + + ### LITERALLY @@ -659,17 +674,15 @@ def _literal(op, **kw): if value is None and dtype.nullable: if dtype.is_null(): return sg.expressions.Null() - return sg.cast( - sg.expressions.Null(), - to=getattr(sg.expressions.DataType.Type, serialize(dtype)), - ) + return sg.cast(sg.expressions.Null(), to=DuckDBType.from_ibis(dtype)) if dtype.is_boolean(): - return sg.expressions.Boolean(value) + return sg.expressions.Boolean(this=value) elif dtype.is_inet(): com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") elif dtype.is_string(): return sg.expressions.Literal(this=f"{value}", is_string=True) elif dtype.is_decimal(): + # TODO: make this a sqlglot expression precision = dtype.precision scale = dtype.scale if precision is None: @@ -702,7 +715,6 @@ def _literal(op, **kw): expression=sg.expressions.Literal(this="NaN", is_string=True), to=sg.expressions.DataType.Type.FLOAT, ) - # return value return sg.expressions.Literal(this=f"{value}", is_string=False) elif dtype.is_interval(): return _interval_format(op) @@ -751,11 +763,18 @@ def _literal(op, **kw): sg_expr = sg.expressions.Map(keys=keys, values=values) return sg_expr elif dtype.is_struct(): - fields = ", ".join( + keys = [ + sg.expressions.Literal(this=key, is_string=True) for key in value.keys() + ] + values = [ _literal(ops.Literal(v, dtype=subdtype), **kw) for subdtype, v in zip(dtype.types, value.values()) - ) - return f"tuple({fields})" + ] + slices = [ + sg.expressions.Slice(this=k, expression=v) for k, v in zip(keys, values) + ] + sg_expr = sg.expressions.Struct.from_arg_list(slices) + return sg_expr else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -1042,12 +1061,13 @@ def _array_column(op, **kw): return sg_expr -# TODO @translate_val.register(ops.StructColumn) def _struct_column(op, **kw): values = translate_val(op.values, **kw) - struct_type = serialize(op.dtype.copy(nullable=False)) - return f"CAST({values} AS {struct_type})" + struct_type = DuckDBType.from_ibis(op.dtype) + # TODO: this seems like a workaround + # but maybe it isn't + return sg.cast(expression=values, to=struct_type) @translate_val.register(ops.Clip) @@ -1063,13 +1083,11 @@ def _clip(op, **kw): @translate_val.register(ops.StructField) -def _struct_field(op, render_aliases: bool = False, **kw): - arg = op.arg - arg_dtype = arg.dtype - arg = translate_val(op.arg, render_aliases=render_aliases, **kw) - idx = arg_dtype.names.index(op.field) - typ = arg_dtype.types[idx] - return f"CAST({arg}.{idx + 1} AS {serialize(typ)})" +def _struct_field(op, **kw): + arg = translate_val(op.arg, **kw) + field = sg.expressions.Literal(this=f"{op.field}", is_string=True) + sg_expr = sg.func("struct_extract", arg, field) + return sg_expr # TODO @@ -1263,7 +1281,6 @@ def _map_merge(op, **kw): return sg.func("map_concat", left, right) else: if not (_is_map_literal(op.left) and _is_map_literal(op.right)): - breakpoint() raise com.UnsupportedOperationError( "Merging non-literal maps is not yet supported by DuckDB" ) @@ -1521,12 +1538,12 @@ def _row_number(_, **kw): @translate_val.register(ops.DenseRank) def _dense_rank(_, **kw): - return sg.func("dense_rank") + return sg.func("dense_rank", dialect="duckdb") @translate_val.register(ops.MinRank) def _rank(_, **kw): - return sg.func("rank") + return sg.func("rank", dialect="duckdb") @translate_val.register(ops.ArrayStringJoin) @@ -1538,29 +1555,45 @@ def _array_string_join(op, **kw): @translate_val.register(ops.Argument) def _argument(op, **_): - return op.name + return sg.expressions.Identifier(this=op.name, quoted=False) @translate_val.register(ops.ArrayMap) def _array_map(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - return sg.func("list_transform", arg, f"{op.parameter}) -> {result}") + lamduh = sg.expressions.Lambda( + this=result, + expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + ) + sg_expr = sg.func("list_transform", arg, lamduh) + return sg_expr @translate_val.register(ops.ArrayFilter) def _array_filter(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - func = sg.func("list_filter", arg, f"{op.parameter} -> {result}") - return func + lamduh = sg.expressions.Lambda( + this=result, + expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + ) + sg_expr = sg.func("list_filter", arg, lamduh) + return sg_expr + + +@translate_val.register(ops.ArrayIntersect) +def _array_intersect(op, **kw): + return translate_val( + ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)), **kw + ) @translate_val.register(ops.ArrayPosition) def _array_position(op, **kw): arg = translate_val(op.arg, **kw) el = translate_val(op.other, **kw) - return f"list_indexof({arg}, {el}) - 1" + return sg.func("list_indexof", arg, el) - 1 @translate_val.register(ops.ArrayRemove) @@ -1570,7 +1603,7 @@ def _array_remove(op, **kw): @translate_val.register(ops.ArrayUnion) def _array_union(op, **kw): - return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw) + return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) # TODO: need to do this as a an array map + struct pack -- look at existing From 760e92602c7b248d0b9fa97fa2a93ef116641ed4 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 6 Sep 2023 16:01:27 -0400 Subject: [PATCH 017/222] feat(duckdb): quantile and multiquantile --- ibis/backends/duckdb/compiler/values.py | 33 +++++++++---------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 44312b54f11b..4645d06b9f3e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -825,30 +825,19 @@ def _not_all(op, **kw): return translate_val(ops.Any(ops.Not(op.arg), where=op.where), **kw) -# TODO -def _quantile_like(func_name: str, op: ops.Node, quantile: str, **kw): - args = [_sql(translate_val(op.arg, **kw))] - - if (where := op.where) is not None: - args.append(_sql(translate_val(where, **kw))) - func_name += "If" - - return f"{func_name}({quantile})({', '.join(args)})" - - @translate_val.register(ops.Quantile) -def _quantile(op, **kw): - quantile = _sql(translate_val(op.quantile, **kw)) - return _quantile_like("quantile", op, quantile, **kw) - - @translate_val.register(ops.MultiQuantile) -def _multi_quantile(op, **kw): - if not isinstance(op.quantile, ops.Literal): - raise TypeError("Duckdb quantile only accepts a list of Python floats") - - quantile = ", ".join(map(str, op.quantile.value)) - return _quantile_like("quantiles", op, quantile, **kw) +def _quantile(op, **kw): + arg = translate_val(op.arg, **kw) + quantile = translate_val(op.quantile, **kw) + sg_expr = sg.func("quantile_cont", arg, quantile, dialect="duckdb") + if op.where is not None: + predicate = translate_val(op.where, **kw) + sg_expr = sg.expressions.Filter( + this=sg_expr, + expression=sg.expressions.Where(this=predicate), + ) + return sg_expr def _agg_variance_like(func): From 131e90577208a344688547cf6746985f51c0f5d2 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 6 Sep 2023 17:16:34 -0400 Subject: [PATCH 018/222] feat(duckdb): countdistinct and approxcountdistinct --- ibis/backends/duckdb/compiler/values.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 4645d06b9f3e..e2a5b70dc08a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -502,7 +502,6 @@ def _levenshtein(op, **kw): # Unary aggregates # ops.ApproxMedian: "median", # TODO # ops.Median: "quantileExactExclusive", # TODO - ops.ApproxCountDistinct: "list_unique", ops.Mean: "avg", ops.Sum: "sum", ops.Max: "max", @@ -513,7 +512,6 @@ def _levenshtein(op, **kw): ops.Mode: "mode", ops.ArgMax: "arg_max", ops.Count: "count", - ops.CountDistinct: "list_unique", ops.First: "first", ops.Last: "last", # string operations @@ -807,6 +805,22 @@ def _array_slice_op(op, **kw): return sg.func("list_slice", arg, start, stop) +@translate_val.register(ops.CountDistinct) +@translate_val.register(ops.ApproxCountDistinct) +def _count_distinct(op, **kw): + arg = translate_val(op.arg, **kw) + on = None + if op.where is not None: + on = translate_val(op.where, **kw) + sg_expr = sg.expressions.Count( + this=sg.expressions.Distinct( + expressions=[arg], + on=on, + ) + ) + return sg_expr + + @translate_val.register(ops.CountStar) def _count_star(op, **kw): sql = sg.expressions.Count(this=sg.expressions.Star()) From a2bbae896c2ad0d28c42c9d57d3fea06be68b472 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 11:37:46 -0400 Subject: [PATCH 019/222] feat(duckdb): substring sg expr --- ibis/backends/duckdb/compiler/values.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index e2a5b70dc08a..0701153ee891 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -417,19 +417,28 @@ def _interval_from_integer(op, **kw): @translate_val.register(ops.Substring) def _substring(op, **kw): + # TODO: fix expr_slice_begin tests # Duckdb is 1-indexed arg = translate_val(op.arg, **kw) start = translate_val(op.start, **kw) - arg_length = f"length({arg})" + arg_length = sg.expressions.Length(this=arg) if op.length is not None: length = translate_val(op.length, **kw) - suffix = f", {length}" else: - suffix = "" + length = arg_length - if_pos = f"substring({arg}, {start} + 1{suffix})" - if_neg = f"substring({arg}, {arg_length} + {start} + 1{suffix})" - return f"if({start} >= 0, {if_pos}, {if_neg})" + if_pos = sg.expressions.Substring(this=arg, start=start + 1, length=length) + if_neg = sg.expressions.Substring(this=arg, start=start, length=length) + + sg_expr = sg.expressions.If( + this=sg.expressions.GTE( + this=start, expression=sg.expressions.Literal(this="0", is_string=False) + ), + true=if_pos, + false=if_neg, + ) + + return sg_expr @translate_val.register(ops.StringFind) From 954fd7cf29413f9dcfe6f6e411b6692e9d140b73 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 15:24:01 -0400 Subject: [PATCH 020/222] fix(duckdb): handle sum of comparisons using count fix(duckdb): fix count filtering --- ibis/backends/duckdb/compiler/values.py | 51 +++++++++++++++++++------ 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 0701153ee891..b99012cd91b3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -421,11 +421,10 @@ def _substring(op, **kw): # Duckdb is 1-indexed arg = translate_val(op.arg, **kw) start = translate_val(op.start, **kw) - arg_length = sg.expressions.Length(this=arg) if op.length is not None: length = translate_val(op.length, **kw) else: - length = arg_length + length = None if_pos = sg.expressions.Substring(this=arg, start=start + 1, length=length) if_neg = sg.expressions.Substring(this=arg, start=start, length=length) @@ -512,7 +511,6 @@ def _levenshtein(op, **kw): # ops.ApproxMedian: "median", # TODO # ops.Median: "quantileExactExclusive", # TODO ops.Mean: "avg", - ops.Sum: "sum", ops.Max: "max", ops.Min: "min", ops.Any: "any_value", @@ -520,7 +518,6 @@ def _levenshtein(op, **kw): ops.ArgMin: "arg_min", ops.Mode: "mode", ops.ArgMax: "arg_max", - ops.Count: "count", ops.First: "first", ops.Last: "last", # string operations @@ -639,7 +636,8 @@ def _in_values(op, **kw): options = sg.expressions.Array().from_arg_list( [translate_val(x, **kw) for x in op.options] ) - return sg.func("list_contains", options, value, dialect="duckdb") + sg_expr = sg.func("list_contains", options, value, dialect="duckdb") + return sg_expr @translate_val.register(ops.InColumn) @@ -814,20 +812,29 @@ def _array_slice_op(op, **kw): return sg.func("list_slice", arg, start, stop) +@translate_val.register(ops.Count) +def _count(op, **kw): + arg = translate_val(op.arg, **kw) + count_expr = sg.expressions.Count(this=arg) + if op.where is not None: + where = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=count_expr, expression=where) + return count_expr + + @translate_val.register(ops.CountDistinct) @translate_val.register(ops.ApproxCountDistinct) def _count_distinct(op, **kw): arg = translate_val(op.arg, **kw) - on = None - if op.where is not None: - on = translate_val(op.where, **kw) - sg_expr = sg.expressions.Count( + count_expr = sg.expressions.Count( this=sg.expressions.Distinct( expressions=[arg], - on=on, ) ) - return sg_expr + if op.where is not None: + where = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=count_expr, expression=where) + return count_expr @translate_val.register(ops.CountStar) @@ -838,6 +845,28 @@ def _count_star(op, **kw): return sql +@translate_val.register(ops.Sum) +def _sum(op, **kw): + arg = translate_val(op.arg, **kw) + where = None + if op.where is not None: + where = translate_val(op.where, **kw) + + sg_where = sg.expressions.Where(this=where) + + # Handle sum(boolean comparison) + if isinstance(op.arg, ops.Comparison): + sg_count_expr = sg.expressions.Count(this=arg) + if where is not None: + return sg.expressions.Filter(this=sg_count_expr, expression=sg_where) + return sg_count_expr + + sg_sum_expr = sg.expressions.Sum(this=arg) + if where is not None: + return sg.expressions.Filter(this=sg_sum_expr, expression=sg_where) + return sg_sum_expr + + @translate_val.register(ops.NotAny) def _not_any(op, **kw): return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) From 7b9381e5ae54a415f6eeea49b961a84c2078a7cd Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 17:07:50 -0400 Subject: [PATCH 021/222] refactor(duckdb): remove remainder of serialize --- ibis/backends/duckdb/__init__.py | 4 +- ibis/backends/duckdb/compiler/values.py | 7 +- ibis/backends/duckdb/datatypes.py | 126 ------------------- ibis/backends/duckdb/tests/test_datatypes.py | 2 +- ibis/tests/benchmarks/test_benchmarks.py | 4 +- 5 files changed, 10 insertions(+), 133 deletions(-) delete mode 100644 ibis/backends/duckdb/datatypes.py diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index be8ee8f6dc4b..b78972668397 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -23,7 +23,6 @@ from ibis.backends.base.sql import BaseBackend from ibis.backends.base.sqlglot.datatypes import DuckDBType from ibis.backends.duckdb.compiler import translate -from ibis.backends.duckdb.datatypes import serialize from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType from ibis.formats.pandas import PandasData @@ -144,7 +143,8 @@ def create_table( code += f" AS {self.compile(obj)}" else: serialized_schema = ", ".join( - f"{name} {serialize(typ)}" for name, typ in schema.items() + f"{name} {DuckDBType.to_string(typ)}" + for name, typ in schema.items() ) code += f" ({serialized_schema})" diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b99012cd91b3..11af477a2a93 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -151,7 +151,10 @@ def _cast(op, **kw): @translate_val.register(ops.TryCast) def _try_cast(op, **kw): return sg.func( - "try_cast", translate_val(op.arg, **kw), serialize(op.to), dialect="duckdb" + "try_cast", + translate_val(op.arg, **kw), + DuckDBType.to_string(op.to), + dialect="duckdb", ) @@ -280,7 +283,6 @@ def _extract_time(op, **kw): @translate_val.register(ops.ExtractMicrosecond) def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - dtype = serialize(op.dtype) return f"extract('us', {arg}::TIMESTAMP) % 1000000" @@ -288,7 +290,6 @@ def _extract_microsecond(op, **kw): @translate_val.register(ops.ExtractMillisecond) def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - dtype = serialize(op.dtype) return f"extract('ms', {arg}::TIMESTAMP) % 1000" diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py deleted file mode 100644 index f8abd82bd531..000000000000 --- a/ibis/backends/duckdb/datatypes.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -import functools - -import ibis.expr.datatypes as dt - - -@functools.singledispatch -def serialize(ty) -> str: - raise NotImplementedError(f"{ty} not serializable to DuckDB type string") - - -@serialize.register(dt.DataType) -def _(ty: dt.DataType) -> str: - ser_ty = serialize_raw(ty) - if not ty.nullable: - return f"{ser_ty} NOT NULL" - return ser_ty - - -@serialize.register(dt.Map) -def _(ty: dt.Map) -> str: - return serialize_raw(ty) - - -@functools.singledispatch -def serialize_raw(ty: dt.DataType) -> str: - raise NotImplementedError(f"{ty} not serializable to DuckDB type string") - - -@serialize_raw.register(dt.DataType) -def _(ty: dt.DataType) -> str: - return type(ty).__name__.capitalize() - - -@serialize_raw.register(dt.Int8) -def _(_: dt.Int8) -> str: - return "TINYINT" - - -@serialize_raw.register(dt.Int16) -def _(_: dt.Int16) -> str: - return "SMALLINT" - - -@serialize_raw.register(dt.Int32) -def _(_: dt.Int32) -> str: - return "INTEGER" - - -@serialize_raw.register(dt.Int64) -def _(_: dt.Int64) -> str: - return "BIGINT" - - -@serialize_raw.register(dt.UInt8) -def _(_: dt.UInt8) -> str: - return "UTINYINT" - - -@serialize_raw.register(dt.UInt16) -def _(_: dt.UInt16) -> str: - return "USMALLINT" - - -@serialize_raw.register(dt.UInt32) -def _(_: dt.UInt32) -> str: - return "UINTEGER" - - -@serialize_raw.register(dt.UInt64) -def _(_: dt.UInt64) -> str: - return "UBIGINT" - - -@serialize_raw.register(dt.Float32) -def _(_: dt.Float32) -> str: - return "FLOAT" - - -@serialize_raw.register(dt.Float64) -def _(_: dt.Float64) -> str: - return "DOUBLE" - - -@serialize_raw.register(dt.Binary) -def _(_: dt.Binary) -> str: - return "BLOB" - - -@serialize_raw.register(dt.Boolean) -def _(_: dt.Boolean) -> str: - return "BOOLEAN" - - -@serialize_raw.register(dt.Array) -def _(ty: dt.Array) -> str: - return f"Array({serialize(ty.value_type)})" - - -@serialize_raw.register(dt.Map) -def _(ty: dt.Map) -> str: - # nullable key type is not allowed inside maps - key_type = serialize_raw(ty.key_type) - value_type = serialize(ty.value_type) - return f"Map({key_type}, {value_type})" - - -@serialize_raw.register(dt.Struct) -def _(ty: dt.Struct) -> str: - fields = ", ".join( - f"{name} {serialize(field_ty)}" for name, field_ty in ty.fields.items() - ) - return f"STRUCT({fields})" - - -@serialize_raw.register(dt.Timestamp) -def _(ty: dt.Timestamp) -> str: - if ty.timezone: - return "TIMESTAMPTZ" - return "TIMESTAMP" - - -@serialize_raw.register(dt.Decimal) -def _(ty: dt.Decimal) -> str: - return f"Decimal({ty.precision}, {ty.scale})" diff --git a/ibis/backends/duckdb/tests/test_datatypes.py b/ibis/backends/duckdb/tests/test_datatypes.py index 2b64d874cdd2..5533c5e08ce3 100644 --- a/ibis/backends/duckdb/tests/test_datatypes.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -10,7 +10,7 @@ import ibis.backends.base.sql.alchemy.datatypes as sat import ibis.common.exceptions as exc import ibis.expr.datatypes as dt -from ibis.backends.duckdb.datatypes import DuckDBType +from ibis.backends.base.sqlglot.datatypes import DuckDBType @pytest.mark.parametrize( diff --git a/ibis/tests/benchmarks/test_benchmarks.py b/ibis/tests/benchmarks/test_benchmarks.py index 91375e7825a3..8ad6266141cc 100644 --- a/ibis/tests/benchmarks/test_benchmarks.py +++ b/ibis/tests/benchmarks/test_benchmarks.py @@ -747,7 +747,9 @@ def test_snowflake_medium_sized_to_pandas(benchmark): def test_parse_many_duckdb_types(benchmark): - parse = pytest.importorskip("ibis.backends.duckdb.datatypes").DuckDBType.from_string + parse = pytest.importorskip( + "ibis.backends.base.sqlglot.datatypes" + ).DuckDBType.from_string def parse_many(types): list(map(parse, types)) From 805aceb1935a04d65e356b1c50bba842f0e34352 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 17:08:20 -0400 Subject: [PATCH 022/222] fix(duckdb): nullability in _metadata --- ibis/backends/duckdb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index b78972668397..c7719871a91d 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -1196,7 +1196,7 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: map(as_py, rows["column_type"]), map(as_py, rows["null"]), ): - ibis_type = DuckDBType.from_string(type, nullable=nullable) + ibis_type = DuckDBType.from_string(type, nullable=null.lower() == "yes") yield name, ibis_type.copy(nullable=null.lower() == "yes") def _register_in_memory_tables(self, expr: ir.Expr) -> None: From b07a49673b7fa16f402e54acc46767a4e258516d Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 17:33:09 -0400 Subject: [PATCH 023/222] refactor(duckdb): redo string ops in sqlglot expressions --- ibis/backends/duckdb/compiler/values.py | 32 +++++++++++++++++-------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 11af477a2a93..f5a6e848de5c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -449,14 +449,16 @@ def _string_find(op, **kw): arg = translate_val(op.arg, **kw) substr = translate_val(op.substr, **kw) - return f"instr({arg}, {substr}) - 1" + return sg.func("instr", arg, substr) - 1 @translate_val.register(ops.RegexSearch) def _regex_search(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return f"regexp_matches({arg}, {pattern}, 's')" + return sg.func( + "regexp_matches", arg, pattern, sg.expressions.Literal(this="s", is_string=True) + ) @translate_val.register(ops.RegexReplace) @@ -464,7 +466,14 @@ def _regex_replace(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) replacement = translate_val(op.replacement, **kw) - return sg.func("regexp_replace", arg, pattern, replacement, "g", dialect="duckdb") + return sg.func( + "regexp_replace", + arg, + pattern, + replacement, + sg.expressions.Literal(this="g", is_string=True), + dialect="duckdb", + ) @translate_val.register(ops.RegexExtract) @@ -473,13 +482,15 @@ def _regex_extract(op, **kw): pattern = translate_val(op.pattern, **kw) group = translate_val(op.index, **kw) return f"regexp_extract({arg}, {pattern}, {group})" + # TODO: make this work -- need to handle pattern escaping? + return sg.func("regexp_extract", arg, pattern, group) @translate_val.register(ops.Levenshtein) def _levenshtein(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return f"levenshtein({left}, {right})" + return sg.func("levenshtein", left, right) ### Simple Ops @@ -1021,7 +1032,7 @@ def _exists_subquery(op, **kw): def _string_split(op, **kw): arg = translate_val(op.arg, **kw) delimiter = translate_val(op.delimiter, **kw) - return f"string_split({arg}, {delimiter})" + return sg.expressions.Split(this=arg, expression=delimiter) @translate_val.register(ops.StringJoin) @@ -1035,7 +1046,7 @@ def _string_join(op, **kw): @translate_val.register(ops.StringConcat) def _string_concat(op, **kw): arg = map(partial(translate_val, **kw), op.arg) - return " || ".join(map(_sql, arg)) + return sg.expressions.Concat(expressions=list(arg)) @translate_val.register(ops.StringSQLLike) @@ -1055,10 +1066,11 @@ def _string_ilike(op, **kw): @translate_val.register(ops.Capitalize) def _string_capitalize(op, **kw): arg = translate_val(op.arg, **kw) - return sg.func( - "concat", - sg.func("upper", sg.func("substr", arg, 1, 1)), - sg.func("lower", sg.func("substr", arg, 2)), + return sg.expressions.Concat( + expressions=[ + sg.func("upper", sg.func("substr", arg, 1, 1)), + sg.func("lower", sg.func("substr", arg, 2)), + ] ) From acb30fd1bd7434db61d3ef9a2eb91aa4e47f94d9 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 7 Sep 2023 17:37:14 -0400 Subject: [PATCH 024/222] fix(duckdb): round --- ibis/backends/duckdb/compiler/values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index f5a6e848de5c..3377fa351c25 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -958,8 +958,8 @@ def _index_of(op, **kw): def _round(op, **kw): arg = translate_val(op.arg, **kw) if (digits := op.digits) is not None: - return f"round({arg}, {translate_val(digits, **kw)})" - return f"round({arg})" + return sg.expressions.Round(this=arg, decimals=translate_val(digits, **kw)) + return sg.expressions.Round(this=arg) @translate_val.register(tuple) From 570ae7ebb0e50cf1f9b11eae487b3dfe7609eb7a Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 11:19:31 -0400 Subject: [PATCH 025/222] fix(duckdb): fix up some date/timestamp stuff re: sqlglot --- ibis/backends/duckdb/compiler/values.py | 55 +++++++++++++++++-------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 3377fa351c25..371e7abeec3d 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -238,11 +238,10 @@ def _string_to_timestamp(op, **kw): @translate_val.register(ops.ExtractEpochSeconds) def _extract_epoch_seconds(op, **kw): arg = translate_val(op.arg, **kw) - # TODO: do we need the TIMESTAMP cast? return sg.func( "epoch", sg.expressions.cast( - expression=sg.expressions.Literal(this=arg, is_string=True), + expression=arg, to=sg.expressions.DataType.Type.TIMESTAMP, ), ) @@ -284,21 +283,35 @@ def _extract_time(op, **kw): def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - return f"extract('us', {arg}::TIMESTAMP) % 1000000" + return sg.expressions.Mod( + this=sg.func( + "extract", + sg.expressions.Literal(this="us", is_string=True), + arg, + ), + expression=sg.expressions.Literal(this="1000000", is_string=False), + ) @translate_val.register(ops.ExtractMillisecond) def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - return f"extract('ms', {arg}::TIMESTAMP) % 1000" + return sg.expressions.Mod( + this=sg.func( + "extract", + sg.expressions.Literal(this="ms", is_string=True), + arg, + ), + expression=sg.expressions.Literal(this="1000", is_string=False), + ) @translate_val.register(ops.Date) def _date(op, **kw): arg = translate_val(op.arg, **kw) - return f"{arg}::DATE" + return sg.expressions.Date(this=arg) @translate_val.register(ops.DateTruncate) @@ -736,23 +749,29 @@ def _literal(op, **kw): elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): - year = op.value.year - month = op.value.month - day = op.value.day - hour = op.value.hour - minute = op.value.minute - second = op.value.second + year = sg.expressions.Literal(this=f"{op.value.year}", is_string=False) + month = sg.expressions.Literal(this=f"{op.value.month}", is_string=False) + day = sg.expressions.Literal(this=f"{op.value.day}", is_string=False) + hour = sg.expressions.Literal(this=f"{op.value.hour}", is_string=False) + minute = sg.expressions.Literal(this=f"{op.value.minute}", is_string=False) + second = sg.expressions.Literal(this=f"{op.value.second}", is_string=False) if op.value.microsecond: - microsecond = op.value.microsecond / 1e6 + microsecond = sg.expressions.Literal( + this=f"{op.value.microsecond / 1e6}", is_string=False + ) second += microsecond - if (timezone := dtype.timezone) is not None: - return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" + if dtype.timezone is not None: + timezone = sg.expressions.Literal(this=dtype.timezone, is_string=True) + return sg.func( + "make_timestamptz", year, month, day, hour, minute, second, timezone + ) else: - return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" + return sg.func("make_timestamp", year, month, day, hour, minute, second) elif dtype.is_date(): - return sg.expressions.DateFromParts( - year=op.value.year, month=op.value.month, day=op.value.day - ) + year = sg.expressions.Literal(this=f"{op.value.year}", is_string=False) + month = sg.expressions.Literal(this=f"{op.value.month}", is_string=False) + day = sg.expressions.Literal(this=f"{op.value.day}", is_string=False) + return sg.expressions.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): value_type = dtype.value_type is_string = isinstance(value_type, dt.String) From d70cc503f43edff5a6325e69fdbbaeae50956601 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 11:45:06 -0400 Subject: [PATCH 026/222] fix(duckdb): pass timezones as expressions --- ibis/backends/duckdb/compiler/values.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 371e7abeec3d..248828a2014d 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -138,11 +138,12 @@ def _cast(op, **kw): ) else: - return sg.expressions.Interval(this=arg, unit=suffix, dialect="duckdb") + return sg.expressions.Interval(this=arg, unit=suffix) elif isinstance(op.to, dt.Timestamp) and isinstance(op.arg.dtype, dt.Integer): - return sg.func("to_timestamp", arg, dialect="duckdb") - elif isinstance(op.to, dt.Timestamp) and (timezone := op.to.timezone) is not None: - return sg.func("timezone", f"'{timezone}'", arg, dialect="duckdb") + return sg.func("to_timestamp", arg) + elif isinstance(op.to, dt.Timestamp) and op.to.timezone is not None: + timezone = sg.expressions.Literal(this=op.to.timezone, is_string=True) + return sg.func("timezone", timezone, arg) to = translate_val(op.to, **kw) return sg.cast(expression=arg, to=to) From 84905ce6d14c7cfd1e456683087071e841ccdf2c Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 11:45:36 -0400 Subject: [PATCH 027/222] fix(duckdb): remove stringified casts --- ibis/backends/duckdb/compiler/values.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 248828a2014d..d6cca108c532 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -151,10 +151,9 @@ def _cast(op, **kw): @translate_val.register(ops.TryCast) def _try_cast(op, **kw): - return sg.func( - "try_cast", - translate_val(op.arg, **kw), - DuckDBType.to_string(op.to), + return sg.expressions.TryCast( + this=translate_val(op.arg, **kw), + to=DuckDBType.to_string(op.to), dialect="duckdb", ) @@ -195,13 +194,12 @@ def _not(op, **kw): def _to_date(op, **kw): arg = translate_val(op.arg, **kw) return sg.expressions.Date(this=arg) - return f"DATE {arg}" @translate_val.register(ops.Time) def _time(op, **kw): arg = translate_val(op.arg, **kw) - return f"{arg}::TIME" + return sg.cast(expression=arg, to=sg.expressions.DataType.Type.TIME) @translate_val.register(ops.TimestampNow) From ad288d2939212a4a1fded855cf6940c203048849 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 11:56:44 -0400 Subject: [PATCH 028/222] feat(duckdb): bit aggregations --- ibis/backends/duckdb/compiler/values.py | 39 +++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index d6cca108c532..c507c42630f2 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -957,6 +957,8 @@ def _aggregate(op, func, *, where=None, **kw): @translate_val.register(ops.Arbitrary) def _arbitrary(op, **kw): + if op.how == "heavy": + raise com.UnsupportedOperationError("how='heavy' not supported in the backend") functions = { "first": "first", "last": "last", @@ -1110,21 +1112,6 @@ def _group_concat(op, **kw): # TODO -def _bit_agg(func): - def _translate(op, **kw): - arg = translate_val(op.arg, **kw) - if not isinstance((type := op.arg.dtype), dt.UnsignedInteger): - nbits = type.nbytes * 8 - arg = f"reinterpretAsUInt{nbits}({arg})" - - if (where := op.where) is not None: - return f"{func}If({arg}, {translate_val(where, **kw)})" - else: - return f"{func}({arg})" - - return _translate - - @translate_val.register(ops.ArrayColumn) def _array_column(op, **kw): sg_expr = sg.expressions.Array.from_arg_list( @@ -1430,10 +1417,24 @@ def _xor(op, **kw): ) -# TODO -translate_val.register(ops.BitAnd)(_bit_agg("groupBitAnd")) -translate_val.register(ops.BitOr)(_bit_agg("groupBitOr")) -translate_val.register(ops.BitXor)(_bit_agg("groupBitXor")) +_bit_agg = { + ops.BitOr: "bit_or", + ops.BitAnd: "bit_and", + ops.BitXor: "bit_xor", +} + + +@translate_val.register(ops.BitAnd) +@translate_val.register(ops.BitOr) +@translate_val.register(ops.BitXor) +def _bitor(op, **kw): + arg = translate_val(op.arg, **kw) + bit_expr = sg.func(_bit_agg[type(op)], arg) + if op.where is not None: + where = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=bit_expr, expression=where) + return bit_expr + translate_val.register(ops.StandardDev)(_agg_variance_like("stddev")) translate_val.register(ops.Variance)(_agg_variance_like("var")) From d65fd3b6d3d19ae51db95872e47036f558ee05ac Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 14:55:33 -0400 Subject: [PATCH 029/222] feat(duckdb): variance, covariance, stddev aggregation covar_pop is currently broken due to upstream bug in sqlglot, should be fixed in a jiffy --- ibis/backends/duckdb/compiler/values.py | 67 +++++++++++++++++++------ ibis/backends/tests/test_aggregation.py | 7 ++- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index c507c42630f2..22dfe5e466aa 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -922,15 +922,6 @@ def _quantile(op, **kw): return sg_expr -def _agg_variance_like(func): - variants = {"sample": f"{func}_samp", "pop": f"{func}_pop"} - - def formatter(op, **kw): - return _aggregate(op, variants[op.how], where=op.where, **kw) - - return formatter - - @translate_val.register(ops.Correlation) def _corr(op, **kw): if op.how == "pop": @@ -938,6 +929,59 @@ def _corr(op, **kw): return _aggregate(op, "corr", where=op.where, **kw) +@translate_val.register(ops.Covariance) +def _covariance(op, **kw): + _how = {"sample": "samp", "pop": "pop"} + + left = translate_val(op.left, **kw) + if (left_type := op.left.dtype).is_boolean(): + left = sg.cast( + expression=left, + to=DuckDBType.from_ibis(dt.Int32(nullable=left_type.nullable)), + ) + + right = translate_val(op.right, **kw) + if (right_type := op.right.dtype).is_boolean(): + right = sg.cast( + expression=right, + to=DuckDBType.from_ibis(dt.Int32(nullable=right_type.nullable)), + ) + + funcname = f"covar_{_how[op.how]}" + + sg_func = sg.func(funcname, left, right) + + if (where := op.where) is not None: + predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=sg_func, expression=predicate) + + return sg_func + + +@translate_val.register(ops.Variance) +@translate_val.register(ops.StandardDev) +def _variance(op, **kw): + _how = {"sample": "samp", "pop": "pop"} + _func = {ops.Variance: "var", ops.StandardDev: "stddev"} + + funcname = f"{_func[type(op)]}_{_how[op.how]}" + + arg = translate_val(op.arg, **kw) + if (arg_type := op.arg.dtype).is_boolean(): + arg = sg.cast( + expression=arg, + to=DuckDBType.from_ibis(dt.Int32(nullable=arg_type.nullable)), + ) + + sg_func = sg.func(funcname, arg) + + if (where := op.where) is not None: + predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=sg_func, expression=predicate) + + return sg_func + + def _aggregate(op, func, *, where=None, **kw): args = [ translate_val(arg, **kw) @@ -1436,11 +1480,6 @@ def _bitor(op, **kw): return bit_expr -translate_val.register(ops.StandardDev)(_agg_variance_like("stddev")) -translate_val.register(ops.Variance)(_agg_variance_like("var")) -translate_val.register(ops.Covariance)(_agg_variance_like("covar")) - - @translate_val.register def _sort_key(op: ops.SortKey, **kw): arg = translate_val(op.expr, **kw) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index f0403fb2262d..d0c63501178d 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -522,6 +522,11 @@ def mean_and_std(v): lambda t, where: t.double_col[where].var(ddof=0), id="var_pop", marks=[ + mark.broken( + ["duckdb"], + raises=com.IbisError, + reason="sqlglot mistranslates VariancePop to variance_pop instead of var_pop", + ), mark.notimpl( ["druid"], raises=sa.exc.ProgrammingError, @@ -999,7 +1004,7 @@ def test_quantile( reason="Correlation with how='sample' is not supported.", ), pytest.mark.notyet( - ["trino", "postgres", "duckdb", "snowflake", "oracle"], + ["trino", "postgres", "snowflake", "oracle"], raises=ValueError, reason="XXXXSQLExprTranslator only implements population correlation coefficient", ), From e0c771b2e9728e8d6a891b9f29693e9ad845269a Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 14:59:37 -0400 Subject: [PATCH 030/222] fix(duckdb): only corr_pop, no corr_samp --- ibis/backends/duckdb/compiler/values.py | 6 ++++-- ibis/backends/tests/test_aggregation.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 22dfe5e466aa..0775ef50bddc 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -924,8 +924,10 @@ def _quantile(op, **kw): @translate_val.register(ops.Correlation) def _corr(op, **kw): - if op.how == "pop": - raise ValueError("Duckdb only implements `sample` correlation coefficient") + if op.how == "sample": + raise com.UnsupportedOperationError( + "DuckDB only implements `pop` correlation coefficient" + ) return _aggregate(op, "corr", where=op.where, **kw) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index d0c63501178d..87502379870b 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -994,6 +994,11 @@ def test_quantile( ["dask", "datafusion", "pandas", "druid"], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["duckdb"], + raises=com.UnsupportedOperationError, + reason="DuckDB only implements population correlation coefficient", + ), pytest.mark.notyet( ["impala", "mysql", "sqlite"], raises=com.OperationNotDefinedError, From 8198089c33f3c011de6f15b92aa70040b8f23c2d Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 15:18:41 -0400 Subject: [PATCH 031/222] feat(duckdb): group_concat! --- ibis/backends/duckdb/compiler/values.py | 3 +-- ibis/backends/tests/test_aggregation.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 0775ef50bddc..63827be6db40 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1145,7 +1145,7 @@ def _group_concat(op, **kw): arg = translate_val(op.arg, **kw) sep = translate_val(op.sep, **kw) - concat = sg.func("array_to_string", arg, sep, dialect="duckdb") + concat = sg.func("string_agg", arg, sep, dialect="duckdb") if (where := op.where) is not None: predicate = translate_val(where, **kw) @@ -1583,7 +1583,6 @@ def format_window_frame(func, frame, **kw): # TODO UNSUPPORTED_REDUCTIONS = ( ops.ApproxMedian, - ops.GroupConcat, ops.ApproxCountDistinct, ) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 87502379870b..f6d92cacd685 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1167,7 +1167,7 @@ def test_median(alltypes, df): id="expr", marks=[ mark.notyet( - ["duckdb", "trino"], + ["trino"], raises=com.UnsupportedOperationError, ), mark.notyet( From dfcd71104cf7fde89d3c3e6050704532e44e21ec Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 15:47:14 -0400 Subject: [PATCH 032/222] feat(duckdb): median and fix count star --- ibis/backends/duckdb/compiler/values.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 63827be6db40..57611565d82e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -532,8 +532,8 @@ def _levenshtein(op, **kw): ops.RandomScalar: "random", ops.Sign: "sign", # Unary aggregates - # ops.ApproxMedian: "median", # TODO - # ops.Median: "quantileExactExclusive", # TODO + ops.ApproxMedian: "median", + ops.Median: "median", ops.Mean: "avg", ops.Max: "max", ops.Min: "min", @@ -867,11 +867,20 @@ def _count_distinct(op, **kw): return count_expr +# TODO: implement +@translate_val.register(ops.CountDistinctStar) +def _count_distinct_star(op, **kw): + ... + + @translate_val.register(ops.CountStar) def _count_star(op, **kw): sql = sg.expressions.Count(this=sg.expressions.Star()) if (predicate := op.where) is not None: - return sg.select(sql).where(predicate) + return sg.expressions.Filter( + this=sql, + expression=sg.expressions.Where(this=translate_val(op.where, **kw)), + ) return sql From 42d69dcc885838f6cf6898d31561c325eb688eaf Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 15:54:24 -0400 Subject: [PATCH 033/222] fix(duckdb): handle boolean comparisons to corr --- ibis/backends/duckdb/compiler/values.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 57611565d82e..69b5b5268c7b 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -937,7 +937,28 @@ def _corr(op, **kw): raise com.UnsupportedOperationError( "DuckDB only implements `pop` correlation coefficient" ) - return _aggregate(op, "corr", where=op.where, **kw) + + left = translate_val(op.left, **kw) + if (left_type := op.left.dtype).is_boolean(): + left = sg.cast( + expression=left, + to=DuckDBType.from_ibis(dt.Int32(nullable=left_type.nullable)), + ) + + right = translate_val(op.right, **kw) + if (right_type := op.right.dtype).is_boolean(): + right = sg.cast( + expression=right, + to=DuckDBType.from_ibis(dt.Int32(nullable=right_type.nullable)), + ) + + sg_func = sg.func("corr", left, right) + + if (where := op.where) is not None: + predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=sg_func, expression=predicate) + + return sg_func @translate_val.register(ops.Covariance) From a0b753d51278d06e72c1925d1549af7df571a876 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Fri, 8 Sep 2023 16:12:43 -0400 Subject: [PATCH 034/222] feat(duckdb): make sure any and all handle filters --- ibis/backends/duckdb/compiler/values.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 69b5b5268c7b..8b11cd128f00 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -187,6 +187,26 @@ def _not(op, **kw): return sg.expressions.Not(this=arg) +@translate_val.register(ops.Any) +def _any(op, **kw): + arg = translate_val(op.arg, **kw) + any_expr = sg.expressions.AnyValue(this=arg) + if op.where is not None: + where = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=any_expr, expression=where) + return any_expr + + +@translate_val.register(ops.All) +def _any(op, **kw): + arg = translate_val(op.arg, **kw) + all_expr = sg.func("bool_and", arg) + if op.where is not None: + where = sg.expressions.Where(this=translate_val(op.where, **kw)) + return sg.expressions.Filter(this=all_expr, expression=where) + return all_expr + + ### Timey McTimeFace @@ -537,8 +557,6 @@ def _levenshtein(op, **kw): ops.Mean: "avg", ops.Max: "max", ops.Min: "min", - ops.Any: "any_value", - ops.All: "min", ops.ArgMin: "arg_min", ops.Mode: "mode", ops.ArgMax: "arg_max", From eda2f8db4e9ed3ec18caf30f6f1193420e7125ed Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 10:38:20 -0400 Subject: [PATCH 035/222] feat(duckdb): port all reg functions to sqlglot Also restore default UTC and progress bar checks to initial connect --- ibis/backends/duckdb/__init__.py | 137 ++++++++++++-------- ibis/backends/duckdb/compiler/relations.py | 2 +- ibis/backends/duckdb/tests/test_register.py | 12 +- ibis/backends/tests/test_register.py | 6 +- 4 files changed, 95 insertions(+), 62 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index c7719871a91d..34b916c7a75f 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -104,6 +104,8 @@ def current_schema(self) -> str: return self.raw_sql("SELECT current_schema()") def raw_sql(self, query: str, **kwargs: Any) -> Any: + if isinstance(query, sg.Expression): + query = query.sql(dialect="duckdb") return self.con.execute(query, **kwargs) def create_table( @@ -222,7 +224,7 @@ def _fully_qualified_name(self, name: str, database: str | None) -> str: # This is a hack to get around nested quoting of table name # e.g. '"main._ibis_temp_table_2"' return name - return sg.table(name, db=db).sql(dialect="duckdb") + return sg.table(name, db=db) # .sql(dialect="duckdb") def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema: """Return a Schema object for the indicated table and database. @@ -241,7 +243,9 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema Ibis schema """ qualified_name = self._fully_qualified_name(table_name, database) - query = f"DESCRIBE {qualified_name}" + if isinstance(qualified_name, str): + qualified_name = sg.expressions.Identifier(this=qualified_name, quoted=True) + query = sg.expressions.Describe(this=qualified_name) results = self.raw_sql(query) names, types, *_ = results.fetch_arrow_table() names = names.to_pylist() @@ -344,10 +348,18 @@ def do_connect( import duckdb - self.con = duckdb.connect(str(database)) + self.con = duckdb.connect(str(database), config=config) - # TODO: disable progress bar for < 0.8.0 - # TODO: set timezone to UTC + # Load any pre-specified extensions + if extensions is not None: + self._load_extensions(extensions) + + # Default timezone + self.con.execute("SET TimeZone = 'UTC'") + # the progress bar in duckdb <0.8.0 causes kernel crashes in + # jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831 + if vparse(duckdb.__version__) < vparse("0.8.0"): + self.con.execute("SET enable_progress_bar = false") self._record_batch_readers_consumed = {} @@ -548,7 +560,7 @@ def register( else: try: return self.read_in_memory(source, table_name=table_name, **kwargs) - except sa.exc.ProgrammingError: + except (duckdb.InvalidInputException, NameError): self._register_failure() if first.startswith(("parquet://", "parq://")) or first.endswith( @@ -580,10 +592,17 @@ def _register_failure(self): ) def _compile_temp_view(self, table_name, source): - raw_source = source.compile( - dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True) + return sg.expressions.Create( + this=sg.expressions.Identifier( + this=table_name, quoted=True + ), # CREATE ... 'table_name' + kind="VIEW", # VIEW + replace=True, # OR REPLACE + properties=sg.expressions.Properties( + expressions=[sg.expressions.TemporaryProperty()] # TEMPORARY + ), + expression=source, # AS ... ) - return f'CREATE OR REPLACE TEMPORARY VIEW "{table_name}" AS {raw_source}' @util.experimental def read_json( @@ -619,16 +638,16 @@ def read_json( if not table_name: table_name = util.gen_name("read_json") - source = sa.select(sa.literal_column("*")).select_from( - sa.func.read_json_auto( - sa.func.list_value(*normalize_filenames(source_list)), - _format_kwargs(kwargs), - ) + options = [f"{key}={val}" for key, val in kwargs.items()] + + sg_view_expr = self._compile_temp_view( + table_name, + sg.select("*").from_( + sg.func("read_json_auto", normalize_filenames(source_list), *options) + ), ) - view = self._compile_temp_view(table_name, source) - with self.begin() as con: - con.exec_driver_sql(view) + self.raw_sql(sg_view_expr) return self.table(table_name) def read_csv( @@ -671,11 +690,20 @@ def read_csv( kwargs.setdefault("header", True) kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs) - options = ", " + ",".join([f"{key}={val}" for key, val in kwargs.items()]) + # TODO: clean this up + # We want to _usually_ quote arguments but if we quote `columns` it messes + # up DuckDB's struct parsing. + options = [ + f'{key}="{val}"' if key != "columns" else f"{key}={val}" + for key, val in kwargs.items() + ] - sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv({source_list}{options})" + sg_view_expr = self._compile_temp_view( + table_name, + sg.select("*").from_(sg.func("read_csv", source_list, *options)), + ) - self.raw_sql(sql) + self.raw_sql(sg_view_expr) return self.table(table_name) def read_parquet( @@ -710,12 +738,10 @@ def read_parquet( # Default to using the native duckdb parquet reader # If that fails because of auth issues, fall back to ingesting via # pyarrow dataset - self._read_parquet_duckdb_native(source_list, table_name, **kwargs) - # except sa.exc.OperationalError as e: - # if isinstance(e.orig, duckdb.IOException): - # self._read_parquet_pyarrow_dataset(source_list, table_name, **kwargs) - # else: - # raise e + try: + self._read_parquet_duckdb_native(source_list, table_name, **kwargs) + except duckdb.IOException: + self._read_parquet_pyarrow_dataset(source_list, table_name, **kwargs) return self.table(table_name) @@ -728,13 +754,18 @@ def _read_parquet_duckdb_native( ): self._load_extensions(["httpfs"]) - options = "" if kw := kwargs: - options = ", " + ",".join([f"{key}={val}" for key, val in kw.items()]) + options = [f"{key}={val}" for key, val in kwargs.items()] + pq_func = sg.func("read_parquet", source_list, *options) + else: + pq_func = sg.func("read_parquet", source_list) - sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet({source_list}{options})" + sg_view_expr = self._compile_temp_view( + table_name, + sg.select("*").from_(pq_func), + ) - self.raw_sql(sql) + self.raw_sql(sg_view_expr) def _read_parquet_pyarrow_dataset( self, source_list: str | Iterable[str], table_name: str, **kwargs: Any @@ -746,12 +777,11 @@ def _read_parquet_pyarrow_dataset( # We don't create a view since DuckDB special cases Arrow Datasets # so if we also create a view we end up with both a "lazy table" # and a view with the same name - with self.begin() as con: - # DuckDB normally auto-detects Arrow Datasets that are defined - # in local variables but the `dataset` variable won't be local - # by the time we execute against this so we register it - # explicitly. - con.connection.register(table_name, dataset) + self.con.register(table_name, dataset) + # DuckDB normally auto-detects Arrow Datasets that are defined + # in local variables but the `dataset` variable won't be local + # by the time we execute against this so we register it + # explicitly. def read_in_memory( self, @@ -774,8 +804,7 @@ def read_in_memory( The just-registered table """ table_name = table_name or util.gen_name("read_in_memory") - with self.begin() as con: - con.connection.register(table_name, source) + self.con.register(table_name, source) if isinstance(source, pa.RecordBatchReader): # Ensure the reader isn't marked as started, in case the name is @@ -851,12 +880,14 @@ def read_postgres( "`table_name` is required when registering a postgres table" ) self._load_extensions(["postgres_scanner"]) - source = sa.select(sa.literal_column("*")).select_from( - sa.func.postgres_scan_pushdown(uri, schema, table_name) + + sg_view_expr = self._compile_temp_view( + table_name, + sg.select("*").from_( + sg.func("postgres_scan_pushdown", uri, schema, table_name) + ), ) - view = self._compile_temp_view(table_name, source) - with self.begin() as con: - con.exec_driver_sql(view) + self.raw_sql(sg_view_expr) return self.table(table_name) @@ -903,12 +934,15 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab raise ValueError("`table_name` is required when registering a sqlite table") self._load_extensions(["sqlite"]) - source = sa.select(sa.literal_column("*")).select_from( - sa.func.sqlite_scan(str(path), table_name) + sg_view_expr = self._compile_temp_view( + table_name, + sg.select("*").from_( + sg.func( + "sqlite_scan", sg.to_identifier(str(path), quoted=True), table_name + ) + ), ) - view = self._compile_temp_view(table_name, source) - with self.begin() as con: - con.exec_driver_sql(view) + self.raw_sql(sg_view_expr) return self.table(table_name) @@ -943,10 +977,9 @@ def attach_sqlite( >>> con.list_tables() ['t'] """ - self._load_extensions(["sqlite"]) - with self.begin() as con: - con.execute(sa.text(f"SET GLOBAL sqlite_all_varchar={all_varchar}")) - con.execute(sa.text(f"CALL sqlite_attach('{path}', overwrite={overwrite})")) + self.load_extension("sqlite") + self.raw_sql(f"SET GLOBAL sqlite_all_varchar={all_varchar}") + self.raw_sql(f"CALL sqlite_attach('{path}', overwrite={overwrite})") def _run_pre_execute_hooks(self, expr: ir.Expr) -> None: # Warn for any tables depending on RecordBatchReaders that have already diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 202dabc977d2..4ae6c360414d 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -24,7 +24,7 @@ def _dummy(op: ops.DummyTable, **kw): @translate_rel.register(ops.PhysicalTable) def _physical_table(op: ops.PhysicalTable, **_): - return sg.parse_one(op.name, into=sg.exp.Table) + return sg.expressions.Table(this=sg.to_identifier(op.name, quoted=True)) @translate_rel.register(ops.Selection) diff --git a/ibis/backends/duckdb/tests/test_register.py b/ibis/backends/duckdb/tests/test_register.py index 4bdfee7baf62..2e404e016283 100644 --- a/ibis/backends/duckdb/tests/test_register.py +++ b/ibis/backends/duckdb/tests/test_register.py @@ -66,16 +66,15 @@ def test_temp_directory(tmp_path): # 1. in-memory + no temp_directory specified con = ibis.duckdb.connect() - with con.begin() as c: - value = c.exec_driver_sql(query).scalar() - assert value # we don't care what the specific value is + + value = con.raw_sql(query).fetchone()[0] + assert value # we don't care what the specific value is temp_directory = Path(tempfile.gettempdir()) / "duckdb" # 2. in-memory + temp_directory specified con = ibis.duckdb.connect(temp_directory=temp_directory) - with con.begin() as c: - value = c.exec_driver_sql(query).scalar() + value = con.raw_sql(query).fetchone()[0] assert value == str(temp_directory) # 3. on-disk + no temp_directory specified @@ -84,8 +83,7 @@ def test_temp_directory(tmp_path): # 4. on-disk + temp_directory specified con = ibis.duckdb.connect(tmp_path / "test2.ddb", temp_directory=temp_directory) - with con.begin() as c: - value = c.exec_driver_sql(query).scalar() + value = con.raw_sql(query).fetchone()[0] assert value == str(temp_directory) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 5fb3155b84bb..7e7de6c9d5f1 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -384,12 +384,14 @@ def test_register_garbage(con, monkeypatch): monkeypatch.setattr(con, "_load_extensions", lambda x: True) sa = pytest.importorskip("sqlalchemy") + duckdb = pytest.importorskip("duckdb") with pytest.raises( - sa.exc.OperationalError, match="No files found that match the pattern" + (sa.exc.OperationalError, duckdb.IOException), + match="No files found that match the pattern", ): con.read_csv("garbage_notafile") - with pytest.raises(FileNotFoundError): + with pytest.raises((FileNotFoundError, duckdb.IOException)): con.read_parquet("garbage_notafile") From d4d802c050c1e1ef81389dce56a72998da15b925 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 11:33:10 -0400 Subject: [PATCH 036/222] fix(duckdb): all substring slicing works! --- ibis/backends/duckdb/compiler/values.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8b11cd128f00..249301845104 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1456,7 +1456,10 @@ def formatter(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return sg_expr(this=left, expression=right) + return sg_expr( + this=sg.expressions.Paren(this=left), + expression=sg.expressions.Paren(this=right), + ) return formatter From fe9052f322ac064b99c0d5242a050ffdcbd8caa3 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 14:54:41 -0400 Subject: [PATCH 037/222] fix(duckdb): handle negative bounds in arrayslice --- ibis/backends/duckdb/compiler/values.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 249301845104..8124350bb0e7 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -847,17 +847,32 @@ def _literal(op, **kw): # return query.subquery() +def _neg_idx_to_pos(array, idx): + return sg.expressions.If( + this=sg.expressions.LT(this=idx, expression=sg_literal(0, is_string=False)), + true=sg.func("len", array) + idx, + false=idx, + ) + + @translate_val.register(ops.ArraySlice) def _array_slice_op(op, **kw): arg = translate_val(op.arg, **kw) - start = translate_val(op.start, **kw) - if (stop := op.stop) is not None: - stop = translate_val(stop, **kw) + arg_length = sg.func("len", arg) + + if (start := op.start) is None: + start = sg_literal(0, is_string=False) else: + start = translate_val(op.start, **kw) + start = sg.func("least", arg_length, _neg_idx_to_pos(arg, start)) + + if (stop := op.stop) is None: stop = sg.expressions.Null() + else: + stop = _neg_idx_to_pos(arg, translate_val(stop, **kw)) - return sg.func("list_slice", arg, start, stop) + return sg.func("list_slice", arg, start + 1, stop) @translate_val.register(ops.Count) From a9bca69ce9b9927570bb31e32fadcc4191547932 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 15:12:38 -0400 Subject: [PATCH 038/222] refactor(duckdb): rearrange and organize by op type --- ibis/backends/duckdb/compiler/values.py | 995 ++++++++++++------------ 1 file changed, 496 insertions(+), 499 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8124350bb0e7..6c0bbe60e587 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -62,6 +62,227 @@ def _alias(op, render_aliases: bool = True, **kw): return val +### Literals + + +def _sql(obj, dialect="duckdb"): + try: + return obj.sql(dialect=dialect) + except AttributeError: + return obj + + +def sg_literal(arg, is_string=True): + return sg.expressions.Literal(this=f"{arg}", is_string=is_string) + + +@translate_val.register(ops.Literal) +def _literal(op, **kw): + value = op.value + dtype = op.dtype + if value is None and dtype.nullable: + if dtype.is_null(): + return sg.expressions.Null() + return sg.cast(sg.expressions.Null(), to=DuckDBType.from_ibis(dtype)) + if dtype.is_boolean(): + return sg.expressions.Boolean(this=value) + elif dtype.is_inet(): + com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") + elif dtype.is_string(): + return sg_literal(value) + elif dtype.is_decimal(): + # TODO: make this a sqlglot expression + precision = dtype.precision + scale = dtype.scale + if precision is None: + precision = 38 + if scale is None: + scale = 9 + if not 1 <= precision <= 38: + raise NotImplementedError( + f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" + ) + + # TODO: handle if `value` is "Infinity" + # precision = sg.expressions.DataTypeParam( + # this=sg.expressions.Literal(this=f"{precision}", is_string=False) + # ) + # scale = sg.expressions.DataTypeParam( + # this=sg.expressions.Literal(this=f"{scale}", is_string=False) + # ) + # need sg.expressions.DataTypeParam to be available + # ... + return f"{value!s}::decimal({precision}, {scale})" + elif dtype.is_numeric(): + if math.isinf(value): + return sg.expressions.cast( + expression=sg_literal(value), + to=sg.expressions.DataType.Type.FLOAT, + ) + elif math.isnan(value): + return sg.expressions.cast( + expression=sg_literal("NaN"), + to=sg.expressions.DataType.Type.FLOAT, + ) + return sg_literal(value, is_string=False) + elif dtype.is_interval(): + return _interval_format(op) + elif dtype.is_timestamp(): + year = sg_literal(op.value.year, is_string=False) + month = sg_literal(op.value.month, is_string=False) + day = sg_literal(op.value.day, is_string=False) + hour = sg_literal(op.value.hour, is_string=False) + minute = sg_literal(op.value.minute, is_string=False) + second = sg_literal(op.value.second, is_string=False) + if op.value.microsecond: + microsecond = sg_literal(op.value.microsecond / 1e6, is_string=False) + second += microsecond + if dtype.timezone is not None: + timezone = sg_literal(dtype.timezone, is_string=True) + return sg.func( + "make_timestamptz", year, month, day, hour, minute, second, timezone + ) + else: + return sg.func("make_timestamp", year, month, day, hour, minute, second) + elif dtype.is_date(): + year = sg_literal(op.value.year, is_string=False) + month = sg_literal(op.value.month, is_string=False) + day = sg_literal(op.value.day, is_string=False) + return sg.expressions.DateFromParts(year=year, month=month, day=day) + elif dtype.is_array(): + value_type = dtype.value_type + is_string = isinstance(value_type, dt.String) + values = sg.expressions.Array().from_arg_list( + [ + # TODO: this cast makes for frustrating output + # is there any better way to handle it? + sg.cast( + sg_literal(v, is_string=is_string), + to=DuckDBType.from_ibis(value_type), + ) + for v in value + ] + ) + return values + elif dtype.is_map(): + key_type = dtype.key_type + value_type = dtype.value_type + keys = sg.expressions.Array().from_arg_list( + [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] + ) + values = sg.expressions.Array().from_arg_list( + [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] + ) + sg_expr = sg.expressions.Map(keys=keys, values=values) + return sg_expr + elif dtype.is_struct(): + keys = [sg_literal(key) for key in value.keys()] + values = [ + _literal(ops.Literal(v, dtype=subdtype), **kw) + for subdtype, v in zip(dtype.types, value.values()) + ] + slices = [ + sg.expressions.Slice(this=k, expression=v) for k, v in zip(keys, values) + ] + sg_expr = sg.expressions.Struct.from_arg_list(slices) + return sg_expr + else: + raise NotImplementedError(f"Unsupported type: {dtype!r}") + + +### Simple Ops + +_simple_ops = { + ops.Power: "pow", + # Unary operations + ops.IsNan: "isnan", + ops.IsInf: "isinf", + ops.Abs: "abs", + ops.Ceil: "ceil", + ops.Floor: "floor", + ops.Exp: "exp", + ops.Sqrt: "sqrt", + ops.Ln: "ln", + ops.Log2: "log2", + ops.Log10: "log", + ops.Acos: "acos", + ops.Asin: "asin", + ops.Atan: "atan", + ops.Atan2: "atan2", + ops.Cos: "cos", + ops.Sin: "sin", + ops.Tan: "tan", + ops.Cot: "cot", + ops.Pi: "pi", + ops.RandomScalar: "random", + ops.Sign: "sign", + # Unary aggregates + ops.ApproxMedian: "median", + ops.Median: "median", + ops.Mean: "avg", + ops.Max: "max", + ops.Min: "min", + ops.ArgMin: "arg_min", + ops.Mode: "mode", + ops.ArgMax: "arg_max", + ops.First: "first", + ops.Last: "last", + # string operations + ops.StringContains: "contains", + ops.StringLength: "length", + ops.Lowercase: "lower", + ops.Uppercase: "upper", + ops.Reverse: "reverse", + ops.StringReplace: "replace", + ops.StartsWith: "prefix", + ops.EndsWith: "suffix", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.LStrip: "ltrim", + ops.RStrip: "rtrim", + ops.Strip: "trim", + ops.StringAscii: "ascii", + ops.StrRight: "right", + # Other operations + ops.Where: "if", + ops.ArrayLength: "length", + ops.Unnest: "unnest", + ops.Degrees: "degrees", + ops.Radians: "radians", + ops.NullIf: "nullIf", + ops.MapLength: "cardinality", + ops.MapKeys: "map_keys", + ops.MapValues: "map_values", + ops.ArraySort: "list_sort", + ops.ArrayContains: "list_contains", + ops.FirstValue: "first_value", + ops.LastValue: "last_value", + ops.NTile: "ntile", + ops.Hash: "hash", +} + + +def _agg(func_name): + def formatter(op, **kw): + return _aggregate(op, func_name, where=op.where, **kw) + + return formatter + + +for _op, _name in _simple_ops.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + translate_val.register(_op)(_agg(_name)) + else: + + @translate_val.register(_op) + def _fmt(op, _name: str = _name, **kw): + return sg.func( + _name, *map(partial(translate_val, **kw), op.args), dialect="duckdb" + ) + + +del _fmt, _name, _op ### Bitwise Business _bitwise_mapping = { @@ -111,6 +332,32 @@ def _generic_log(op, **kw): return sg.func("ln", arg) +@translate_val.register(ops.Clip) +def _clip(op, **kw): + arg = translate_val(op.arg, **kw) + if (upper := op.upper) is not None: + arg = f"least({translate_val(upper, **kw)}, {arg})" + + if (lower := op.lower) is not None: + arg = f"greatest({translate_val(lower, **kw)}, {arg})" + + return arg + + +@translate_val.register(ops.FloorDivide) +def _floor_divide(op, **kw): + new_op = ops.Floor(ops.Divide(op.left, op.right)) + return translate_val(new_op, **kw) + + +@translate_val.register(ops.Round) +def _round(op, **kw): + arg = translate_val(op.arg, **kw) + if (digits := op.digits) is not None: + return sg.expressions.Round(this=arg, decimals=translate_val(digits, **kw)) + return sg.expressions.Round(this=arg) + + ### Dtype Dysmorphia @@ -198,7 +445,7 @@ def _any(op, **kw): @translate_val.register(ops.All) -def _any(op, **kw): +def _all(op, **kw): arg = translate_val(op.arg, **kw) all_expr = sg.func("bool_and", arg) if op.where is not None: @@ -207,6 +454,16 @@ def _any(op, **kw): return all_expr +@translate_val.register(ops.NotAny) +def _not_any(op, **kw): + return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) + + +@translate_val.register(ops.NotAll) +def _not_all(op, **kw): + return translate_val(ops.Any(ops.Not(op.arg), where=op.where), **kw) + + ### Timey McTimeFace @@ -216,18 +473,58 @@ def _to_date(op, **kw): return sg.expressions.Date(this=arg) +@translate_val.register(ops.DateFromYMD) +def _date_from_ymd(op, **kw): + y = translate_val(op.year, **kw) + m = translate_val(op.month, **kw) + d = translate_val(op.day, **kw) + return sg.expressions.DateFromParts(year=y, month=m, day=d) + + @translate_val.register(ops.Time) def _time(op, **kw): arg = translate_val(op.arg, **kw) return sg.cast(expression=arg, to=sg.expressions.DataType.Type.TIME) +@translate_val.register(ops.TimeFromHMS) +def _time_from_hms(op, **kw): + hours = translate_val(op.hours, **kw) + minutes = translate_val(op.minutes, **kw) + seconds = translate_val(op.seconds, **kw) + return sg.func("make_time", hours, minutes, seconds) + + @translate_val.register(ops.TimestampNow) def _timestamp_now(op, **kw): """DuckDB current timestamp defaults to timestamp + tz""" return sg.cast(expression=sg.func("current_timestamp"), to="TIMESTAMP") +@translate_val.register(ops.TimestampFromUNIX) +def _timestamp_from_unix(op, **kw): + arg = translate_val(op.arg, **kw) + if (unit := op.unit.short) in {"ms", "us", "ns"}: + raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") + + return sg.expressions.UnixToTime(this=arg) + + +@translate_val.register(ops.TimestampFromYMDHMS) +def _timestamp_from_ymdhms(op, **kw): + year = translate_val(op.year, **kw) + month = translate_val(op.month, **kw) + day = translate_val(op.day, **kw) + hour = translate_val(op.hours, **kw) + minute = translate_val(op.minutes, **kw) + second = translate_val(op.seconds, **kw) + + if (timezone := op.dtype.timezone) is not None: + return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" + else: + return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" + + @translate_val.register(ops.Strftime) def _strftime(op, **kw): if not isinstance(op.format_str, ops.Literal): @@ -239,14 +536,6 @@ def _strftime(op, **kw): return sg.func("strftime", arg, format_str) -@translate_val.register(ops.TimeFromHMS) -def _time_from_hms(op, **kw): - hours = translate_val(op.hours, **kw) - minutes = translate_val(op.minutes, **kw) - seconds = translate_val(op.seconds, **kw) - return sg.func("make_time", hours, minutes, seconds) - - @translate_val.register(ops.StringToTimestamp) def _string_to_timestamp(op, **kw): arg = translate_val(op.arg, **kw) @@ -326,13 +615,6 @@ def _extract_microsecond(op, **kw): ) -@translate_val.register(ops.Date) -def _date(op, **kw): - arg = translate_val(op.arg, **kw) - - return sg.expressions.Date(this=arg) - - @translate_val.register(ops.DateTruncate) @translate_val.register(ops.TimestampTruncate) @translate_val.register(ops.TimeTruncate) @@ -359,42 +641,42 @@ def _truncate(op, **kw): return f"date_trunc('{duckunit}', {arg})" -@translate_val.register(ops.DateFromYMD) -def _date_from_ymd(op, **kw): - y = translate_val(op.year, **kw) - m = translate_val(op.month, **kw) - d = translate_val(op.day, **kw) - return sg.expressions.DateFromParts(year=y, month=m, day=d) - - @translate_val.register(ops.DayOfWeekIndex) def _day_of_week_index(op, **kw): arg = translate_val(op.arg, **kw) return f"(dayofweek({arg}) + 6) % 7" -@translate_val.register(ops.TimestampFromUNIX) -def _timestamp_from_unix(op, **kw): - arg = translate_val(op.arg, **kw) - if (unit := op.unit.short) in {"ms", "us", "ns"}: - raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - - return sg.expressions.UnixToTime(this=arg) - - -@translate_val.register(ops.TimestampFromYMDHMS) -def _timestamp_from_ymdhms(op, **kw): - year = translate_val(op.year, **kw) - month = translate_val(op.month, **kw) - day = translate_val(op.day, **kw) - hour = translate_val(op.hours, **kw) - minute = translate_val(op.minutes, **kw) - second = translate_val(op.seconds, **kw) - - if (timezone := op.dtype.timezone) is not None: - return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" - else: - return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" +@translate_val.register(ops.DayOfWeekName) +def day_of_week_name(op, **kw): + # day of week number is 0-indexed + # Sunday == 0 + # Saturday == 6 + arg = op.arg + nullable = arg.dtype.nullable + empty_string = ops.Literal("", dtype=dt.String(nullable=nullable)) + weekdays = range(7) + return translate_val( + ops.NullIf( + ops.SimpleCase( + base=ops.DayOfWeekIndex(arg), + cases=[ + ops.Literal(day, dtype=dt.Int8(nullable=nullable)) + for day in weekdays + ], + results=[ + ops.Literal( + calendar.day_name[day], + dtype=dt.String(nullable=nullable), + ) + for day in weekdays + ], + default=empty_string, + ), + empty_string, + ), + **kw, + ) ### Interval Marginalia @@ -525,99 +807,50 @@ def _levenshtein(op, **kw): return sg.func("levenshtein", left, right) -### Simple Ops - -_simple_ops = { - ops.Power: "pow", - # Unary operations - ops.IsNan: "isnan", - ops.IsInf: "isinf", - ops.Abs: "abs", - ops.Ceil: "ceil", - ops.Floor: "floor", - ops.Exp: "exp", - ops.Sqrt: "sqrt", - ops.Ln: "ln", - ops.Log2: "log2", - ops.Log10: "log", - ops.Acos: "acos", - ops.Asin: "asin", - ops.Atan: "atan", - ops.Atan2: "atan2", - ops.Cos: "cos", - ops.Sin: "sin", - ops.Tan: "tan", - ops.Cot: "cot", - ops.Pi: "pi", - ops.RandomScalar: "random", - ops.Sign: "sign", - # Unary aggregates - ops.ApproxMedian: "median", - ops.Median: "median", - ops.Mean: "avg", - ops.Max: "max", - ops.Min: "min", - ops.ArgMin: "arg_min", - ops.Mode: "mode", - ops.ArgMax: "arg_max", - ops.First: "first", - ops.Last: "last", - # string operations - ops.StringContains: "contains", - ops.StringLength: "length", - ops.Lowercase: "lower", - ops.Uppercase: "upper", - ops.Reverse: "reverse", - ops.StringReplace: "replace", - ops.StartsWith: "prefix", - ops.EndsWith: "suffix", - ops.LPad: "lpad", - ops.RPad: "rpad", - ops.LStrip: "ltrim", - ops.RStrip: "rtrim", - ops.Strip: "trim", - ops.StringAscii: "ascii", - ops.StrRight: "right", - # Other operations - ops.Where: "if", - ops.ArrayLength: "length", - ops.Unnest: "unnest", - ops.Degrees: "degrees", - ops.Radians: "radians", - ops.NullIf: "nullIf", - ops.MapLength: "cardinality", - ops.MapKeys: "map_keys", - ops.MapValues: "map_values", - ops.ArraySort: "list_sort", - ops.ArrayContains: "list_contains", - ops.FirstValue: "first_value", - ops.LastValue: "last_value", - ops.NTile: "ntile", - ops.Hash: "hash", -} +@translate_val.register(ops.StringSplit) +def _string_split(op, **kw): + arg = translate_val(op.arg, **kw) + delimiter = translate_val(op.delimiter, **kw) + return sg.expressions.Split(this=arg, expression=delimiter) -def _agg(func_name): - def formatter(op, **kw): - return _aggregate(op, func_name, where=op.where, **kw) +@translate_val.register(ops.StringJoin) +def _string_join(op, **kw): + arg = map(partial(translate_val, **kw), op.arg) + sep = translate_val(op.sep, **kw) + elements = ", ".join(map(_sql, arg)) + return f"list_aggregate([{elements}], 'string_agg', {sep})" - return formatter +@translate_val.register(ops.StringConcat) +def _string_concat(op, **kw): + arg = map(partial(translate_val, **kw), op.arg) + return sg.expressions.Concat(expressions=list(arg)) -for _op, _name in _simple_ops.items(): - assert isinstance(type(_op), type), type(_op) - if issubclass(_op, ops.Reduction): - translate_val.register(_op)(_agg(_name)) - else: - @translate_val.register(_op) - def _fmt(op, _name: str = _name, **kw): - return sg.func( - _name, *map(partial(translate_val, **kw), op.args), dialect="duckdb" - ) +@translate_val.register(ops.StringSQLLike) +def _string_like(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + return sg.expressions.Like(this=arg, expression=pattern) -del _fmt, _name, _op +@translate_val.register(ops.StringSQLILike) +def _string_ilike(op, **kw): + arg = translate_val(op.arg, **kw) + pattern = translate_val(op.pattern, **kw) + return sg.expressions.Like(this=sg.func("lower", arg), expression=pattern) + + +@translate_val.register(ops.Capitalize) +def _string_capitalize(op, **kw): + arg = translate_val(op.arg, **kw) + return sg.expressions.Concat( + expressions=[ + sg.func("upper", sg.func("substr", arg, 1, 1)), + sg.func("lower", sg.func("substr", arg, 2)), + ] + ) ### NULL PLAYER CHARACTER @@ -640,6 +873,18 @@ def _if_null(op, **kw): return sg.func("ifnull", arg, ifnull, dialect="duckdb") +@translate_val.register(ops.NullIfZero) +def _null_if_zero(op, **kw): + arg = translate_val(op.arg, **kw) + return sg.func("nullif", arg, 0, dialect="duckdb") + + +@translate_val.register(ops.ZeroIfNull) +def _zero_if_null(op, **kw): + arg = translate_val(op.arg, **kw) + return sg.func("ifnull", arg, 0, dialect="duckdb") + + ### Definitely Not Tensors @@ -711,130 +956,6 @@ def _array_concat(op, **kw): return sg_expr -### LITERALLY - - -@translate_val.register(ops.Literal) -def _literal(op, **kw): - value = op.value - dtype = op.dtype - if value is None and dtype.nullable: - if dtype.is_null(): - return sg.expressions.Null() - return sg.cast(sg.expressions.Null(), to=DuckDBType.from_ibis(dtype)) - if dtype.is_boolean(): - return sg.expressions.Boolean(this=value) - elif dtype.is_inet(): - com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") - elif dtype.is_string(): - return sg.expressions.Literal(this=f"{value}", is_string=True) - elif dtype.is_decimal(): - # TODO: make this a sqlglot expression - precision = dtype.precision - scale = dtype.scale - if precision is None: - precision = 38 - if scale is None: - scale = 9 - if not 1 <= precision <= 38: - raise NotImplementedError( - f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" - ) - - # TODO: handle if `value` is "Infinity" - # precision = sg.expressions.DataTypeParam( - # this=sg.expressions.Literal(this=f"{precision}", is_string=False) - # ) - # scale = sg.expressions.DataTypeParam( - # this=sg.expressions.Literal(this=f"{scale}", is_string=False) - # ) - # need sg.expressions.DataTypeParam to be available - # ... - return f"{value!s}::decimal({precision}, {scale})" - elif dtype.is_numeric(): - if math.isinf(value): - return sg.expressions.cast( - expression=sg.expressions.Literal(this=value, is_string=True), - to=sg.expressions.DataType.Type.FLOAT, - ) - elif math.isnan(value): - return sg.expressions.cast( - expression=sg.expressions.Literal(this="NaN", is_string=True), - to=sg.expressions.DataType.Type.FLOAT, - ) - return sg.expressions.Literal(this=f"{value}", is_string=False) - elif dtype.is_interval(): - return _interval_format(op) - elif dtype.is_timestamp(): - year = sg.expressions.Literal(this=f"{op.value.year}", is_string=False) - month = sg.expressions.Literal(this=f"{op.value.month}", is_string=False) - day = sg.expressions.Literal(this=f"{op.value.day}", is_string=False) - hour = sg.expressions.Literal(this=f"{op.value.hour}", is_string=False) - minute = sg.expressions.Literal(this=f"{op.value.minute}", is_string=False) - second = sg.expressions.Literal(this=f"{op.value.second}", is_string=False) - if op.value.microsecond: - microsecond = sg.expressions.Literal( - this=f"{op.value.microsecond / 1e6}", is_string=False - ) - second += microsecond - if dtype.timezone is not None: - timezone = sg.expressions.Literal(this=dtype.timezone, is_string=True) - return sg.func( - "make_timestamptz", year, month, day, hour, minute, second, timezone - ) - else: - return sg.func("make_timestamp", year, month, day, hour, minute, second) - elif dtype.is_date(): - year = sg.expressions.Literal(this=f"{op.value.year}", is_string=False) - month = sg.expressions.Literal(this=f"{op.value.month}", is_string=False) - day = sg.expressions.Literal(this=f"{op.value.day}", is_string=False) - return sg.expressions.DateFromParts(year=year, month=month, day=day) - elif dtype.is_array(): - value_type = dtype.value_type - is_string = isinstance(value_type, dt.String) - values = sg.expressions.Array().from_arg_list( - [ - # TODO: this cast makes for frustrating output - # is there any better way to handle it? - sg.cast( - sg.expressions.Literal(this=f"{v}", is_string=is_string), - to=DuckDBType.from_ibis(value_type), - ) - for v in value - ] - ) - return values - elif dtype.is_map(): - key_type = dtype.key_type - value_type = dtype.value_type - keys = sg.expressions.Array().from_arg_list( - [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] - ) - values = sg.expressions.Array().from_arg_list( - [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] - ) - sg_expr = sg.expressions.Map(keys=keys, values=values) - return sg_expr - elif dtype.is_struct(): - keys = [ - sg.expressions.Literal(this=key, is_string=True) for key in value.keys() - ] - values = [ - _literal(ops.Literal(v, dtype=subdtype), **kw) - for subdtype, v in zip(dtype.types, value.values()) - ] - slices = [ - sg.expressions.Slice(this=k, expression=v) for k, v in zip(keys, values) - ] - sg_expr = sg.expressions.Struct.from_arg_list(slices) - return sg_expr - else: - raise NotImplementedError(f"Unsupported type: {dtype!r}") - - -### BELOW HERE BE DRAGONS - - # # TODO # @translate_val.register(ops.ArrayRepeat) # def _array_repeat_op(op, **kw): @@ -875,6 +996,74 @@ def _array_slice_op(op, **kw): return sg.func("list_slice", arg, start + 1, stop) +@translate_val.register(ops.ArrayStringJoin) +def _array_string_join(op, **kw): + arg = translate_val(op.arg, **kw) + sep = translate_val(op.sep, **kw) + return f"list_aggregate({arg}, 'string_agg', {sep})" + + +@translate_val.register(ops.ArrayMap) +def _array_map(op, **kw): + arg = translate_val(op.arg, **kw) + result = translate_val(op.result, **kw) + lamduh = sg.expressions.Lambda( + this=result, + expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + ) + sg_expr = sg.func("list_transform", arg, lamduh) + return sg_expr + + +@translate_val.register(ops.ArrayFilter) +def _array_filter(op, **kw): + arg = translate_val(op.arg, **kw) + result = translate_val(op.result, **kw) + lamduh = sg.expressions.Lambda( + this=result, + expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + ) + sg_expr = sg.func("list_filter", arg, lamduh) + return sg_expr + + +@translate_val.register(ops.ArrayIntersect) +def _array_intersect(op, **kw): + return translate_val( + ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)), **kw + ) + + +@translate_val.register(ops.ArrayPosition) +def _array_position(op, **kw): + arg = translate_val(op.arg, **kw) + el = translate_val(op.other, **kw) + return sg.func("list_indexof", arg, el) - 1 + + +@translate_val.register(ops.ArrayRemove) +def _array_remove(op, **kw): + return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw) + + +@translate_val.register(ops.ArrayUnion) +def _array_union(op, **kw): + return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) + + +# TODO: need to do this as a an array map + struct pack -- look at existing +# alchemy backend implementation +@translate_val.register(ops.ArrayZip) +def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: + zipped = sg.expressions.ArrayJoin().from_arg_list( + [translate_val(arg, **kw) for arg in op.arg] + ) + return zipped + + +### Counting + + @translate_val.register(ops.Count) def _count(op, **kw): arg = translate_val(op.arg, **kw) @@ -939,14 +1128,22 @@ def _sum(op, **kw): return sg_sum_expr -@translate_val.register(ops.NotAny) -def _not_any(op, **kw): - return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) +# TODO +@translate_val.register(ops.NthValue) +def _nth_value(op, **kw): + arg = translate_val(op.arg, **kw) + nth = translate_val(op.nth, **kw) + return f"nth_value({arg}, ({nth}) + 1)" -@translate_val.register(ops.NotAll) -def _not_all(op, **kw): - return translate_val(ops.Any(ops.Not(op.arg), where=op.where), **kw) +@translate_val.register(ops.Repeat) +def _repeat(op, **kw): + arg = translate_val(op.arg, **kw) + times = translate_val(op.times, **kw) + return f"repeat({arg}, {times})" + + +### Stats @translate_val.register(ops.Quantile) @@ -1083,25 +1280,10 @@ def _index_of(op, **kw): return f"list_indexof([{values}], {needle}) - 1" -@translate_val.register(ops.Round) -def _round(op, **kw): - arg = translate_val(op.arg, **kw) - if (digits := op.digits) is not None: - return sg.expressions.Round(this=arg, decimals=translate_val(digits, **kw)) - return sg.expressions.Round(this=arg) - - @translate_val.register(tuple) -def _node_list(op, punct="()", **kw): - values = ", ".join(map(_sql, map(partial(translate_val, **kw), op))) - return f"{punct[0]}{values}{punct[1]}" - - -def _sql(obj, dialect="duckdb"): - try: - return obj.sql(dialect=dialect) - except AttributeError: - return obj +def _node_list(op, punct="()", **kw): + values = ", ".join(map(_sql, map(partial(translate_val, **kw), op))) + return f"{punct[0]}{values}{punct[1]}" @translate_val.register(ops.SimpleCase) @@ -1157,52 +1339,6 @@ def _exists_subquery(op, **kw): return f"{prefix}EXISTS ({subq})" -@translate_val.register(ops.StringSplit) -def _string_split(op, **kw): - arg = translate_val(op.arg, **kw) - delimiter = translate_val(op.delimiter, **kw) - return sg.expressions.Split(this=arg, expression=delimiter) - - -@translate_val.register(ops.StringJoin) -def _string_join(op, **kw): - arg = map(partial(translate_val, **kw), op.arg) - sep = translate_val(op.sep, **kw) - elements = ", ".join(map(_sql, arg)) - return f"list_aggregate([{elements}], 'string_agg', {sep})" - - -@translate_val.register(ops.StringConcat) -def _string_concat(op, **kw): - arg = map(partial(translate_val, **kw), op.arg) - return sg.expressions.Concat(expressions=list(arg)) - - -@translate_val.register(ops.StringSQLLike) -def _string_like(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - return sg.expressions.Like(this=arg, expression=pattern) - - -@translate_val.register(ops.StringSQLILike) -def _string_ilike(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - return sg.expressions.Like(this=sg.func("lower", arg), expression=pattern) - - -@translate_val.register(ops.Capitalize) -def _string_capitalize(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.expressions.Concat( - expressions=[ - sg.func("upper", sg.func("substr", arg, 1, 1)), - sg.func("lower", sg.func("substr", arg, 2)), - ] - ) - - @translate_val.register(ops.GroupConcat) def _group_concat(op, **kw): arg = translate_val(op.arg, **kw) @@ -1238,18 +1374,6 @@ def _struct_column(op, **kw): return sg.cast(expression=values, to=struct_type) -@translate_val.register(ops.Clip) -def _clip(op, **kw): - arg = translate_val(op.arg, **kw) - if (upper := op.upper) is not None: - arg = f"least({translate_val(upper, **kw)}, {arg})" - - if (lower := op.lower) is not None: - arg = f"greatest({translate_val(lower, **kw)}, {arg})" - - return arg - - @translate_val.register(ops.StructField) def _struct_field(op, **kw): arg = translate_val(op.arg, **kw) @@ -1258,39 +1382,6 @@ def _struct_field(op, **kw): return sg_expr -# TODO -@translate_val.register(ops.NthValue) -def _nth_value(op, **kw): - arg = translate_val(op.arg, **kw) - nth = translate_val(op.nth, **kw) - return f"nth_value({arg}, ({nth}) + 1)" - - -@translate_val.register(ops.Repeat) -def _repeat(op, **kw): - arg = translate_val(op.arg, **kw) - times = translate_val(op.times, **kw) - return f"repeat({arg}, {times})" - - -@translate_val.register(ops.NullIfZero) -def _null_if_zero(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func("nullif", arg, 0, dialect="duckdb") - - -@translate_val.register(ops.ZeroIfNull) -def _zero_if_null(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func("ifnull", arg, 0, dialect="duckdb") - - -@translate_val.register(ops.FloorDivide) -def _floor_divide(op, **kw): - new_op = ops.Floor(ops.Divide(op.left, op.right)) - return translate_val(new_op, **kw) - - @translate_val.register(ops.ScalarParameter) def _scalar_param(op, params: Mapping[ops.Node, Any], **kw): raw_value = params[op] @@ -1341,40 +1432,6 @@ def tr(op, *, cache, **kw): # TODO # translate_val.register(ops.Contains)(contains("IN")) # translate_val.register(ops.NotContains)(contains("NOT IN")) - - -@translate_val.register(ops.DayOfWeekName) -def day_of_week_name(op, **kw): - # day of week number is 0-indexed - # Sunday == 0 - # Saturday == 6 - arg = op.arg - nullable = arg.dtype.nullable - empty_string = ops.Literal("", dtype=dt.String(nullable=nullable)) - weekdays = range(7) - return translate_val( - ops.NullIf( - ops.SimpleCase( - base=ops.DayOfWeekIndex(arg), - cases=[ - ops.Literal(day, dtype=dt.Int8(nullable=nullable)) - for day in weekdays - ], - results=[ - ops.Literal( - calendar.day_name[day], - dtype=dt.String(nullable=nullable), - ) - for day in weekdays - ], - default=empty_string, - ), - empty_string, - ), - **kw, - ) - - @translate_val.register(ops.IdenticalTo) def _identical_to(op, **kw): left = translate_val(op.left, **kw) @@ -1514,21 +1571,6 @@ def formatter(op, **kw): del _op, _sym -@translate_val.register(ops.Xor) -def _xor(op, **kw): - # TODO: is this really the best way to do this? - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.expressions.And( - this=sg.expressions.Paren(this=sg.expressions.Or(this=left, expression=right)), - expression=sg.expressions.Paren( - this=sg.expressions.Not( - this=sg.expressions.And(this=left, expression=right) - ) - ), - ) - - _bit_agg = { ops.BitOr: "bit_or", ops.BitAnd: "bit_and", @@ -1548,6 +1590,39 @@ def _bitor(op, **kw): return bit_expr +@translate_val.register(ops.Xor) +def _xor(op, **kw): + # TODO: is this really the best way to do this? + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return sg.expressions.And( + this=sg.expressions.Paren(this=sg.expressions.Or(this=left, expression=right)), + expression=sg.expressions.Paren( + this=sg.expressions.Not( + this=sg.expressions.And(this=left, expression=right) + ) + ), + ) + + +### Ordering + + +@translate_val.register(ops.RowNumber) +def _row_number(_, **kw): + return sg.expressions.RowNumber() + + +@translate_val.register(ops.DenseRank) +def _dense_rank(_, **kw): + return sg.func("dense_rank", dialect="duckdb") + + +@translate_val.register(ops.MinRank) +def _rank(_, **kw): + return sg.func("rank", dialect="duckdb") + + @translate_val.register def _sort_key(op: ops.SortKey, **kw): arg = translate_val(op.expr, **kw) @@ -1555,6 +1630,8 @@ def _sort_key(op: ops.SortKey, **kw): return f"{_sql(arg)} {direction}" +### Window functions + _cumulative_to_reduction = { ops.CumulativeSum: ops.Sum, ops.CumulativeMin: ops.Min, @@ -1710,86 +1787,6 @@ def formatter(op, **kw): shift_like(ops.Lead, "lead") -@translate_val.register(ops.RowNumber) -def _row_number(_, **kw): - return sg.expressions.RowNumber() - - -@translate_val.register(ops.DenseRank) -def _dense_rank(_, **kw): - return sg.func("dense_rank", dialect="duckdb") - - -@translate_val.register(ops.MinRank) -def _rank(_, **kw): - return sg.func("rank", dialect="duckdb") - - -@translate_val.register(ops.ArrayStringJoin) -def _array_string_join(op, **kw): - arg = translate_val(op.arg, **kw) - sep = translate_val(op.sep, **kw) - return f"list_aggregate({arg}, 'string_agg', {sep})" - - @translate_val.register(ops.Argument) def _argument(op, **_): return sg.expressions.Identifier(this=op.name, quoted=False) - - -@translate_val.register(ops.ArrayMap) -def _array_map(op, **kw): - arg = translate_val(op.arg, **kw) - result = translate_val(op.result, **kw) - lamduh = sg.expressions.Lambda( - this=result, - expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], - ) - sg_expr = sg.func("list_transform", arg, lamduh) - return sg_expr - - -@translate_val.register(ops.ArrayFilter) -def _array_filter(op, **kw): - arg = translate_val(op.arg, **kw) - result = translate_val(op.result, **kw) - lamduh = sg.expressions.Lambda( - this=result, - expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], - ) - sg_expr = sg.func("list_filter", arg, lamduh) - return sg_expr - - -@translate_val.register(ops.ArrayIntersect) -def _array_intersect(op, **kw): - return translate_val( - ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)), **kw - ) - - -@translate_val.register(ops.ArrayPosition) -def _array_position(op, **kw): - arg = translate_val(op.arg, **kw) - el = translate_val(op.other, **kw) - return sg.func("list_indexof", arg, el) - 1 - - -@translate_val.register(ops.ArrayRemove) -def _array_remove(op, **kw): - return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw) - - -@translate_val.register(ops.ArrayUnion) -def _array_union(op, **kw): - return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) - - -# TODO: need to do this as a an array map + struct pack -- look at existing -# alchemy backend implementation -@translate_val.register(ops.ArrayZip) -def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: - zipped = sg.expressions.ArrayJoin().from_arg_list( - [translate_val(arg, **kw) for arg in op.arg] - ) - return zipped From a326c4c7508558700276fd77ad754faa96be7ccc Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 15:20:47 -0400 Subject: [PATCH 039/222] fix(duckdb): strict dtype cast on numeric literal --- ibis/backends/duckdb/compiler/values.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 6c0bbe60e587..c5229fe7a2ae 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -124,7 +124,10 @@ def _literal(op, **kw): expression=sg_literal("NaN"), to=sg.expressions.DataType.Type.FLOAT, ) - return sg_literal(value, is_string=False) + return sg.cast( + sg_literal(value, is_string=False), + to=DuckDBType.from_ibis(dtype), + ) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): From 5894c347f2662185bab5b8583fbb7fb064de1732 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 15:38:20 -0400 Subject: [PATCH 040/222] feat(duckdb): decimal infinity and nan --- ibis/backends/duckdb/compiler/values.py | 30 ++++++++++++++++--------- ibis/backends/tests/test_numeric.py | 21 +++++------------ 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index c5229fe7a2ae..b89040b30ab3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -102,17 +102,27 @@ def _literal(op, **kw): raise NotImplementedError( f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" ) + if math.isinf(value): + return sg.expressions.cast( + expression=sg_literal(value), + to=sg.expressions.DataType.Type.FLOAT, + ) + elif math.isnan(value): + return sg.expressions.cast( + expression=sg_literal("NaN"), + to=sg.expressions.DataType.Type.FLOAT, + ) - # TODO: handle if `value` is "Infinity" - # precision = sg.expressions.DataTypeParam( - # this=sg.expressions.Literal(this=f"{precision}", is_string=False) - # ) - # scale = sg.expressions.DataTypeParam( - # this=sg.expressions.Literal(this=f"{scale}", is_string=False) - # ) - # need sg.expressions.DataTypeParam to be available - # ... - return f"{value!s}::decimal({precision}, {scale})" + precision = sg.expressions.DataTypeParam( + this=sg_literal(precision, is_string=False) + ) + scale = sg.expressions.DataTypeParam(this=sg_literal(scale, is_string=False)) + cast_to = sg.expressions.DataType( + this=sg.expressions.DataType.Type.DECIMAL, + expressions=[precision, scale], + ) + sg_expr = sg.cast(sg_literal(value, is_string=False), to=cast_to) + return sg_expr elif dtype.is_numeric(): if math.isinf(value): return sg.expressions.cast( diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index ac4d4995ee59..7535d5aa77a4 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -409,13 +409,14 @@ def test_numeric_literal(con, backend, expr, expected_types): "pandas": decimal.Decimal("Infinity"), "dask": decimal.Decimal("Infinity"), "impala": float("inf"), + "duckdb": float("inf"), }, { "bigquery": "FLOAT64", "snowflake": "VARCHAR", "sqlite": "real", "trino": "decimal(2,1)", - "duckdb": "DECIMAL(18,3)", + "duckdb": "FLOAT", "postgres": "numeric", "impala": "DOUBLE", }, @@ -425,11 +426,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["duckdb"], - "Unsupported precision. Supported values: [1 : 38]. Current value: None", - raises=NotImplementedError, - ), pytest.mark.broken( ["trino"], "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " @@ -489,6 +485,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "pandas": decimal.Decimal("-Infinity"), "dask": decimal.Decimal("-Infinity"), "impala": float("-inf"), + "duckdb": float("-inf"), }, { "bigquery": "FLOAT64", @@ -498,6 +495,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": "DECIMAL(18,3)", "postgres": "numeric", "impala": "DOUBLE", + "duckdb": "FLOAT", }, marks=[ pytest.mark.broken( @@ -505,11 +503,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["duckdb"], - "Unsupported precision. Supported values: [1 : 38]. Current value: None", - raises=NotImplementedError, - ), pytest.mark.broken( ["trino"], "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " @@ -569,6 +562,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "pandas": decimal.Decimal("NaN"), "dask": decimal.Decimal("NaN"), "impala": float("nan"), + "duckdb": float("nan"), }, { "bigquery": "FLOAT64", @@ -578,6 +572,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": "DECIMAL(18,3)", "postgres": "numeric", "impala": "DOUBLE", + "duckdb": "FLOAT", }, marks=[ pytest.mark.broken( @@ -585,10 +580,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["duckdb"], - "Unsupported precision. Supported values: [1 : 38]. Current value: None", - ), pytest.mark.broken( ["trino"], "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " From 3566c4c1233f6055440b5aa86c563197af822128 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 15:49:39 -0400 Subject: [PATCH 041/222] feat(duckdb): create_ and drop_schema --- ibis/backends/duckdb/__init__.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 34b916c7a75f..1b1c06f75056 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -493,10 +493,13 @@ def create_schema( raise exc.UnsupportedOperationError( "DuckDB cannot create a schema in another database." ) - # name = self._quote(name) - if_not_exists = "IF NOT EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + + name = sg.to_identifier(database, quoted=True) + return sg.expressions.Create( + this=name, + kind="SCHEMA", + replace=force, + ) def drop_schema( self, name: str, database: str | None = None, force: bool = False @@ -505,10 +508,13 @@ def drop_schema( raise exc.UnsupportedOperationError( "DuckDB cannot drop a schema in another database." ) - # name = self._quote(name) - if_exists = "IF EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") + + name = sg.to_identifier(database, quoted=True) + return sg.expressions.Drop( + this=name, + kind="SCHEMA", + replace=force, + ) def sql( self, From f7924e57172a1b73b24856300f15c623c1e6755b Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 17:08:03 -0400 Subject: [PATCH 042/222] feat(duckdb): use sqlglot for create table expression and fix nullability on schema translation --- ibis/backends/duckdb/__init__.py | 81 ++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 1b1c06f75056..15d02e3c6a03 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -118,41 +118,61 @@ def create_table( temp: bool = False, overwrite: bool = False, ): - tmp = "TEMP " * temp - replace = "OR REPLACE" * overwrite - if temp and overwrite: raise exc.IbisInputError("Cannot specify both temp and overwrite") - if not temp: - table = self._fully_qualified_name(name, database) - else: - table = name - database = None - code = f"CREATE {replace}{tmp}TABLE {table}" - if obj is None and schema is None: raise exc.IbisError("The schema or obj parameter is required") + table_identifier = sg.to_identifier(name, quoted=True) + create_expr = sg.expressions.Create( + kind="TABLE", # TABLE + replace=overwrite, # OR REPLACE + ) + + if temp: + create_expr.args["properties"] = sg.expressions.Properties( + expressions=[sg.expressions.TemporaryProperty()] # TEMPORARY + ) + if obj is not None and not isinstance(obj, ir.Expr): + # pd.DataFrame or pa.Table obj = ibis.memtable(obj, schema=schema) self._register_in_memory_table(obj.op()) - code += f" AS {self.compile(obj)}" - else: + create_expr.args["expression"] = self.compile(obj) # AS ... + create_expr.args["this"] = table_identifier # t0 + elif obj is not None: + self._register_in_memory_tables(obj) # If both `obj` and `schema` are specified, `obj` overrides `schema` # DuckDB doesn't support `create table (schema) AS select * ...` - if obj is not None: - code += f" AS {self.compile(obj)}" - else: - serialized_schema = ", ".join( - f"{name} {DuckDBType.to_string(typ)}" - for name, typ in schema.items() - ) - - code += f" ({serialized_schema})" + create_expr.args["expression"] = self.compile(obj) # AS ... + create_expr.args["this"] = table_identifier # t0 + else: + # Schema -> Table -> [ColumnDefs] + schema_expr = sg.expressions.Schema( + this=sg.expressions.Table(this=table_identifier), + expressions=[ + sg.expressions.ColumnDef( + this=sg.to_identifier(key, quoted=False), + kind=DuckDBType.from_ibis(typ), + ) + if typ.nullable + else sg.expressions.ColumnDef( + this=sg.to_identifier(key, quoted=False), + kind=DuckDBType.from_ibis(typ), + constraints=[ + sg.expressions.ColumnConstraint( + kind=sg.expressions.NotNullColumnConstraint() + ) + ], + ) + for key, typ in schema.items() + ], + ) + create_expr.args["this"] = schema_expr # create the table - self.raw_sql(code) + self.raw_sql(create_expr) return self.table(name, database=database) @@ -247,10 +267,23 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema qualified_name = sg.expressions.Identifier(this=qualified_name, quoted=True) query = sg.expressions.Describe(this=qualified_name) results = self.raw_sql(query) - names, types, *_ = results.fetch_arrow_table() + names, types, nulls, *_ = results.fetch_arrow_table() names = names.to_pylist() types = types.to_pylist() - return sch.Schema(dict(zip(names, map(DuckDBType.from_string, types)))) + # TODO: remove code crime + # DuckDB gives back "YES", "NO" for nullability + nulls = [bool(null[:-2]) for null in nulls.to_pylist()] + return sch.Schema( + dict( + zip( + names, + ( + DuckDBType.from_string(typ, nullable=null) + for typ, null in zip(types, nulls) + ), + ) + ) + ) def list_databases(self, like: str | None = None) -> list[str]: result = self.raw_sql("PRAGMA database_list;") From 1d294944a8a92991d907bb75ac5e8ec0d02b687e Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 17:22:28 -0400 Subject: [PATCH 043/222] fix(duckdb): simplify decimal literal --- ibis/backends/duckdb/compiler/values.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b89040b30ab3..bff8f2236899 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -113,15 +113,10 @@ def _literal(op, **kw): to=sg.expressions.DataType.Type.FLOAT, ) - precision = sg.expressions.DataTypeParam( - this=sg_literal(precision, is_string=False) + dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) + sg_expr = sg.cast( + sg_literal(value, is_string=False), to=DuckDBType.from_ibis(dtype) ) - scale = sg.expressions.DataTypeParam(this=sg_literal(scale, is_string=False)) - cast_to = sg.expressions.DataType( - this=sg.expressions.DataType.Type.DECIMAL, - expressions=[precision, scale], - ) - sg_expr = sg.cast(sg_literal(value, is_string=False), to=cast_to) return sg_expr elif dtype.is_numeric(): if math.isinf(value): From e658b60005ad728a8bac52cfb0118627bef6256a Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 12 Sep 2023 17:34:12 -0400 Subject: [PATCH 044/222] chore(duckdb): cleanup linting errors --- ibis/backends/duckdb/__init__.py | 41 ++++++++----------------- ibis/backends/duckdb/compiler/core.py | 5 ++- ibis/backends/duckdb/compiler/values.py | 30 +++++++++--------- ibis/backends/duckdb/tests/conftest.py | 1 + ibis/backends/tests/test_client.py | 3 -- ibis/backends/tests/test_numeric.py | 2 -- 6 files changed, 33 insertions(+), 49 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 15d02e3c6a03..3a75340485ae 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -7,7 +7,13 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, MutableMapping +from typing import TYPE_CHECKING, Any + +import duckdb +import pyarrow as pa +import sqlglot as sg +import toolz +from packaging.version import parse as vparse import ibis import ibis.common.exceptions as exc @@ -15,9 +21,6 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir -import pyarrow as pa -import sqlglot as sg -import toolz from ibis import util from ibis.backends.base import CanCreateSchema from ibis.backends.base.sql import BaseBackend @@ -26,17 +29,15 @@ from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType from ibis.formats.pandas import PandasData -from ibis.formats.pyarrow import PyArrowData -from packaging.version import parse as vparse - -import duckdb if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence - import ibis.expr.operations as ops + import pandas as pd import torch + from ibis.common.typing import SupportsSchema + def normalize_filenames(source_list): # Promote to list @@ -45,22 +46,6 @@ def normalize_filenames(source_list): return list(map(util.normalize_filename, source_list)) -def _format_kwargs(kwargs: Mapping[str, Any]): - bindparams, pieces = [], [] - for name, value in kwargs.items(): - bindparam = sa.bindparam(name, value) - if isinstance(paramtype := bindparam.type, sa.String): - # special case strings to avoid double escaping backslashes - pieces.append(f"{name} = '{value!s}'") - elif not isinstance(paramtype, sa.types.NullType): - bindparams.append(bindparam) - pieces.append(f"{name} = :{name}") - else: # fallback to string strategy - pieces.append(f"{name} = {value!r}") - - return sa.text(", ".join(pieces)).bindparams(*bindparams) - - _UDF_INPUT_TYPE_MAPPING = { InputType.PYARROW: duckdb.functional.ARROW, InputType.PYTHON: duckdb.functional.NATIVE, @@ -481,7 +466,7 @@ def execute( raise exc.IbisError(e) # TODO: should we do this in arrow? - # also wth is pandas doing with dates? + # also what is pandas doing with dates? pandas_df = result.fetch_df() result = PandasData.convert_table(pandas_df, schema) if isinstance(expr, ir.Table): @@ -515,7 +500,7 @@ def _load_extensions(self, extensions): installed = (name for (name,) in self.con.sql(query).fetchall()) # Install and load all other extensions todo = set(extensions).difference(installed) - for extension in extensions: + for extension in todo: self.con.install_extension(extension) self.con.load_extension(extension) @@ -793,7 +778,7 @@ def _read_parquet_duckdb_native( ): self._load_extensions(["httpfs"]) - if kw := kwargs: + if kwargs: options = [f"{key}={val}" for key, val in kwargs.items()] pq_func = sg.func("read_parquet", source_list, *options) else: diff --git a/ibis/backends/duckdb/compiler/core.py b/ibis/backends/duckdb/compiler/core.py index c7db3faf51c1..9979e63a9eaa 100644 --- a/ibis/backends/duckdb/compiler/core.py +++ b/ibis/backends/duckdb/compiler/core.py @@ -41,7 +41,7 @@ from __future__ import annotations -from typing import Any, Mapping +from typing import TYPE_CHECKING, Any import sqlglot as sg @@ -49,6 +49,9 @@ import ibis.expr.types as ir from ibis.backends.duckdb.compiler.relations import translate_rel +if TYPE_CHECKING: + from collections.abc import Mapping + def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression: """Translate an ibis operation to a sqlglot expression. diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index bff8f2236899..9c9eb1987961 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1,26 +1,27 @@ from __future__ import annotations import calendar -import contextlib import functools import math -import operator from functools import partial -from operator import add, mul, sub -from typing import Any, Literal, Mapping +from typing import TYPE_CHECKING, Any, Literal import duckdb +import sqlglot as sg +from packaging.version import parse as vparse +from toolz import flip + import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz -import sqlglot as sg from ibis.backends.base.sql.registry import helpers from ibis.backends.base.sqlglot.datatypes import DuckDBType -from packaging.version import parse as vparse -from toolz import flip + +if TYPE_CHECKING: + from collections.abc import Mapping # TODO: Ideally we can translate bottom up a la `relations.py` # TODO: Find a way to remove all the dialect="duckdb" kwargs @@ -88,6 +89,7 @@ def _literal(op, **kw): return sg.expressions.Boolean(this=value) elif dtype.is_inet(): com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") + return None elif dtype.is_string(): return sg_literal(value) elif dtype.is_decimal(): @@ -505,7 +507,7 @@ def _time_from_hms(op, **kw): @translate_val.register(ops.TimestampNow) def _timestamp_now(op, **kw): - """DuckDB current timestamp defaults to timestamp + tz""" + """DuckDB current timestamp defaults to timestamp + tz.""" return sg.cast(expression=sg.func("current_timestamp"), to="TIMESTAMP") @@ -1109,7 +1111,7 @@ def _count_star(op, **kw): if (predicate := op.where) is not None: return sg.expressions.Filter( this=sql, - expression=sg.expressions.Where(this=translate_val(op.where, **kw)), + expression=sg.expressions.Where(this=translate_val(predicate, **kw)), ) return sql @@ -1193,7 +1195,7 @@ def _corr(op, **kw): sg_func = sg.func("corr", left, right) if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + predicate = sg.expressions.Where(this=translate_val(where, **kw)) return sg.expressions.Filter(this=sg_func, expression=predicate) return sg_func @@ -1222,7 +1224,7 @@ def _covariance(op, **kw): sg_func = sg.func(funcname, left, right) if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + predicate = sg.expressions.Where(this=translate_val(where, **kw)) return sg.expressions.Filter(this=sg_func, expression=predicate) return sg_func @@ -1246,7 +1248,7 @@ def _variance(op, **kw): sg_func = sg.func(funcname, arg) if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(op.where, **kw)) + predicate = sg.expressions.Where(this=translate_val(where, **kw)) return sg.expressions.Filter(this=sg_func, expression=predicate) return sg_func @@ -1334,7 +1336,7 @@ def _table_array_view(op, *, cache, **kw): def _exists_subquery(op, **kw): from ibis.backends.duckdb.compiler.relations import translate_rel - if not "table" in kw: + if "table" not in kw: kw["table"] = translate_rel(op.foreign_table.table, **kw) foreign_table = translate_rel(op.foreign_table, **kw) predicates = translate_val(op.predicates, **kw) @@ -1544,8 +1546,6 @@ def formatter(op, **kw): return formatter -import operator - _binary_infix_ops = { # Binary operations ops.Add: sg.expressions.Add, diff --git a/ibis/backends/duckdb/tests/conftest.py b/ibis/backends/duckdb/tests/conftest.py index 1f8d70c05a41..3fe3cfd2cdaa 100644 --- a/ibis/backends/duckdb/tests/conftest.py +++ b/ibis/backends/duckdb/tests/conftest.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing import Any from ibis.backends.base import BaseBackend diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index af2de4601095..df7031ea0c0f 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1243,9 +1243,6 @@ def test_persist_expression_repeated_cache(alltypes): assert not nested_cached_table.to_pandas().empty -@mark.broken( - "duckdb", reason="table name has `main` prepended, breaking the match check" -) @mark.notimpl(["datafusion", "bigquery", "impala", "trino", "druid"]) @mark.never( ["mssql"], diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 7535d5aa77a4..ee244c9b26c7 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -492,7 +492,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": "VARCHAR", "sqlite": "real", "trino": "decimal(2,1)", - "duckdb": "DECIMAL(18,3)", "postgres": "numeric", "impala": "DOUBLE", "duckdb": "FLOAT", @@ -569,7 +568,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": "VARCHAR", "sqlite": "null", "trino": "decimal(2,1)", - "duckdb": "DECIMAL(18,3)", "postgres": "numeric", "impala": "DOUBLE", "duckdb": "FLOAT", From d8e5e8a8cae64c75ac696b327c4752228ca04974 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 13 Sep 2023 14:06:37 -0400 Subject: [PATCH 045/222] fix(duckdb): fix MapGet and Map Dict --- ibis/backends/duckdb/__init__.py | 7 ++++++- ibis/backends/duckdb/compiler/values.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 3a75340485ae..6af701086881 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -467,7 +467,12 @@ def execute( # TODO: should we do this in arrow? # also what is pandas doing with dates? - pandas_df = result.fetch_df() + # TODO: converting to arrow -> pandas + # makes map tests pass because of how the map results + # are parsed out of DucKDB. + # This is stupid and we should fix it. + pandas_df = result.arrow().to_pandas() + # pandas_df = result.fetch_df() result = PandasData.convert_table(pandas_df, schema) if isinstance(expr, ir.Table): return result diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 9c9eb1987961..7b1ed7157691 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1475,7 +1475,7 @@ def _map_get(op, **kw): default = translate_val(op.default, **kw) sg_expr = sg.func( "ifnull", - sg.func("element_at", arg, key), + sg.func("list_extract", sg.func("element_at", arg, key), 1), default, dialect="duckdb", ) From 5f8f2fffa3e0083318dfd64ccfea72093758b760 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 13 Sep 2023 14:24:37 -0400 Subject: [PATCH 046/222] feat(duckdb): array repeat --- ibis/backends/duckdb/compiler/values.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 7b1ed7157691..f6090f50d40f 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -966,16 +966,17 @@ def _array_concat(op, **kw): return sg_expr -# # TODO -# @translate_val.register(ops.ArrayRepeat) -# def _array_repeat_op(op, **kw): -# arg = translate_val(op.arg, **kw) -# times = translate_val(op.times, **kw) -# from_ = f"(SELECT {arg} AS arr FROM system.numbers LIMIT {times})" -# query = sg.parse_one( -# f"SELECT arrayFlatten(groupArray(arr)) FROM {from_}", read="duckdb" -# ) -# return query.subquery() +@translate_val.register(ops.ArrayRepeat) +def _array_repeat_op(op, **kw): + arg = translate_val(op.arg, **kw) + times = translate_val(op.times, **kw) + sg_expr = sg.func( + "flatten", + sg.select( + sg.func("array", sg.select(arg).from_(sg.func("range", times))) + ).subquery(), + ) + return sg_expr def _neg_idx_to_pos(array, idx): From 66132ea0f3433d66423cd7082f4494ddc6c7a244 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 13 Sep 2023 15:56:33 -0400 Subject: [PATCH 047/222] fix(duckdb): handle duckdb -> pandas df -> ibis table conversion --- ibis/backends/duckdb/__init__.py | 17 +++++++---------- ibis/backends/duckdb/datatypes.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) create mode 100644 ibis/backends/duckdb/datatypes.py diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 6af701086881..98356b232d51 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -28,7 +28,7 @@ from ibis.backends.duckdb.compiler import translate from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType -from ibis.formats.pandas import PandasData +from ibis.backends.duckdb.datatypes import DuckDBPandasData if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence @@ -255,9 +255,10 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema names, types, nulls, *_ = results.fetch_arrow_table() names = names.to_pylist() types = types.to_pylist() - # TODO: remove code crime # DuckDB gives back "YES", "NO" for nullability - nulls = [bool(null[:-2]) for null in nulls.to_pylist()] + # TODO: remove code crime + # nulls = [bool(null[:-2]) for null in nulls.to_pylist()] + nulls = [null == "YES" for null in nulls.to_pylist()] return sch.Schema( dict( zip( @@ -467,13 +468,9 @@ def execute( # TODO: should we do this in arrow? # also what is pandas doing with dates? - # TODO: converting to arrow -> pandas - # makes map tests pass because of how the map results - # are parsed out of DucKDB. - # This is stupid and we should fix it. - pandas_df = result.arrow().to_pandas() - # pandas_df = result.fetch_df() - result = PandasData.convert_table(pandas_df, schema) + + pandas_df = result.fetch_df() + result = DuckDBPandasData.convert_table(pandas_df, schema) if isinstance(expr, ir.Table): return result elif isinstance(expr, ir.Column): diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py new file mode 100644 index 000000000000..9ed770ea26aa --- /dev/null +++ b/ibis/backends/duckdb/datatypes.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import numpy as np +from ibis.formats.pandas import PandasData + + +class DuckDBPandasData(PandasData): + @staticmethod + def convert_Map(s, dtype, pandas_type): + return s.map(lambda x: dict(zip(x["key"], x["value"])), na_action="ignore") + + @staticmethod + def convert_Array(s, dtype, pandas_type): + return s.replace(np.nan, None) From c03859f3aec0f12a687590dd677395c5c30518a7 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 13 Sep 2023 15:57:32 -0400 Subject: [PATCH 048/222] fix(duckdb): fix array slice with neg index > len(array) --- ibis/backends/duckdb/compiler/values.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index f6090f50d40f..55564dae360d 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -7,18 +7,17 @@ from typing import TYPE_CHECKING, Any, Literal import duckdb -import sqlglot as sg -from packaging.version import parse as vparse -from toolz import flip - import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz +import sqlglot as sg from ibis.backends.base.sql.registry import helpers from ibis.backends.base.sqlglot.datatypes import DuckDBType +from packaging.version import parse as vparse +from toolz import flip if TYPE_CHECKING: from collections.abc import Mapping @@ -345,6 +344,7 @@ def _generic_log(op, **kw): @translate_val.register(ops.Clip) def _clip(op, **kw): arg = translate_val(op.arg, **kw) + # TODO expressionize if (upper := op.upper) is not None: arg = f"least({translate_val(upper, **kw)}, {arg})" @@ -805,9 +805,7 @@ def _regex_extract(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) group = translate_val(op.index, **kw) - return f"regexp_extract({arg}, {pattern}, {group})" - # TODO: make this work -- need to handle pattern escaping? - return sg.func("regexp_extract", arg, pattern, group) + return sg.func("regexp_extract", arg, pattern, group, dialect="duckdb") @translate_val.register(ops.Levenshtein) @@ -980,9 +978,15 @@ def _array_repeat_op(op, **kw): def _neg_idx_to_pos(array, idx): + arg_length = sg.func("len", array) return sg.expressions.If( this=sg.expressions.LT(this=idx, expression=sg_literal(0, is_string=False)), - true=sg.func("len", array) + idx, + # Need to have the greatest here to handle the case where + # abs(neg_index) > arg_length + # e.g. where the magnitude of the negative index is greater than the + # length of the array + # You cannot index a[:-3] if a = [1, 2] + true=arg_length + sg.func("greatest", idx, -1 * arg_length), false=idx, ) @@ -1020,7 +1024,7 @@ def _array_map(op, **kw): result = translate_val(op.result, **kw) lamduh = sg.expressions.Lambda( this=result, - expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + expressions=[sg.to_identifier(f"{op.parameter}", quoted=False)], ) sg_expr = sg.func("list_transform", arg, lamduh) return sg_expr From f9c4171c385574c059216a12c6dbef8483013fc3 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 13 Sep 2023 16:31:47 -0400 Subject: [PATCH 049/222] fix(duckdb): fix arrayconcat and array zip co-authored by Phillip Cloud --- ibis/backends/duckdb/compiler/values.py | 36 ++++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 55564dae360d..ea696fbfd294 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -957,8 +957,11 @@ def _array_collect(op, **kw): @translate_val.register(ops.ArrayConcat) def _array_concat(op, **kw): sg_expr = sg.func( - "list_concat", - *(translate_val(arg, **kw) for arg in op.arg), + "flatten", + sg.func( + "list_value", + *(translate_val(arg, **kw) for arg in op.arg), + ), dialect="duckdb", ) return sg_expr @@ -1066,14 +1069,33 @@ def _array_union(op, **kw): return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) -# TODO: need to do this as a an array map + struct pack -- look at existing -# alchemy backend implementation @translate_val.register(ops.ArrayZip) def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: - zipped = sg.expressions.ArrayJoin().from_arg_list( - [translate_val(arg, **kw) for arg in op.arg] + i = sg.to_identifier("i", quoted=False) + args = [translate_val(arg, **kw) for arg in op.arg] + result = sg.expressions.Struct( + expressions=[ + sg.expressions.Slice( + this=sg_literal(name), + expression=sg.func("list_extract", arg, i), + ) + for name, arg in zip(op.dtype.value_type.names, args) + ] ) - return zipped + lamduh = sg.expressions.Lambda(this=result, expressions=[i]) + sg_expr = sg.func( + "list_transform", + sg.func( + "range", + sg_literal(1, is_string=False), + # DuckDB Range is not inclusive of upper bound + sg.func("greatest", *[sg.func("len", arg) for arg in args]) + 1, + ), + lamduh, + dialect="duckdb", + ) + + return sg_expr ### Counting From 35826e8397e21eb65bebc8e557f4df40bf44a81a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 05:42:46 -0400 Subject: [PATCH 050/222] test: regen cast sql snapshots --- .../snapshots/test_datatypes/test_cast_uints/uint16/out.sql | 2 +- .../snapshots/test_datatypes/test_cast_uints/uint32/out.sql | 2 +- .../snapshots/test_datatypes/test_cast_uints/uint64/out.sql | 2 +- .../snapshots/test_datatypes/test_cast_uints/uint8/out.sql | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql index abb420080b20..3cf287053589 100644 --- a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint16/out.sql @@ -1,3 +1,3 @@ SELECT CAST(t0.a AS USMALLINT) AS "Cast(a, uint16)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql index b2ec0d726884..ec913c277122 100644 --- a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint32/out.sql @@ -1,3 +1,3 @@ SELECT CAST(t0.a AS UINTEGER) AS "Cast(a, uint32)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql index 6cefd3bb478b..b7ef4009e0de 100644 --- a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint64/out.sql @@ -1,3 +1,3 @@ SELECT CAST(t0.a AS UBIGINT) AS "Cast(a, uint64)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql index dae9dbdc41cb..1fd18da74bec 100644 --- a/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql +++ b/ibis/backends/duckdb/tests/snapshots/test_datatypes/test_cast_uints/uint8/out.sql @@ -1,3 +1,3 @@ SELECT CAST(t0.a AS UTINYINT) AS "Cast(a, uint8)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file From 73ed6224903a73b1777813117d1d8219bcf2557b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 05:47:53 -0400 Subject: [PATCH 051/222] feat(duckdb): implement json getitem --- ibis/backends/duckdb/compiler/values.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index ea696fbfd294..5655c5d003fc 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -7,17 +7,18 @@ from typing import TYPE_CHECKING, Any, Literal import duckdb +import sqlglot as sg +from packaging.version import parse as vparse +from toolz import flip + import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz -import sqlglot as sg from ibis.backends.base.sql.registry import helpers from ibis.backends.base.sqlglot.datatypes import DuckDBType -from packaging.version import parse as vparse -from toolz import flip if TYPE_CHECKING: from collections.abc import Mapping @@ -1825,3 +1826,10 @@ def formatter(op, **kw): @translate_val.register(ops.Argument) def _argument(op, **_): return sg.expressions.Identifier(this=op.name, quoted=False) + + +@translate_val.register(ops.JSONGetItem) +def _json_getitem(op, **kw): + return sg.exp.JSONExtract( + this=translate_val(op.arg, **kw), expression=translate_val(op.index, **kw) + ) From 3d7cb61f020fdc12349096f20a55efad3d13a60c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:12:37 -0400 Subject: [PATCH 052/222] chore: fix percent_rankd and cume_dist --- ibis/backends/duckdb/compiler/values.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5655c5d003fc..2f13ceeeeab2 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1659,6 +1659,16 @@ def _rank(_, **kw): return sg.func("rank", dialect="duckdb") +@translate_val.register(ops.PercentRank) +def _percent_rank(_, **kw): + return sg.func("percent_rank", dialect="duckdb") + + +@translate_val.register(ops.CumeDist) +def _cume_dist(_, **kw): + return sg.func("percent_rank", dialect="duckdb") + + @translate_val.register def _sort_key(op: ops.SortKey, **kw): arg = translate_val(op.expr, **kw) From cba7b01b458a7d4e3fe058ef36f952de73c1cdac Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:13:01 -0400 Subject: [PATCH 053/222] chore: fix rowid --- ibis/backends/duckdb/compiler/values.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 2f13ceeeeab2..9f5337892d17 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1843,3 +1843,14 @@ def _json_getitem(op, **kw): return sg.exp.JSONExtract( this=translate_val(op.arg, **kw), expression=translate_val(op.index, **kw) ) + + +@translate_val.register(ops.RowID) +def _rowid(op, *, aliases, **_) -> str: + table = op.table + return sg.column(op.name, (aliases or {}).get(table, table.name)) + + +@translate_val.register(ops.ScalarUDF) +def _scalar_udf(op, **kw) -> str: + return sg.func(op.__full_name__, *(translate_val(arg, **kw) for arg in op.args)) From dd2b981db7589ad77c0549cc17e5812dc81ecd53 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:32:52 -0400 Subject: [PATCH 054/222] chore: fix dynamic limit/offset --- ibis/backends/duckdb/compiler/relations.py | 28 +++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 4ae6c360414d..a3ea1c2a13e0 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -176,13 +176,29 @@ def _set_op(op: ops.SetOp, *, left, right, **_): @translate_rel.register -def _limit(op: ops.Limit, *, table, **kw): - n = op.n - limited = sg.select("*").from_(table).limit(n) +def _limit(op: ops.Limit, *, table, n, offset, **kw): + result = sg.select("*").from_(table) + + if isinstance(n, int): + result = result.limit(n) + elif n is not None: + limit = translate_val(n, **kw) + # TODO: calling `.sql` is a workaround for sqlglot not supporting + # scalar subqueries in limits + limit = sg.select(limit).from_(table).subquery().sql(dialect="duckdb") + result = result.limit(limit) + + assert offset is not None, "offset is None" + + if not isinstance(offset, int): + skip = translate_val(offset, **kw) + skip = sg.select(skip).from_(table).subquery().sql(dialect="duckdb") + elif not offset: + return result + else: + skip = offset - if offset := op.offset: - limited = limited.offset(offset) - return limited + return result.offset(skip) @translate_rel.register From 6c37d6c2bb7dfc9f1be8f7292119883d2f444374 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:40:17 -0400 Subject: [PATCH 055/222] chore: remove SUPPORTS_MAPS from upstream --- ibis/backends/duckdb/compiler/values.py | 27 +++---------------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 9f5337892d17..2890bce96e9e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -6,9 +6,7 @@ from functools import partial from typing import TYPE_CHECKING, Any, Literal -import duckdb import sqlglot as sg -from packaging.version import parse as vparse from toolz import flip import ibis @@ -25,7 +23,6 @@ # TODO: Ideally we can translate bottom up a la `relations.py` # TODO: Find a way to remove all the dialect="duckdb" kwargs -_SUPPORTS_MAPS = vparse(duckdb.__version__) >= vparse("0.8.0") @functools.singledispatch @@ -1538,27 +1535,9 @@ def _is_map_literal(op): @translate_val.register(ops.MapMerge) def _map_merge(op, **kw): - if _SUPPORTS_MAPS: - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.func("map_concat", left, right) - else: - if not (_is_map_literal(op.left) and _is_map_literal(op.right)): - raise com.UnsupportedOperationError( - "Merging non-literal maps is not yet supported by DuckDB" - ) - left = sg.func("to_json", translate_val(op.left, **kw)) - right = sg.func("to_json", translate_val(op.right, **kw)) - pairs = sg.func("json_merge_patch", left, right) - keys = sg.func("json_keys", pairs) - return sg.cast( - expression=sg.func( - "map", - keys, - sg.func("json_extract_string", pairs, keys), - ), - to=DuckDBType.from_ibis(op.dtype), - ) + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return sg.func("map_concat", left, right) def _binary_infix(sg_expr: sg.expressions._Expression): From c7b9ac18a021e02b6cfc5f8caa882d2d7df5c6f0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:41:48 -0400 Subject: [PATCH 056/222] chore: fix collect with filter --- ibis/backends/duckdb/compiler/values.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 2890bce96e9e..5a4827d7fd04 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -945,11 +945,10 @@ def _in_column(op, **kw): @translate_val.register(ops.ArrayCollect) def _array_collect(op, **kw): - if op.where is not None: - # TODO: handle when op.where is not none - # probably using list_agg? - ... - return sg.func("list", translate_val(op.arg, **kw), dialect="duckdb") + agg = sg.func("list", translate_val(op.arg, **kw), dialect="duckdb") + if (where := op.where) is not None: + return sg.exp.Filter(this=agg, expression=translate_val(where, **kw)) + return agg @translate_val.register(ops.ArrayConcat) From bb96a0299c3a3c3be3392c7c2846c709589023e6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:00:59 -0400 Subject: [PATCH 057/222] chore: conslidate aggregate filter application and fix reduction operations --- ibis/backends/duckdb/compiler/values.py | 151 +++++++----------------- ibis/backends/tests/test_aggregation.py | 5 - 2 files changed, 41 insertions(+), 115 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5a4827d7fd04..693cdab48354 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -442,24 +442,36 @@ def _not(op, **kw): return sg.expressions.Not(this=arg) +def _apply_agg_filter(expr, *, where, **kw): + if where is not None: + return sg.exp.Filter( + this=expr, expression=sg.exp.Where(this=translate_val(where, **kw)) + ) + return expr + + +def _aggregate(op, func, *, where, **kw): + args = [ + translate_val(arg, **kw) + for argname, arg in zip(op.argnames, op.args) + if argname not in ("where", "how") + ] + agg = sg.func(func, *args, dialect="duckdb") + return _apply_agg_filter(agg, where=op.where, **kw) + + @translate_val.register(ops.Any) def _any(op, **kw): arg = translate_val(op.arg, **kw) - any_expr = sg.expressions.AnyValue(this=arg) - if op.where is not None: - where = sg.expressions.Where(this=translate_val(op.where, **kw)) - return sg.expressions.Filter(this=any_expr, expression=where) - return any_expr + any_expr = sg.func("bool_or", arg) + return _apply_agg_filter(any_expr, where=op.where, **kw) @translate_val.register(ops.All) def _all(op, **kw): arg = translate_val(op.arg, **kw) all_expr = sg.func("bool_and", arg) - if op.where is not None: - where = sg.expressions.Where(this=translate_val(op.where, **kw)) - return sg.expressions.Filter(this=all_expr, expression=where) - return all_expr + return _apply_agg_filter(all_expr, where=op.where, **kw) @translate_val.register(ops.NotAny) @@ -946,9 +958,7 @@ def _in_column(op, **kw): @translate_val.register(ops.ArrayCollect) def _array_collect(op, **kw): agg = sg.func("list", translate_val(op.arg, **kw), dialect="duckdb") - if (where := op.where) is not None: - return sg.exp.Filter(this=agg, expression=translate_val(where, **kw)) - return agg + return _apply_agg_filter(agg, where=op.where, **kw) @translate_val.register(ops.ArrayConcat) @@ -1102,10 +1112,7 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: def _count(op, **kw): arg = translate_val(op.arg, **kw) count_expr = sg.expressions.Count(this=arg) - if op.where is not None: - where = sg.expressions.Where(this=translate_val(op.where, **kw)) - return sg.expressions.Filter(this=count_expr, expression=where) - return count_expr + return _apply_agg_filter(count_expr, where=op.where, **kw) @translate_val.register(ops.CountDistinct) @@ -1117,10 +1124,7 @@ def _count_distinct(op, **kw): expressions=[arg], ) ) - if op.where is not None: - where = sg.expressions.Where(this=translate_val(op.where, **kw)) - return sg.expressions.Filter(this=count_expr, expression=where) - return count_expr + return _apply_agg_filter(count_expr, where=op.where, **kw) # TODO: implement @@ -1132,34 +1136,15 @@ def _count_distinct_star(op, **kw): @translate_val.register(ops.CountStar) def _count_star(op, **kw): sql = sg.expressions.Count(this=sg.expressions.Star()) - if (predicate := op.where) is not None: - return sg.expressions.Filter( - this=sql, - expression=sg.expressions.Where(this=translate_val(predicate, **kw)), - ) - return sql + return _apply_agg_filter(sql, where=op.where, **kw) @translate_val.register(ops.Sum) def _sum(op, **kw): - arg = translate_val(op.arg, **kw) - where = None - if op.where is not None: - where = translate_val(op.where, **kw) - - sg_where = sg.expressions.Where(this=where) - - # Handle sum(boolean comparison) - if isinstance(op.arg, ops.Comparison): - sg_count_expr = sg.expressions.Count(this=arg) - if where is not None: - return sg.expressions.Filter(this=sg_count_expr, expression=sg_where) - return sg_count_expr - - sg_sum_expr = sg.expressions.Sum(this=arg) - if where is not None: - return sg.expressions.Filter(this=sg_sum_expr, expression=sg_where) - return sg_sum_expr + arg = translate_val( + ops.Cast(arg, to=op.dtype) if (arg := op.arg).dtype.is_boolean() else arg, **kw + ) + return _apply_agg_filter(sg.expressions.Sum(this=arg), where=op.where, **kw) # TODO @@ -1186,13 +1171,7 @@ def _quantile(op, **kw): arg = translate_val(op.arg, **kw) quantile = translate_val(op.quantile, **kw) sg_expr = sg.func("quantile_cont", arg, quantile, dialect="duckdb") - if op.where is not None: - predicate = translate_val(op.where, **kw) - sg_expr = sg.expressions.Filter( - this=sg_expr, - expression=sg.expressions.Where(this=predicate), - ) - return sg_expr + return _apply_agg_filter(sg_expr, where=op.where, **kw) @translate_val.register(ops.Correlation) @@ -1217,12 +1196,7 @@ def _corr(op, **kw): ) sg_func = sg.func("corr", left, right) - - if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(where, **kw)) - return sg.expressions.Filter(this=sg_func, expression=predicate) - - return sg_func + return _apply_agg_filter(sg_func, where=op.where, **kw) @translate_val.register(ops.Covariance) @@ -1243,15 +1217,8 @@ def _covariance(op, **kw): to=DuckDBType.from_ibis(dt.Int32(nullable=right_type.nullable)), ) - funcname = f"covar_{_how[op.how]}" - - sg_func = sg.func(funcname, left, right) - - if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(where, **kw)) - return sg.expressions.Filter(this=sg_func, expression=predicate) - - return sg_func + sg_func = sg.func(f"covar_{_how[op.how]}", left, right) + return _apply_agg_filter(sg_func, where=op.where, **kw) @translate_val.register(ops.Variance) @@ -1260,39 +1227,14 @@ def _variance(op, **kw): _how = {"sample": "samp", "pop": "pop"} _func = {ops.Variance: "var", ops.StandardDev: "stddev"} - funcname = f"{_func[type(op)]}_{_how[op.how]}" - - arg = translate_val(op.arg, **kw) - if (arg_type := op.arg.dtype).is_boolean(): - arg = sg.cast( - expression=arg, - to=DuckDBType.from_ibis(dt.Int32(nullable=arg_type.nullable)), - ) - - sg_func = sg.func(funcname, arg) - - if (where := op.where) is not None: - predicate = sg.expressions.Where(this=translate_val(where, **kw)) - return sg.expressions.Filter(this=sg_func, expression=predicate) - - return sg_func - + arg = op.arg + if (arg_dtype := arg.dtype).is_boolean(): + arg = ops.Cast(arg, to=dt.Int32(nullable=arg_dtype)) -def _aggregate(op, func, *, where=None, **kw): - args = [ - translate_val(arg, **kw) - for argname, arg in zip(op.argnames, op.args) - if argname not in ("where", "how") - ] - if where is not None: - predicate = translate_val(where, **kw) - return sg.expressions.Filter( - this=sg.func(func, *args, dialect="duckdb"), - expression=sg.expressions.Where(this=predicate), - ) + arg = translate_val(arg, **kw) - res = sg.func(func, *args) - return res + sg_func = sg.func(f"{_func[type(op)]}_{_how[op.how]}", arg) + return _apply_agg_filter(sg_func, where=op.where, **kw) @translate_val.register(ops.Arbitrary) @@ -1379,15 +1321,7 @@ def _group_concat(op, **kw): sep = translate_val(op.sep, **kw) concat = sg.func("string_agg", arg, sep, dialect="duckdb") - - if (where := op.where) is not None: - predicate = translate_val(where, **kw) - return sg.expressions.Filter( - this=concat, - expression=sg.expressions.Where(this=predicate), - ) - - return concat + return _apply_agg_filter(concat, where=op.where, **kw) # TODO @@ -1598,10 +1532,7 @@ def formatter(op, **kw): def _bitor(op, **kw): arg = translate_val(op.arg, **kw) bit_expr = sg.func(_bit_agg[type(op)], arg) - if op.where is not None: - where = sg.expressions.Where(this=translate_val(op.where, **kw)) - return sg.expressions.Filter(this=bit_expr, expression=where) - return bit_expr + return _apply_agg_filter(bit_expr, where=op.where, **kw) @translate_val.register(ops.Xor) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index f6d92cacd685..ea07595734b9 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -522,11 +522,6 @@ def mean_and_std(v): lambda t, where: t.double_col[where].var(ddof=0), id="var_pop", marks=[ - mark.broken( - ["duckdb"], - raises=com.IbisError, - reason="sqlglot mistranslates VariancePop to variance_pop instead of var_pop", - ), mark.notimpl( ["druid"], raises=sa.exc.ProgrammingError, From 8fbb33ce7396ae756c3dae9acbd9d92008dda49d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:11:59 -0400 Subject: [PATCH 058/222] chore: implement `CountDistinctStar` --- ibis/backends/duckdb/compiler/values.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 693cdab48354..1e15eead1218 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1127,10 +1127,14 @@ def _count_distinct(op, **kw): return _apply_agg_filter(count_expr, where=op.where, **kw) -# TODO: implement @translate_val.register(ops.CountDistinctStar) def _count_distinct_star(op, **kw): - ... + # use a tuple because duckdb doesn't accept COUNT(DISTINCT a, b, c, ...) + # + # this turns the expression into COUNT(DISTINCT (a, b, c, ...)) + row = sg.exp.Tuple(expressions=list(map(sg.column, op.arg.schema.keys()))) + expr = sg.exp.Count(this=sg.exp.Distinct(expressions=[row])) + return _apply_agg_filter(expr, where=op.where, **kw) @translate_val.register(ops.CountStar) From 6c201fd51465ad8609dfa7341f03715fca3d3731 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:23:20 -0400 Subject: [PATCH 059/222] chore: clean up approximate aggs and count --- ibis/backends/duckdb/compiler/values.py | 31 ++++++------------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 1e15eead1218..670d9b6c3e78 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -224,7 +224,7 @@ def _literal(op, **kw): ops.RandomScalar: "random", ops.Sign: "sign", # Unary aggregates - ops.ApproxMedian: "median", + ops.ApproxCountDistinct: "approx_count_distinct", ops.Median: "median", ops.Mean: "avg", ops.Max: "max", @@ -234,6 +234,7 @@ def _literal(op, **kw): ops.ArgMax: "arg_max", ops.First: "first", ops.Last: "last", + ops.Count: "count", # string operations ops.StringContains: "contains", ops.StringLength: "length", @@ -1108,22 +1109,10 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: ### Counting -@translate_val.register(ops.Count) -def _count(op, **kw): - arg = translate_val(op.arg, **kw) - count_expr = sg.expressions.Count(this=arg) - return _apply_agg_filter(count_expr, where=op.where, **kw) - - @translate_val.register(ops.CountDistinct) -@translate_val.register(ops.ApproxCountDistinct) def _count_distinct(op, **kw): arg = translate_val(op.arg, **kw) - count_expr = sg.expressions.Count( - this=sg.expressions.Distinct( - expressions=[arg], - ) - ) + count_expr = sg.expressions.Count(this=sg.expressions.Distinct(expressions=[arg])) return _apply_agg_filter(count_expr, where=op.where, **kw) @@ -1682,21 +1671,15 @@ def format_window_frame(func, frame, **kw): } -# TODO -UNSUPPORTED_REDUCTIONS = ( - ops.ApproxMedian, - ops.ApproxCountDistinct, -) +@translate_val.register(ops.ApproxMedian) +def _approx_median(op, **kw): + expr = sg.func("approx_quantile", "0.5", translate_val(op.arg)) + return _apply_agg_filter(expr, where=op.where, **kw) # TODO @translate_val.register(ops.WindowFunction) def _window(op: ops.WindowFunction, **kw: Any): - if isinstance(op.func, UNSUPPORTED_REDUCTIONS): - raise com.UnsupportedOperationError( - f"{type(op.func)} is not supported in window functions" - ) - if isinstance(op.func, ops.CumulativeOp): arg = cumulative_to_window(op.func, op.frame) return translate_val(arg, **kw) From f4273255ee35fb503d247db23c36d28089dd67de Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:23:55 -0400 Subject: [PATCH 060/222] chore: use sg.func in more places --- ibis/backends/duckdb/compiler/values.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 670d9b6c3e78..272ec30bb56d 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1145,14 +1145,14 @@ def _sum(op, **kw): def _nth_value(op, **kw): arg = translate_val(op.arg, **kw) nth = translate_val(op.nth, **kw) - return f"nth_value({arg}, ({nth}) + 1)" + return sg.func("nth_value", arg, nth + 1) @translate_val.register(ops.Repeat) def _repeat(op, **kw): arg = translate_val(op.arg, **kw) times = translate_val(op.times, **kw) - return f"repeat({arg}, {times})" + return sg.func("repeat", arg, times) ### Stats @@ -1243,10 +1243,13 @@ def _arbitrary(op, **kw): @translate_val.register(ops.FindInSet) def _index_of(op, **kw): - values = map(partial(translate_val, **kw), op.values) - values = ", ".join(map(_sql, values)) needle = translate_val(op.needle, **kw) - return f"list_indexof([{values}], {needle}) - 1" + return ( + sg.func( + "list_indexof", list(map(partial(translate_val, **kw), op.values)), needle + ) + - 1 + ) @translate_val.register(tuple) @@ -1719,8 +1722,7 @@ def formatter(op, **kw): offset_fmt = translate_val(offset, **kw) pieces.append(offset_fmt) - res = f"{name}({', '.join(map(_sql, pieces))})" - return res + return sg.func(name, *pieces) return formatter From bdacfa9e2b62c1e89dfbfd31a9ed0c309899797a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:24:19 -0400 Subject: [PATCH 061/222] chore: use `sg.exp.Tuple` instead of string formatting --- ibis/backends/duckdb/compiler/values.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 272ec30bb56d..788473970a4b 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1253,9 +1253,8 @@ def _index_of(op, **kw): @translate_val.register(tuple) -def _node_list(op, punct="()", **kw): - values = ", ".join(map(_sql, map(partial(translate_val, **kw), op))) - return f"{punct[0]}{values}{punct[1]}" +def _node_list(op, **kw): + return sg.exp.Tuple(expressions=list(map(partial(translate_val, **kw), op))) @translate_val.register(ops.SimpleCase) From 30c50891755b69c5aad87cd7421ddecfb5c94ca2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:24:33 -0400 Subject: [PATCH 062/222] chore: commit ugly sql for now --- .../duckdb/out.sql | 79 ++++++++++++++----- .../test_group_by_has_index/duckdb/out.sql | 20 ++++- 2 files changed, 77 insertions(+), 22 deletions(-) diff --git a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql index 66a947699796..dbde514e7b86 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql @@ -1,22 +1,61 @@ -WITH t0 AS ( - SELECT - t4.key AS key - FROM leaf AS t4 - WHERE - CAST(TRUE AS BOOLEAN) -), t1 AS ( +SELECT + t4.key +FROM ( SELECT - t0.key AS key - FROM t0 -), t2 AS ( + t1.key + FROM ( + SELECT + * + FROM "leaf" AS t0 + WHERE + TRUE + ) AS t1 + INNER JOIN ( + SELECT + t1.key + FROM ( + SELECT + * + FROM "leaf" AS t0 + WHERE + TRUE + ) AS t1 + ) AS t2 + ON ( + t1.key + ) = ( + t2.key + ) +) AS t4 +INNER JOIN ( SELECT - t0.key AS key - FROM t0 - JOIN t1 - ON t0.key = t1.key -) -SELECT - t2.key -FROM t2 -JOIN t2 AS t3 - ON t2.key = t3.key \ No newline at end of file + t1.key + FROM ( + SELECT + * + FROM "leaf" AS t0 + WHERE + TRUE + ) AS t1 + INNER JOIN ( + SELECT + t1.key + FROM ( + SELECT + * + FROM "leaf" AS t0 + WHERE + TRUE + ) AS t1 + ) AS t2 + ON ( + t1.key + ) = ( + t2.key + ) +) AS t5 + ON ( + t4.key + ) = ( + t5.key + ) \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql index fc16f2428d16..241f1095fd1e 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql @@ -17,6 +17,22 @@ SELECT ELSE 'Unknown continent' END AS cont, SUM(t0.population) AS total_pop -FROM countries AS t0 +FROM "countries" AS t0 GROUP BY - 1 \ No newline at end of file + CASE t0.continent + WHEN 'NA' + THEN 'North America' + WHEN 'SA' + THEN 'South America' + WHEN 'EU' + THEN 'Europe' + WHEN 'AF' + THEN 'Africa' + WHEN 'AS' + THEN 'Asia' + WHEN 'OC' + THEN 'Oceania' + WHEN 'AN' + THEN 'Antarctica' + ELSE 'Unknown continent' + END \ No newline at end of file From 09114e0331c9fc5f7895fc3325f4b501a308d26a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:28:37 -0400 Subject: [PATCH 063/222] chore: fix scalar parameter translation --- ibis/backends/duckdb/compiler/values.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 788473970a4b..867720dc061a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1352,7 +1352,7 @@ def _scalar_param(op, params: Mapping[ops.Node, Any], **kw): if isinstance(dtype, dt.Struct): literal = ibis.struct(raw_value, type=dtype) elif isinstance(dtype, dt.Map): - literal = ibis.map(raw_value, type=dtype) + literal = ibis.map(raw_value) else: literal = ibis.literal(raw_value, type=dtype) return translate_val(literal.op(), **kw) @@ -1750,4 +1750,5 @@ def _rowid(op, *, aliases, **_) -> str: @translate_val.register(ops.ScalarUDF) def _scalar_udf(op, **kw) -> str: - return sg.func(op.__full_name__, *(translate_val(arg, **kw) for arg in op.args)) + funcname = op.__class__.__name__ + return sg.func(funcname, *(translate_val(arg, **kw) for arg in op.args)) From db3ede46879263ed7a28ce060412c249f434f206 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:46:26 -0400 Subject: [PATCH 064/222] chore: implement fillna --- ibis/backends/duckdb/compiler/relations.py | 26 ++++++++++++++++++++++ ibis/backends/duckdb/compiler/values.py | 11 +++++++++ 2 files changed, 37 insertions(+) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index a3ea1c2a13e0..86f064d657a0 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from collections.abc import Mapping from functools import partial import sqlglot as sg @@ -234,3 +235,28 @@ def _dropna(op: ops.DropNa, *, table, **kw): return table.where(predicate, dialect="duckdb") except AttributeError: return sg.select("*").from_(table).where(predicate, dialect="duckdb") + + +@translate_rel.register +def _fillna(op: ops.FillNa, *, table, **kw): + replacements = op.replacements + if isinstance(replacements, Mapping): + mapping = replacements + else: + mapping = { + name: replacements for name, dtype in op.schema.items() if dtype.nullable + } + exprs = [ + ( + sg.alias( + sg.exp.Coalesce( + this=sg.column(col), expressions=[translate_val(alt, **kw)] + ), + col, + ) + if (alt := mapping.get(col)) is not None + else sg.column(col) + ) + for col in op.schema.keys() + ] + return sg.select(*exprs).from_(table) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 867720dc061a..5086bf5fdaba 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1752,3 +1752,14 @@ def _rowid(op, *, aliases, **_) -> str: def _scalar_udf(op, **kw) -> str: funcname = op.__class__.__name__ return sg.func(funcname, *(translate_val(arg, **kw) for arg in op.args)) + + +@translate_val.register(int) +@translate_val.register(float) +def _int_float(val, **kw): + return sg.exp.Literal(this=str(val), is_string=False) + + +@translate_val.register(str) +def _str(val, **kw): + return sg.exp.Literal(this=val, is_string=True) From f57ea96c0e363ad6f25133862f3fb6cade4cc23d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:46:51 -0400 Subject: [PATCH 065/222] chore: fix struct column construction --- ibis/backends/duckdb/compiler/values.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5086bf5fdaba..9b1497e01d72 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1330,11 +1330,15 @@ def _array_column(op, **kw): @translate_val.register(ops.StructColumn) def _struct_column(op, **kw): - values = translate_val(op.values, **kw) - struct_type = DuckDBType.from_ibis(op.dtype) - # TODO: this seems like a workaround - # but maybe it isn't - return sg.cast(expression=values, to=struct_type) + return sg.exp.Struct( + expressions=[ + sg.exp.Slice( + this=sg.exp.Literal(this=name, is_string=True), + expression=translate_val(value, **kw), + ) + for name, value in zip(op.names, op.values) + ] + ) @translate_val.register(ops.StructField) From 91c17b91376bbcb44dffdae6e2d47e6cd4f0269e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:49:21 -0400 Subject: [PATCH 066/222] chore: clean up struct extract a bit --- ibis/backends/duckdb/compiler/values.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 9b1497e01d72..c0c319f96180 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1344,9 +1344,9 @@ def _struct_column(op, **kw): @translate_val.register(ops.StructField) def _struct_field(op, **kw): arg = translate_val(op.arg, **kw) - field = sg.expressions.Literal(this=f"{op.field}", is_string=True) - sg_expr = sg.func("struct_extract", arg, field) - return sg_expr + return sg.exp.StructExtract( + this=arg, expression=sg.exp.Literal(this=op.field, is_string=True) + ) @translate_val.register(ops.ScalarParameter) From 5eb104eeaf4e5db216817ab29dd8bdfb0c128b38 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:51:57 -0400 Subject: [PATCH 067/222] chore: use sg.exp instead of sg.expressions --- ibis/backends/duckdb/compiler/values.py | 247 ++++++++++++------------ 1 file changed, 119 insertions(+), 128 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index c0c319f96180..8a1b79becccb 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -71,19 +71,20 @@ def _sql(obj, dialect="duckdb"): def sg_literal(arg, is_string=True): - return sg.expressions.Literal(this=f"{arg}", is_string=is_string) + return sg.exp.Literal(this=f"{arg}", is_string=is_string) @translate_val.register(ops.Literal) def _literal(op, **kw): value = op.value dtype = op.dtype + if value is None and dtype.nullable: if dtype.is_null(): - return sg.expressions.Null() - return sg.cast(sg.expressions.Null(), to=DuckDBType.from_ibis(dtype)) + return sg.exp.Null() + return sg.cast(sg.exp.Null(), to=DuckDBType.from_ibis(dtype)) if dtype.is_boolean(): - return sg.expressions.Boolean(this=value) + return sg.exp.Boolean(this=value) elif dtype.is_inet(): com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") return None @@ -102,14 +103,14 @@ def _literal(op, **kw): f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" ) if math.isinf(value): - return sg.expressions.cast( + return sg.exp.cast( expression=sg_literal(value), - to=sg.expressions.DataType.Type.FLOAT, + to=sg.exp.DataType.Type.FLOAT, ) elif math.isnan(value): - return sg.expressions.cast( + return sg.exp.cast( expression=sg_literal("NaN"), - to=sg.expressions.DataType.Type.FLOAT, + to=sg.exp.DataType.Type.FLOAT, ) dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) @@ -119,14 +120,14 @@ def _literal(op, **kw): return sg_expr elif dtype.is_numeric(): if math.isinf(value): - return sg.expressions.cast( + return sg.exp.cast( expression=sg_literal(value), - to=sg.expressions.DataType.Type.FLOAT, + to=sg.exp.DataType.Type.FLOAT, ) elif math.isnan(value): - return sg.expressions.cast( + return sg.exp.cast( expression=sg_literal("NaN"), - to=sg.expressions.DataType.Type.FLOAT, + to=sg.exp.DataType.Type.FLOAT, ) return sg.cast( sg_literal(value, is_string=False), @@ -155,11 +156,11 @@ def _literal(op, **kw): year = sg_literal(op.value.year, is_string=False) month = sg_literal(op.value.month, is_string=False) day = sg_literal(op.value.day, is_string=False) - return sg.expressions.DateFromParts(year=year, month=month, day=day) + return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): value_type = dtype.value_type is_string = isinstance(value_type, dt.String) - values = sg.expressions.Array().from_arg_list( + values = sg.exp.Array().from_arg_list( [ # TODO: this cast makes for frustrating output # is there any better way to handle it? @@ -174,13 +175,13 @@ def _literal(op, **kw): elif dtype.is_map(): key_type = dtype.key_type value_type = dtype.value_type - keys = sg.expressions.Array().from_arg_list( + keys = sg.exp.Array().from_arg_list( [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] ) - values = sg.expressions.Array().from_arg_list( + values = sg.exp.Array().from_arg_list( [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] ) - sg_expr = sg.expressions.Map(keys=keys, values=values) + sg_expr = sg.exp.Map(keys=keys, values=values) return sg_expr elif dtype.is_struct(): keys = [sg_literal(key) for key in value.keys()] @@ -188,10 +189,8 @@ def _literal(op, **kw): _literal(ops.Literal(v, dtype=subdtype), **kw) for subdtype, v in zip(dtype.types, value.values()) ] - slices = [ - sg.expressions.Slice(this=k, expression=v) for k, v in zip(keys, values) - ] - sg_expr = sg.expressions.Struct.from_arg_list(slices) + slices = [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] + sg_expr = sg.exp.Struct.from_arg_list(slices) return sg_expr else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -294,11 +293,11 @@ def _fmt(op, _name: str = _name, **kw): ### Bitwise Business _bitwise_mapping = { - ops.BitwiseLeftShift: sg.expressions.BitwiseLeftShift, - ops.BitwiseRightShift: sg.expressions.BitwiseRightShift, - ops.BitwiseAnd: sg.expressions.BitwiseAnd, - ops.BitwiseOr: sg.expressions.BitwiseOr, - ops.BitwiseXor: sg.expressions.BitwiseXor, + ops.BitwiseLeftShift: sg.exp.BitwiseLeftShift, + ops.BitwiseRightShift: sg.exp.BitwiseRightShift, + ops.BitwiseAnd: sg.exp.BitwiseAnd, + ops.BitwiseOr: sg.exp.BitwiseOr, + ops.BitwiseXor: sg.exp.BitwiseXor, } @@ -319,7 +318,7 @@ def _bitwise_binary(op, **kw): def _bitwise_not(op, **kw): value = translate_val(op.arg, **kw) - return sg.expressions.BitwiseNot(this=value) + return sg.exp.BitwiseNot(this=value) ### Mathematical Calisthenics @@ -363,8 +362,8 @@ def _floor_divide(op, **kw): def _round(op, **kw): arg = translate_val(op.arg, **kw) if (digits := op.digits) is not None: - return sg.expressions.Round(this=arg, decimals=translate_val(digits, **kw)) - return sg.expressions.Round(this=arg) + return sg.exp.Round(this=arg, decimals=translate_val(digits, **kw)) + return sg.exp.Round(this=arg) ### Dtype Dysmorphia @@ -394,11 +393,11 @@ def _cast(op, **kw): ) else: - return sg.expressions.Interval(this=arg, unit=suffix) + return sg.exp.Interval(this=arg, unit=suffix) elif isinstance(op.to, dt.Timestamp) and isinstance(op.arg.dtype, dt.Integer): return sg.func("to_timestamp", arg) elif isinstance(op.to, dt.Timestamp) and op.to.timezone is not None: - timezone = sg.expressions.Literal(this=op.to.timezone, is_string=True) + timezone = sg.exp.Literal(this=op.to.timezone, is_string=True) return sg.func("timezone", timezone, arg) to = translate_val(op.to, **kw) @@ -407,7 +406,7 @@ def _cast(op, **kw): @translate_val.register(ops.TryCast) def _try_cast(op, **kw): - return sg.expressions.TryCast( + return sg.exp.TryCast( this=translate_val(op.arg, **kw), to=DuckDBType.to_string(op.to), dialect="duckdb", @@ -428,19 +427,19 @@ def _between(op, **kw): arg = translate_val(op.arg, **kw) lower_bound = translate_val(op.lower_bound, **kw) upper_bound = translate_val(op.upper_bound, **kw) - return sg.expressions.Between(this=arg, low=lower_bound, high=upper_bound) + return sg.exp.Between(this=arg, low=lower_bound, high=upper_bound) @translate_val.register(ops.Negate) def _negate(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Neg(this=arg) + return sg.exp.Neg(this=arg) @translate_val.register(ops.Not) def _not(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Not(this=arg) + return sg.exp.Not(this=arg) def _apply_agg_filter(expr, *, where, **kw): @@ -491,7 +490,7 @@ def _not_all(op, **kw): @translate_val.register(ops.Date) def _to_date(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Date(this=arg) + return sg.exp.Date(this=arg) @translate_val.register(ops.DateFromYMD) @@ -499,13 +498,13 @@ def _date_from_ymd(op, **kw): y = translate_val(op.year, **kw) m = translate_val(op.month, **kw) d = translate_val(op.day, **kw) - return sg.expressions.DateFromParts(year=y, month=m, day=d) + return sg.exp.DateFromParts(year=y, month=m, day=d) @translate_val.register(ops.Time) def _time(op, **kw): arg = translate_val(op.arg, **kw) - return sg.cast(expression=arg, to=sg.expressions.DataType.Type.TIME) + return sg.cast(expression=arg, to=sg.exp.DataType.Type.TIME) @translate_val.register(ops.TimeFromHMS) @@ -528,7 +527,7 @@ def _timestamp_from_unix(op, **kw): if (unit := op.unit.short) in {"ms", "us", "ns"}: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - return sg.expressions.UnixToTime(this=arg) + return sg.exp.UnixToTime(this=arg) @translate_val.register(ops.TimestampFromYMDHMS) @@ -569,9 +568,9 @@ def _extract_epoch_seconds(op, **kw): arg = translate_val(op.arg, **kw) return sg.func( "epoch", - sg.expressions.cast( + sg.exp.cast( expression=arg, - to=sg.expressions.DataType.Type.TIMESTAMP, + to=sg.exp.DataType.Type.TIMESTAMP, ), ) @@ -601,9 +600,7 @@ def _extract_epoch_seconds(op, **kw): def _extract_time(op, **kw): part = _extract_mapping[type(op)] timestamp = translate_val(op.arg, **kw) - return sg.func( - "extract", sg.expressions.Literal(this=part, is_string=True), timestamp - ) + return sg.func("extract", sg.exp.Literal(this=part, is_string=True), timestamp) # DuckDB extracts subminute microseconds and milliseconds @@ -612,13 +609,13 @@ def _extract_time(op, **kw): def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Mod( + return sg.exp.Mod( this=sg.func( "extract", - sg.expressions.Literal(this="us", is_string=True), + sg.exp.Literal(this="us", is_string=True), arg, ), - expression=sg.expressions.Literal(this="1000000", is_string=False), + expression=sg.exp.Literal(this="1000000", is_string=False), ) @@ -626,13 +623,13 @@ def _extract_microsecond(op, **kw): def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Mod( + return sg.exp.Mod( this=sg.func( "extract", - sg.expressions.Literal(this="ms", is_string=True), + sg.exp.Literal(this="ms", is_string=True), arg, ), - expression=sg.expressions.Literal(this="1000", is_string=False), + expression=sg.exp.Literal(this="1000", is_string=False), ) @@ -704,9 +701,9 @@ def day_of_week_name(op, **kw): _interval_mapping = { - ops.IntervalAdd: sg.expressions.Add, - ops.IntervalSubtract: sg.expressions.Sub, - ops.IntervalMultiply: sg.expressions.Mul, + ops.IntervalAdd: sg.exp.Add, + ops.IntervalSubtract: sg.exp.Sub, + ops.IntervalMultiply: sg.exp.Mul, } @@ -728,8 +725,8 @@ def _interval_format(op): "Duckdb doesn't support nanosecond interval resolutions" ) - return sg.expressions.Interval( - this=sg.expressions.Literal(this=op.value, is_string=False), + return sg.exp.Interval( + this=sg.exp.Literal(this=op.value, is_string=False), unit=dtype.resolution.upper(), ) @@ -762,12 +759,12 @@ def _substring(op, **kw): else: length = None - if_pos = sg.expressions.Substring(this=arg, start=start + 1, length=length) - if_neg = sg.expressions.Substring(this=arg, start=start, length=length) + if_pos = sg.exp.Substring(this=arg, start=start + 1, length=length) + if_neg = sg.exp.Substring(this=arg, start=start, length=length) - sg_expr = sg.expressions.If( - this=sg.expressions.GTE( - this=start, expression=sg.expressions.Literal(this="0", is_string=False) + sg_expr = sg.exp.If( + this=sg.exp.GTE( + this=start, expression=sg.exp.Literal(this="0", is_string=False) ), true=if_pos, false=if_neg, @@ -792,7 +789,7 @@ def _regex_search(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) return sg.func( - "regexp_matches", arg, pattern, sg.expressions.Literal(this="s", is_string=True) + "regexp_matches", arg, pattern, sg.exp.Literal(this="s", is_string=True) ) @@ -806,7 +803,7 @@ def _regex_replace(op, **kw): arg, pattern, replacement, - sg.expressions.Literal(this="g", is_string=True), + sg.exp.Literal(this="g", is_string=True), dialect="duckdb", ) @@ -830,7 +827,7 @@ def _levenshtein(op, **kw): def _string_split(op, **kw): arg = translate_val(op.arg, **kw) delimiter = translate_val(op.delimiter, **kw) - return sg.expressions.Split(this=arg, expression=delimiter) + return sg.exp.Split(this=arg, expression=delimiter) @translate_val.register(ops.StringJoin) @@ -844,27 +841,27 @@ def _string_join(op, **kw): @translate_val.register(ops.StringConcat) def _string_concat(op, **kw): arg = map(partial(translate_val, **kw), op.arg) - return sg.expressions.Concat(expressions=list(arg)) + return sg.exp.Concat(expressions=list(arg)) @translate_val.register(ops.StringSQLLike) def _string_like(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return sg.expressions.Like(this=arg, expression=pattern) + return sg.exp.Like(this=arg, expression=pattern) @translate_val.register(ops.StringSQLILike) def _string_ilike(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return sg.expressions.Like(this=sg.func("lower", arg), expression=pattern) + return sg.exp.Like(this=sg.func("lower", arg), expression=pattern) @translate_val.register(ops.Capitalize) def _string_capitalize(op, **kw): arg = translate_val(op.arg, **kw) - return sg.expressions.Concat( + return sg.exp.Concat( expressions=[ sg.func("upper", sg.func("substr", arg, 1, 1)), sg.func("lower", sg.func("substr", arg, 2)), @@ -876,13 +873,13 @@ def _string_capitalize(op, **kw): @translate_val.register(ops.IsNull) def _is_null(op, **kw): arg = translate_val(op.arg, **kw) - return arg.is_(sg.expressions.null()) + return arg.is_(sg.exp.null()) @translate_val.register(ops.NotNull) def _is_not_null(op, **kw): arg = translate_val(op.arg, **kw) - return arg.is_(sg.not_(sg.expressions.null())) + return arg.is_(sg.not_(sg.exp.null())) @translate_val.register(ops.IfNull) @@ -911,14 +908,14 @@ def _zero_if_null(op, **kw): def _array_sort(op, **kw): arg = translate_val(op.arg, **kw) - sg_expr = sg.expressions.If( - this=arg.is_(sg.expressions.Null()), - true=sg.expressions.Null(), + sg_expr = sg.exp.If( + this=arg.is_(sg.exp.Null()), + true=sg.exp.Null(), false=sg.func("list_distinct", arg, dialect="duckdb") - + sg.expressions.If( + + sg.exp.If( this=sg.func("list_count", arg, dialect="duckdb") < sg.func("array_length", arg, dialect="duckdb"), - true=sg.func("list_value", sg.expressions.Null(), dialect="duckdb"), + true=sg.func("list_value", sg.exp.Null(), dialect="duckdb"), false=sg.func("list_value", dialect="duckdb"), ), ) @@ -939,9 +936,7 @@ def _in_values(op, **kw): if not op.options: return False value = translate_val(op.value, **kw) - options = sg.expressions.Array().from_arg_list( - [translate_val(x, **kw) for x in op.options] - ) + options = sg.exp.Array().from_arg_list([translate_val(x, **kw) for x in op.options]) sg_expr = sg.func("list_contains", options, value, dialect="duckdb") return sg_expr @@ -990,8 +985,8 @@ def _array_repeat_op(op, **kw): def _neg_idx_to_pos(array, idx): arg_length = sg.func("len", array) - return sg.expressions.If( - this=sg.expressions.LT(this=idx, expression=sg_literal(0, is_string=False)), + return sg.exp.If( + this=sg.exp.LT(this=idx, expression=sg_literal(0, is_string=False)), # Need to have the greatest here to handle the case where # abs(neg_index) > arg_length # e.g. where the magnitude of the negative index is greater than the @@ -1015,7 +1010,7 @@ def _array_slice_op(op, **kw): start = sg.func("least", arg_length, _neg_idx_to_pos(arg, start)) if (stop := op.stop) is None: - stop = sg.expressions.Null() + stop = sg.exp.Null() else: stop = _neg_idx_to_pos(arg, translate_val(stop, **kw)) @@ -1033,7 +1028,7 @@ def _array_string_join(op, **kw): def _array_map(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - lamduh = sg.expressions.Lambda( + lamduh = sg.exp.Lambda( this=result, expressions=[sg.to_identifier(f"{op.parameter}", quoted=False)], ) @@ -1045,9 +1040,9 @@ def _array_map(op, **kw): def _array_filter(op, **kw): arg = translate_val(op.arg, **kw) result = translate_val(op.result, **kw) - lamduh = sg.expressions.Lambda( + lamduh = sg.exp.Lambda( this=result, - expressions=[sg.expressions.Identifier(this=f"{op.parameter}", quoted=False)], + expressions=[sg.exp.Identifier(this=f"{op.parameter}", quoted=False)], ) sg_expr = sg.func("list_filter", arg, lamduh) return sg_expr @@ -1081,16 +1076,16 @@ def _array_union(op, **kw): def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: i = sg.to_identifier("i", quoted=False) args = [translate_val(arg, **kw) for arg in op.arg] - result = sg.expressions.Struct( + result = sg.exp.Struct( expressions=[ - sg.expressions.Slice( + sg.exp.Slice( this=sg_literal(name), expression=sg.func("list_extract", arg, i), ) for name, arg in zip(op.dtype.value_type.names, args) ] ) - lamduh = sg.expressions.Lambda(this=result, expressions=[i]) + lamduh = sg.exp.Lambda(this=result, expressions=[i]) sg_expr = sg.func( "list_transform", sg.func( @@ -1112,7 +1107,7 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: @translate_val.register(ops.CountDistinct) def _count_distinct(op, **kw): arg = translate_val(op.arg, **kw) - count_expr = sg.expressions.Count(this=sg.expressions.Distinct(expressions=[arg])) + count_expr = sg.exp.Count(this=sg.exp.Distinct(expressions=[arg])) return _apply_agg_filter(count_expr, where=op.where, **kw) @@ -1128,7 +1123,7 @@ def _count_distinct_star(op, **kw): @translate_val.register(ops.CountStar) def _count_star(op, **kw): - sql = sg.expressions.Count(this=sg.expressions.Star()) + sql = sg.exp.Count(this=sg.exp.Star()) return _apply_agg_filter(sql, where=op.where, **kw) @@ -1137,7 +1132,7 @@ def _sum(op, **kw): arg = translate_val( ops.Cast(arg, to=op.dtype) if (arg := op.arg).dtype.is_boolean() else arg, **kw ) - return _apply_agg_filter(sg.expressions.Sum(this=arg), where=op.where, **kw) + return _apply_agg_filter(sg.exp.Sum(this=arg), where=op.where, **kw) # TODO @@ -1260,10 +1255,10 @@ def _node_list(op, **kw): @translate_val.register(ops.SimpleCase) @translate_val.register(ops.SearchedCase) def _case(op, **kw): - case = sg.expressions.Case() + case = sg.exp.Case() if (base := getattr(op, "base", None)) is not None: - case = sg.expressions.Case(this=translate_val(base, **kw)) + case = sg.exp.Case(this=translate_val(base, **kw)) for when, then in zip(op.cases, op.results): case = case.when( @@ -1322,9 +1317,7 @@ def _group_concat(op, **kw): # TODO @translate_val.register(ops.ArrayColumn) def _array_column(op, **kw): - sg_expr = sg.expressions.Array.from_arg_list( - [translate_val(col, **kw) for col in op.cols] - ) + sg_expr = sg.exp.Array.from_arg_list([translate_val(col, **kw) for col in op.cols]) return sg_expr @@ -1421,7 +1414,7 @@ def _vararg_func(op, **kw): def _map(op, **kw): keys = translate_val(op.keys, **kw) values = translate_val(op.values, **kw) - sg_expr = sg.expressions.Map(keys=keys, values=values) + sg_expr = sg.exp.Map(keys=keys, values=values) return sg_expr @@ -1443,7 +1436,7 @@ def _map_get(op, **kw): def _map_contains(op, **kw): arg = translate_val(op.arg, **kw) key = translate_val(op.key, **kw) - sg_expr = sg.expressions.NEQ( + sg_expr = sg.exp.NEQ( this=sg.func( "array_length", sg.func( @@ -1452,7 +1445,7 @@ def _map_contains(op, **kw): key, ), ), - expression=sg.expressions.Literal(this="0", is_string=False), + expression=sg.exp.Literal(this="0", is_string=False), ) return sg_expr @@ -1472,14 +1465,14 @@ def _map_merge(op, **kw): return sg.func("map_concat", left, right) -def _binary_infix(sg_expr: sg.expressions._Expression): +def _binary_infix(sg_expr: sg.exp._Expression): def formatter(op, **kw): left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) return sg_expr( - this=sg.expressions.Paren(this=left), - expression=sg.expressions.Paren(this=right), + this=sg.exp.Paren(this=left), + expression=sg.exp.Paren(this=right), ) return formatter @@ -1487,28 +1480,28 @@ def formatter(op, **kw): _binary_infix_ops = { # Binary operations - ops.Add: sg.expressions.Add, - ops.Subtract: sg.expressions.Sub, - ops.Multiply: sg.expressions.Mul, - ops.Divide: sg.expressions.Div, - ops.Modulus: sg.expressions.Mod, + ops.Add: sg.exp.Add, + ops.Subtract: sg.exp.Sub, + ops.Multiply: sg.exp.Mul, + ops.Divide: sg.exp.Div, + ops.Modulus: sg.exp.Mod, # Comparisons - ops.GreaterEqual: sg.expressions.GTE, - ops.Greater: sg.expressions.GT, - ops.LessEqual: sg.expressions.LTE, - ops.Less: sg.expressions.LT, - ops.Equals: sg.expressions.EQ, - ops.NotEquals: sg.expressions.NEQ, - ops.Xor: sg.expressions.Xor, + ops.GreaterEqual: sg.exp.GTE, + ops.Greater: sg.exp.GT, + ops.LessEqual: sg.exp.LTE, + ops.Less: sg.exp.LT, + ops.Equals: sg.exp.EQ, + ops.NotEquals: sg.exp.NEQ, + ops.Xor: sg.exp.Xor, # Boolean comparisons - ops.And: sg.expressions.And, - ops.Or: sg.expressions.Or, - ops.DateAdd: sg.expressions.Add, - ops.DateSub: sg.expressions.Sub, - ops.DateDiff: sg.expressions.Sub, - ops.TimestampAdd: sg.expressions.Add, - ops.TimestampSub: sg.expressions.Sub, - ops.TimestampDiff: sg.expressions.Sub, + ops.And: sg.exp.And, + ops.Or: sg.exp.Or, + ops.DateAdd: sg.exp.Add, + ops.DateSub: sg.exp.Sub, + ops.DateDiff: sg.exp.Sub, + ops.TimestampAdd: sg.exp.Add, + ops.TimestampSub: sg.exp.Sub, + ops.TimestampDiff: sg.exp.Sub, } @@ -1539,12 +1532,10 @@ def _xor(op, **kw): # TODO: is this really the best way to do this? left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return sg.expressions.And( - this=sg.expressions.Paren(this=sg.expressions.Or(this=left, expression=right)), - expression=sg.expressions.Paren( - this=sg.expressions.Not( - this=sg.expressions.And(this=left, expression=right) - ) + return sg.exp.And( + this=sg.exp.Paren(this=sg.exp.Or(this=left, expression=right)), + expression=sg.exp.Paren( + this=sg.exp.Not(this=sg.exp.And(this=left, expression=right)) ), ) @@ -1554,7 +1545,7 @@ def _xor(op, **kw): @translate_val.register(ops.RowNumber) def _row_number(_, **kw): - return sg.expressions.RowNumber() + return sg.exp.RowNumber() @translate_val.register(ops.DenseRank) @@ -1736,7 +1727,7 @@ def formatter(op, **kw): @translate_val.register(ops.Argument) def _argument(op, **_): - return sg.expressions.Identifier(this=op.name, quoted=False) + return sg.exp.Identifier(this=op.name, quoted=False) @translate_val.register(ops.JSONGetItem) From f86aa56a9a30cecbf9bc768035ac32a8d10bb5d1 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:57:53 -0400 Subject: [PATCH 068/222] chore: fix inet and uuid translation --- ibis/backends/duckdb/compiler/values.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8a1b79becccb..8357d9f33f07 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -83,11 +83,10 @@ def _literal(op, **kw): if dtype.is_null(): return sg.exp.Null() return sg.cast(sg.exp.Null(), to=DuckDBType.from_ibis(dtype)) - if dtype.is_boolean(): + elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_inet(): - com.UnsupportedOperationError("DuckDB doesn't support an explicit inet dtype") - return None + return sg.exp.Literal(this=str(value), is_string=True) elif dtype.is_string(): return sg_literal(value) elif dtype.is_decimal(): @@ -192,6 +191,11 @@ def _literal(op, **kw): slices = [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] sg_expr = sg.exp.Struct.from_arg_list(slices) return sg_expr + elif dtype.is_uuid(): + return sg.cast( + sg.exp.Literal(this=str(value), is_string=True), + to=sg.exp.DataType.Type.UUID, + ) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") From 836acd0005005b822ef31fd21f706e77b8a5d07c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:58:22 -0400 Subject: [PATCH 069/222] chore: promote replacement values to expressions in fillna --- ibis/backends/duckdb/compiler/values.py | 11 ----------- ibis/expr/operations/relations.py | 4 +++- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8357d9f33f07..dfd80f2d1711 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1751,14 +1751,3 @@ def _rowid(op, *, aliases, **_) -> str: def _scalar_udf(op, **kw) -> str: funcname = op.__class__.__name__ return sg.func(funcname, *(translate_val(arg, **kw) for arg in op.args)) - - -@translate_val.register(int) -@translate_val.register(float) -def _int_float(val, **kw): - return sg.exp.Literal(this=str(val), is_string=False) - - -@translate_val.register(str) -def _str(val, **kw): - return sg.exp.Literal(this=val, is_string=True) diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index ffc510889ed1..6f3c633a0caf 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -587,7 +587,9 @@ class FillNa(Relation): """Fill null values in the table.""" table: Relation - replacements: UnionType[Value[dt.Numeric | dt.String], FrozenDict[str, Any]] + replacements: UnionType[ + Value[dt.Numeric | dt.String], FrozenDict[str, Value[dt.Any]] + ] @attribute def schema(self): From 2661d3cb066baea2d61b6613a92d9b378bd315c4 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:14:44 -0400 Subject: [PATCH 070/222] chore: allow tuple translation in duckdb --- ibis/backends/tests/test_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index df7031ea0c0f..94fe24a9b247 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1032,9 +1032,9 @@ def test_has_operation_no_geo(con, op): for name, obj in sorted(inspect.getmembers(builtins), key=itemgetter(0)) for backend in sorted(ALL_BACKENDS) # filter out builtins that are types, except for tuples on ClickHouse - # because tuples are used to represent lists of expressions + # and duckdb because tuples are used to represent lists of expressions if isinstance(obj, type) - if (obj != tuple or backend != "clickhouse") + if (obj != tuple or backend not in ("clickhouse", "duckdb")) if (backend != "pyspark" or vparse(pd.__version__) < vparse("2")) ], ) From b10025a2f284f11900d2a42e0bc78efeea52b933 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:15:15 -0400 Subject: [PATCH 071/222] chore: expressionize clip --- ibis/backends/duckdb/compiler/values.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index dfd80f2d1711..7103ef30106c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -346,12 +346,11 @@ def _generic_log(op, **kw): @translate_val.register(ops.Clip) def _clip(op, **kw): arg = translate_val(op.arg, **kw) - # TODO expressionize if (upper := op.upper) is not None: - arg = f"least({translate_val(upper, **kw)}, {arg})" + arg = sg.exp.Least.from_arg_list(translate_val(upper, **kw), arg) if (lower := op.lower) is not None: - arg = f"greatest({translate_val(lower, **kw)}, {arg})" + arg = sg.exp.Greatest.from_arg_list(translate_val(lower, **kw), arg) return arg From 61a5ba62f91624710d9f15e3b92f5a67be8edfbf Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:15:29 -0400 Subject: [PATCH 072/222] chore: remove duplicate TODO --- ibis/backends/duckdb/compiler/values.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 7103ef30106c..5e377f438061 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -21,7 +21,6 @@ if TYPE_CHECKING: from collections.abc import Mapping -# TODO: Ideally we can translate bottom up a la `relations.py` # TODO: Find a way to remove all the dialect="duckdb" kwargs From f546860e1274d0019e97607d4f9927b296d063c6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:15:51 -0400 Subject: [PATCH 073/222] chore: clean up literal casting --- ibis/backends/duckdb/compiler/values.py | 50 ++++++------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5e377f438061..a88f6b6e30d3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -77,11 +77,12 @@ def sg_literal(arg, is_string=True): def _literal(op, **kw): value = op.value dtype = op.dtype + sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: if dtype.is_null(): return sg.exp.Null() - return sg.cast(sg.exp.Null(), to=DuckDBType.from_ibis(dtype)) + return sg.cast(sg.exp.Null(), to=sg_type) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_inet(): @@ -101,36 +102,19 @@ def _literal(op, **kw): f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" ) if math.isinf(value): - return sg.exp.cast( - expression=sg_literal(value), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.exp.cast(expression=sg_literal(value), to=sg_type) elif math.isnan(value): - return sg.exp.cast( - expression=sg_literal("NaN"), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.exp.cast(expression=sg_literal("NaN"), to=sg_type) dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) - sg_expr = sg.cast( - sg_literal(value, is_string=False), to=DuckDBType.from_ibis(dtype) - ) + sg_expr = sg.cast(sg_literal(value, is_string=False), to=sg_type) return sg_expr elif dtype.is_numeric(): if math.isinf(value): - return sg.exp.cast( - expression=sg_literal(value), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.exp.cast(expression=sg_literal(value), to=sg_type) elif math.isnan(value): - return sg.exp.cast( - expression=sg_literal("NaN"), - to=sg.exp.DataType.Type.FLOAT, - ) - return sg.cast( - sg_literal(value, is_string=False), - to=DuckDBType.from_ibis(dtype), - ) + return sg.exp.cast(expression=sg_literal("NaN"), to=sg_type) + return sg.cast(sg_literal(value, is_string=False), sg_type) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): @@ -157,17 +141,8 @@ def _literal(op, **kw): return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): value_type = dtype.value_type - is_string = isinstance(value_type, dt.String) - values = sg.exp.Array().from_arg_list( - [ - # TODO: this cast makes for frustrating output - # is there any better way to handle it? - sg.cast( - sg_literal(v, is_string=is_string), - to=DuckDBType.from_ibis(value_type), - ) - for v in value - ] + values = sg.exp.Array.from_arg_list( + [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value] ) return values elif dtype.is_map(): @@ -191,10 +166,7 @@ def _literal(op, **kw): sg_expr = sg.exp.Struct.from_arg_list(slices) return sg_expr elif dtype.is_uuid(): - return sg.cast( - sg.exp.Literal(this=str(value), is_string=True), - to=sg.exp.DataType.Type.UUID, - ) + return sg.cast(sg.exp.Literal(this=str(value), is_string=True), to=sg_type) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") From 9735605c3b6402b5e3ddbbaf51f870ae9e298c9c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:16:02 -0400 Subject: [PATCH 074/222] chore: remove already-fixed comment --- ibis/backends/duckdb/compiler/values.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index a88f6b6e30d3..d3c4dceec9da 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -724,8 +724,6 @@ def _interval_from_integer(op, **kw): @translate_val.register(ops.Substring) def _substring(op, **kw): - # TODO: fix expr_slice_begin tests - # Duckdb is 1-indexed arg = translate_val(op.arg, **kw) start = translate_val(op.start, **kw) if op.length is not None: From 84d0411a0102b247ad6e1fb89c5ca61f7d491a1a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:18:05 -0400 Subject: [PATCH 075/222] chore: experssionize list_aggr --- ibis/backends/duckdb/compiler/values.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index d3c4dceec9da..28e7aa89936c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -804,10 +804,11 @@ def _string_split(op, **kw): @translate_val.register(ops.StringJoin) def _string_join(op, **kw): - arg = map(partial(translate_val, **kw), op.arg) + elements = list(map(partial(translate_val, **kw), op.arg)) sep = translate_val(op.sep, **kw) - elements = ", ".join(map(_sql, arg)) - return f"list_aggregate([{elements}], 'string_agg', {sep})" + return sg.func( + "list_aggr", sg.exp.Array(expressions=elements), sg_literal("string_agg"), sep + ) @translate_val.register(ops.StringConcat) From 18b3773f3264b6e158a801caffe8bcc435df3b89 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:18:30 -0400 Subject: [PATCH 076/222] chore: remove TODOs for TODones --- ibis/backends/duckdb/compiler/values.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 28e7aa89936c..8422f17c6a45 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -892,7 +892,6 @@ def _array_sort(op, **kw): false=sg.func("list_value", dialect="duckdb"), ), ) - # TODO: this is (I think) working but tests fail because of broken NaN / None stuff return sg_expr @@ -1108,7 +1107,6 @@ def _sum(op, **kw): return _apply_agg_filter(sg.exp.Sum(this=arg), where=op.where, **kw) -# TODO @translate_val.register(ops.NthValue) def _nth_value(op, **kw): arg = translate_val(op.arg, **kw) @@ -1287,11 +1285,9 @@ def _group_concat(op, **kw): return _apply_agg_filter(concat, where=op.where, **kw) -# TODO @translate_val.register(ops.ArrayColumn) def _array_column(op, **kw): - sg_expr = sg.exp.Array.from_arg_list([translate_val(col, **kw) for col in op.cols]) - return sg_expr + return sg.exp.Array.from_arg_list([translate_val(col, **kw) for col in op.cols]) @translate_val.register(ops.StructColumn) From 240e943c12adc0559f0955a37b2bf8fc862da00d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:18:44 -0400 Subject: [PATCH 077/222] chore: make xor a bit less annoying --- ibis/backends/duckdb/compiler/values.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8422f17c6a45..83385f0f60f2 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1498,15 +1498,10 @@ def _bitor(op, **kw): @translate_val.register(ops.Xor) def _xor(op, **kw): - # TODO: is this really the best way to do this? - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.exp.And( - this=sg.exp.Paren(this=sg.exp.Or(this=left, expression=right)), - expression=sg.exp.Paren( - this=sg.exp.Not(this=sg.exp.And(this=left, expression=right)) - ), - ) + left = translate_val(ops.Cast(op.left, to=dt.int8), **kw) + right = translate_val(ops.Cast(op.right, to=dt.int8), **kw) + arg = sg.func("xor", left, right) + return sg.cast(arg, to=DuckDBType.from_ibis(dt.boolean)) ### Ordering From f1a1a35e492fe5700567fe59385204c8a3fda042 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:20:21 -0400 Subject: [PATCH 078/222] chore: clean up ilike --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 83385f0f60f2..353cbe4afb4f 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -828,7 +828,7 @@ def _string_like(op, **kw): def _string_ilike(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return sg.exp.Like(this=sg.func("lower", arg), expression=pattern) + return sg.exp.ILike(this=arg, expression=pattern) @translate_val.register(ops.Capitalize) From 7922117f6a0262c63f2e3b6f080ef8504ba61be9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:20:31 -0400 Subject: [PATCH 079/222] revert: chore: clean up literal casting This reverts commit 0c2f8b14062032782bba0dbd77e10273d5f65b92. --- ibis/backends/duckdb/compiler/values.py | 50 +++++++++++++++++++------ 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 353cbe4afb4f..1f9d9e797b5e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -77,12 +77,11 @@ def sg_literal(arg, is_string=True): def _literal(op, **kw): value = op.value dtype = op.dtype - sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: if dtype.is_null(): return sg.exp.Null() - return sg.cast(sg.exp.Null(), to=sg_type) + return sg.cast(sg.exp.Null(), to=DuckDBType.from_ibis(dtype)) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_inet(): @@ -102,19 +101,36 @@ def _literal(op, **kw): f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" ) if math.isinf(value): - return sg.exp.cast(expression=sg_literal(value), to=sg_type) + return sg.exp.cast( + expression=sg_literal(value), + to=sg.exp.DataType.Type.FLOAT, + ) elif math.isnan(value): - return sg.exp.cast(expression=sg_literal("NaN"), to=sg_type) + return sg.exp.cast( + expression=sg_literal("NaN"), + to=sg.exp.DataType.Type.FLOAT, + ) dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) - sg_expr = sg.cast(sg_literal(value, is_string=False), to=sg_type) + sg_expr = sg.cast( + sg_literal(value, is_string=False), to=DuckDBType.from_ibis(dtype) + ) return sg_expr elif dtype.is_numeric(): if math.isinf(value): - return sg.exp.cast(expression=sg_literal(value), to=sg_type) + return sg.exp.cast( + expression=sg_literal(value), + to=sg.exp.DataType.Type.FLOAT, + ) elif math.isnan(value): - return sg.exp.cast(expression=sg_literal("NaN"), to=sg_type) - return sg.cast(sg_literal(value, is_string=False), sg_type) + return sg.exp.cast( + expression=sg_literal("NaN"), + to=sg.exp.DataType.Type.FLOAT, + ) + return sg.cast( + sg_literal(value, is_string=False), + to=DuckDBType.from_ibis(dtype), + ) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): @@ -141,8 +157,17 @@ def _literal(op, **kw): return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): value_type = dtype.value_type - values = sg.exp.Array.from_arg_list( - [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value] + is_string = isinstance(value_type, dt.String) + values = sg.exp.Array().from_arg_list( + [ + # TODO: this cast makes for frustrating output + # is there any better way to handle it? + sg.cast( + sg_literal(v, is_string=is_string), + to=DuckDBType.from_ibis(value_type), + ) + for v in value + ] ) return values elif dtype.is_map(): @@ -166,7 +191,10 @@ def _literal(op, **kw): sg_expr = sg.exp.Struct.from_arg_list(slices) return sg_expr elif dtype.is_uuid(): - return sg.cast(sg.exp.Literal(this=str(value), is_string=True), to=sg_type) + return sg.cast( + sg.exp.Literal(this=str(value), is_string=True), + to=sg.exp.DataType.Type.UUID, + ) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") From 75e2ec46be2e5a75c5297443b2f76559b6f4a5be Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:24:46 -0400 Subject: [PATCH 080/222] chore: expressionize clip --- ibis/backends/duckdb/compiler/values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 1f9d9e797b5e..dd28975136b8 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -346,10 +346,10 @@ def _generic_log(op, **kw): def _clip(op, **kw): arg = translate_val(op.arg, **kw) if (upper := op.upper) is not None: - arg = sg.exp.Least.from_arg_list(translate_val(upper, **kw), arg) + arg = sg.exp.Least.from_arg_list([translate_val(upper, **kw), arg]) if (lower := op.lower) is not None: - arg = sg.exp.Greatest.from_arg_list(translate_val(lower, **kw), arg) + arg = sg.exp.Greatest.from_arg_list([translate_val(lower, **kw), arg]) return arg From 2cb5f121592548cb680521b0d8b4d71293af7646 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:25:03 -0400 Subject: [PATCH 081/222] chore: somehow the dynamic slice test is passing ... but y tho --- ibis/backends/tests/test_generic.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index c1448a2e719c..44f215b364e4 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1492,11 +1492,6 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): raises=HiveServer2Error, ) @pytest.mark.notyet(["pyspark"], reason="pyspark doesn't support dynamic limit/offset") -@pytest.mark.xfail_version( - duckdb=["duckdb<=0.8.1"], - raises=AssertionError, - reason="https://github.com/duckdb/duckdb/issues/8412", -) def test_dynamic_table_slice_with_computed_offset(backend): t = backend.functional_alltypes From 59bdc67527d9381043f0a06708d6a98e04c71122 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:27:30 -0400 Subject: [PATCH 082/222] test(duckdb): fix default backend test --- ibis/backends/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 94fe24a9b247..048a958c6f30 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -903,7 +903,7 @@ def test_default_backend(): rx = """\ SELECT SUM\\((\\w+)\\.a\\) AS ".+" -FROM \\w+ AS \\1""" +FROM "\\w+" AS \\1""" assert re.match(rx, sql) is not None From 42b177d740eec1684f8196dd103dc5ba404458bc Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:33:59 -0400 Subject: [PATCH 083/222] chore: remove sqlalchemy type annotations --- ibis/backends/duckdb/__init__.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 98356b232d51..2f742d742f61 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import ast -import contextlib import os import warnings from pathlib import Path @@ -26,9 +25,9 @@ from ibis.backends.base.sql import BaseBackend from ibis.backends.base.sqlglot.datatypes import DuckDBType from ibis.backends.duckdb.compiler import translate +from ibis.backends.duckdb.datatypes import DuckDBPandasData from ibis.expr.operations.relations import PandasDataFrameProxy from ibis.expr.operations.udf import InputType -from ibis.backends.duckdb.datatypes import DuckDBPandasData if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence @@ -1244,7 +1243,7 @@ def fetch_from_cursor( for name, col in zip(table.column_names, table.columns) } ) - return PandasData.convert_table(df, schema) + return DuckDBPandasData.convert_table(df, schema) def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: rows = self.raw_sql(f"DESCRIBE {query}").fetch_arrow_table() @@ -1302,9 +1301,7 @@ def _register(name, table): except duckdb.NotImplementedException: _register(name, data.to_pyarrow(schema)) - def _get_temp_view_definition( - self, name: str, definition: sa.sql.compiler.Compiled - ) -> str: + def _get_temp_view_definition(self, name: str, definition) -> str: yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" def _register_udfs(self, expr: ir.Expr) -> None: @@ -1346,7 +1343,7 @@ def _compile_udf(self, udf_node: ops.ScalarUDF) -> None: def _compile_pandas_udf(self, _: ops.ScalarUDF) -> None: raise NotImplementedError("duckdb doesn't support pandas UDFs") - def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable): + def _get_compiled_statement(self, view, definition): # TODO: remove this once duckdb supports CTAS prepared statements return super()._get_compiled_statement( view, definition, compile_kwargs={"literal_binds": True} From 9db95650478cb16f966fe46d784facabd1d98426 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:34:19 -0400 Subject: [PATCH 084/222] chore: shorten raw_sql a bit --- ibis/backends/duckdb/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 2f742d742f61..60188e3fe8d9 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import contextlib import os import warnings from pathlib import Path @@ -87,8 +88,8 @@ def current_database(self) -> str: def current_schema(self) -> str: return self.raw_sql("SELECT current_schema()") - def raw_sql(self, query: str, **kwargs: Any) -> Any: - if isinstance(query, sg.Expression): + def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): query = query.sql(dialect="duckdb") return self.con.execute(query, **kwargs) From 80eae44ff516e21dd9abe79cf020875cc664a8f2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:34:45 -0400 Subject: [PATCH 085/222] chore: clean up current_database and current_schema --- ibis/backends/duckdb/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 60188e3fe8d9..e4d34df31287 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -78,15 +78,11 @@ def _register_udfs(self, expr): @property def current_database(self) -> str: - return ( - self.raw_sql("PRAGMA database_size; CALL pragma_database_size();") - .arrow()["database_name"] - .to_pylist()[0] - ) + return self.raw_sql("SELECT CURRENT_DATABASE()").arrow()[0][0].as_py() @property def current_schema(self) -> str: - return self.raw_sql("SELECT current_schema()") + return self.raw_sql("SELECT CURRENT_SCHEMA()").arrow()[0][0].as_py() def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: with contextlib.suppress(AttributeError): From f34174864b77eec1f148546c7a18432ebcd8d567 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:39:14 -0400 Subject: [PATCH 086/222] chore: fix udfs --- ibis/backends/duckdb/__init__.py | 57 +++++++++++++++----------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index e4d34df31287..58593130e4f7 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -72,10 +72,6 @@ def _define_udf_translation_rules(self, expr): # TODO: ... - def _register_udfs(self, expr): - # TODO: - ... - @property def current_database(self) -> str: return self.raw_sql("SELECT CURRENT_DATABASE()").arrow()[0][0].as_py() @@ -1302,37 +1298,36 @@ def _get_temp_view_definition(self, name: str, definition) -> str: yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" def _register_udfs(self, expr: ir.Expr) -> None: - ... - # import ibis.expr.operations as ops + import ibis.expr.operations as ops - # with self.begin() as con: - # for udf_node in expr.op().find(ops.ScalarUDF): - # compile_func = getattr( - # self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - # ) - # with contextlib.suppress(duckdb.InvalidInputException): - # con.connection.remove_function(udf_node.__class__.__name__) + con = self.con - # registration_func = compile_func(udf_node) - # registration_func(con) + for udf_node in expr.op().find(ops.ScalarUDF): + compile_func = getattr( + self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + with contextlib.suppress(duckdb.InvalidInputException): + con.remove_function(udf_node.__class__.__name__) + + registration_func = compile_func(udf_node) + registration_func(con) def _compile_udf(self, udf_node: ops.ScalarUDF) -> None: - ... - # func = udf_node.__func__ - # name = func.__name__ - # input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args] - # output_type = DuckDBType.to_string(udf_node.dtype) - - # def register_udf(con): - # return con.connection.create_function( - # name, - # func, - # input_types, - # output_type, - # type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__], - # ) - - # return register_udf + func = udf_node.__func__ + name = func.__name__ + input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args] + output_type = DuckDBType.to_string(udf_node.dtype) + + def register_udf(con): + return con.create_function( + name, + func, + input_types, + output_type, + type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__], + ) + + return register_udf _compile_python_udf = _compile_udf _compile_pyarrow_udf = _compile_udf From 9298c75e16278ae8c7ef880a4908b529a1313c7b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:50:49 -0400 Subject: [PATCH 087/222] chore: remove xor mapping --- ibis/backends/duckdb/compiler/values.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index dd28975136b8..570ec691e203 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1489,7 +1489,6 @@ def formatter(op, **kw): ops.Less: sg.exp.LT, ops.Equals: sg.exp.EQ, ops.NotEquals: sg.exp.NEQ, - ops.Xor: sg.exp.Xor, # Boolean comparisons ops.And: sg.exp.And, ops.Or: sg.exp.Or, From 8f14c2999da159b3f935234aa12a2207aabe4250 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:51:04 -0400 Subject: [PATCH 088/222] revert: chore: make xor a bit less annoying This reverts commit 550c69f22650c40ecc09f8bda2cb8b4e65486d0b. --- ibis/backends/duckdb/compiler/values.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 570ec691e203..fd8d18a287ce 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1525,10 +1525,15 @@ def _bitor(op, **kw): @translate_val.register(ops.Xor) def _xor(op, **kw): - left = translate_val(ops.Cast(op.left, to=dt.int8), **kw) - right = translate_val(ops.Cast(op.right, to=dt.int8), **kw) - arg = sg.func("xor", left, right) - return sg.cast(arg, to=DuckDBType.from_ibis(dt.boolean)) + # TODO: is this really the best way to do this? + left = translate_val(op.left, **kw) + right = translate_val(op.right, **kw) + return sg.exp.And( + this=sg.exp.Paren(this=sg.exp.Or(this=left, expression=right)), + expression=sg.exp.Paren( + this=sg.exp.Not(this=sg.exp.And(this=left, expression=right)) + ), + ) ### Ordering From f24bcb00424a9e348b7ff59af7a7891766b65b26 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:51:43 -0400 Subject: [PATCH 089/222] chore: comment about upstream issue --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index fd8d18a287ce..22185ce001a6 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1525,7 +1525,7 @@ def _bitor(op, **kw): @translate_val.register(ops.Xor) def _xor(op, **kw): - # TODO: is this really the best way to do this? + # https://github.com/tobymao/sqlglot/issues/2238 left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) return sg.exp.And( From 08c84ae6de7f2127dbcfd95d4d0c6c6485919046 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 12:23:20 -0400 Subject: [PATCH 090/222] chore: implement and test `con.insert` --- ibis/backends/duckdb/__init__.py | 56 ++++++++++++++++------- ibis/backends/duckdb/tests/test_client.py | 28 ++++++++++++ 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 58593130e4f7..ca81980dfa39 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any import duckdb +import pandas as pd import pyarrow as pa import sqlglot as sg import toolz @@ -33,7 +34,6 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence - import pandas as pd import torch from ibis.common.typing import SupportsSchema @@ -1341,22 +1341,46 @@ def _get_compiled_statement(self, view, definition): view, definition, compile_kwargs={"literal_binds": True} ) - def _insert_dataframe( - self, table_name: str, df: pd.DataFrame, overwrite: bool + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + database: str | None = None, + overwrite: bool = False, ) -> None: - # TODO: reimplement - ... - # columns = list(df.columns) - # t = sa.table(table_name, *map(sa.column, columns)) + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + + Raises + ------ + NotImplementedError + If inserting data from a different database + ValueError + If the type of `obj` isn't supported + """ + con = self.con - # table_name = self._quote(table_name) + table = sg.table(table_name, db=database) - # # the table name df here matters, and *must* match the input variable's - # # name because duckdb will look up this name in the outer scope of the - # # insert call and pull in that variable's data to scan - # source = sa.table("df", *map(sa.column, columns)) + if overwrite: + con.execute(f"TRUNCATE TABLE {table.sql('duckdb')}") - # with self.begin() as con: - # if overwrite: - # con.execute(t.delete()) - # con.execute(t.insert().from_select(columns, sa.select(source))) + if isinstance(obj, ir.Table): + query = sg.exp.insert( + expression=self.compile(obj), into=table, dialect="duckdb" + ) + con.execute(query.sql("duckdb")) + elif isinstance(obj, pd.DataFrame): + con.append(table_name, obj) + else: + con.append(table_name, pd.DataFrame(obj)) diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index 690612915d4a..39cd09e7f8d4 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -55,3 +55,31 @@ def test_load_extension(ext_directory): """ ).fetchall() assert all(loaded for (loaded,) in results) + + +def test_insert(con): + import pandas as pd + + name = ibis.util.guid() + + t = con.create_table(name, schema=ibis.schema({"a": "int64"})) + con.insert(name, obj=pd.DataFrame({"a": [1, 2]})) + assert t.count().execute() == 2 + + con.insert(name, obj=pd.DataFrame({"a": [1, 2]})) + assert t.count().execute() == 4 + + con.insert(name, obj=pd.DataFrame({"a": [1, 2]}), overwrite=True) + assert t.count().execute() == 2 + + con.insert(name, t) + assert t.count().execute() == 4 + + con.insert(name, [{"a": 1}, {"a": 2}], overwrite=True) + assert t.count().execute() == 2 + + con.insert(name, [(1,), (2,)]) + assert t.count().execute() == 4 + + con.insert(name, {"a": [1, 2]}, overwrite=True) + assert t.count().execute() == 2 From 32c26d5913d158349fb9213f1aa0bbb9816d4067 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 16 Sep 2023 12:54:49 -0400 Subject: [PATCH 091/222] chore: fix all contains except InColumn --- ibis/backends/duckdb/compiler/values.py | 61 ++++--------------------- 1 file changed, 10 insertions(+), 51 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 22185ce001a6..34742cc2d863 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -4,7 +4,7 @@ import functools import math from functools import partial -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any import sqlglot as sg from toolz import flip @@ -14,8 +14,6 @@ import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.rules as rlz -from ibis.backends.base.sql.registry import helpers from ibis.backends.base.sqlglot.datatypes import DuckDBType if TYPE_CHECKING: @@ -934,11 +932,13 @@ def _array_index_op(op, **kw): @translate_val.register(ops.InValues) def _in_values(op, **kw): if not op.options: - return False + return sg.exp.FALSE + value = translate_val(op.value, **kw) - options = sg.exp.Array().from_arg_list([translate_val(x, **kw) for x in op.options]) - sg_expr = sg.func("list_contains", options, value, dialect="duckdb") - return sg_expr + return sg.exp.In( + this=value, + expressions=[translate_val(opt, **kw) for opt in op.options], + ) @translate_val.register(ops.InColumn) @@ -1236,14 +1236,10 @@ def _arbitrary(op, **kw): @translate_val.register(ops.FindInSet) -def _index_of(op, **kw): +def _index_of(op: ops.FindInSet, **kw): needle = translate_val(op.needle, **kw) - return ( - sg.func( - "list_indexof", list(map(partial(translate_val, **kw), op.values)), needle - ) - - 1 - ) + args = sg.exp.Array(expressions=list(map(partial(translate_val, **kw), op.values))) + return sg.func("list_indexof", args, needle) - 1 @translate_val.register(tuple) @@ -1352,43 +1348,6 @@ def _scalar_param(op, params: Mapping[ops.Node, Any], **kw): return translate_val(literal.op(), **kw) -# TODO -def contains(op_string: Literal["IN", "NOT IN"]) -> str: - def tr(op, *, cache, **kw): - from ibis.backends.duckdb.compiler import translate - - value = op.value - options = op.options - if isinstance(options, tuple) and not options: - return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string] - - left_arg = translate_val(value, **kw) - if helpers.needs_parens(value): - left_arg = helpers.parenthesize(left_arg) - - # special case non-foreign isin/notin expressions - if ( - not isinstance(options, tuple) - and options.output_shape is rlz.Shape.COLUMNAR - ): - # this will fail to execute if there's a correlation, but it's too - # annoying to detect so we let it through to enable the - # uncorrelated use case (pandas-style `.isin`) - subquery = translate(options.to_expr().as_table().op(), {}) - right_arg = f"({_sql(subquery)})" - else: - right_arg = _sql(translate_val(options, cache=cache, **kw)) - - # we explicitly do NOT parenthesize the right side because it doesn't - # make sense to do so for Sequence operations - return f"{left_arg} {op_string} {right_arg}" - - return tr - - -# TODO -# translate_val.register(ops.Contains)(contains("IN")) -# translate_val.register(ops.NotContains)(contains("NOT IN")) @translate_val.register(ops.IdenticalTo) def _identical_to(op, **kw): left = translate_val(op.left, **kw) From 8cc2c68d9b4556b0a9285de596e09ffeba2d0d3e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:03:21 -0400 Subject: [PATCH 092/222] chore: clean up `_JOIN_TYPES` --- ibis/backends/duckdb/compiler/relations.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 86f064d657a0..430b112e7e7f 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -95,15 +95,13 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): _JOIN_TYPES = { ops.InnerJoin: "INNER", - ops.AnyInnerJoin: "ANY", - ops.LeftJoin: "LEFT OUTER", - ops.AnyLeftJoin: "LEFT ANY", - ops.RightJoin: "RIGHT OUTER", - ops.OuterJoin: "FULL OUTER", + ops.LeftJoin: "LEFT", + ops.RightJoin: "RIGHT", + ops.OuterJoin: "FULL", ops.CrossJoin: "CROSS", - ops.LeftSemiJoin: "LEFT SEMI", - ops.LeftAntiJoin: "LEFT ANTI", - ops.AsOfJoin: "LEFT ASOF", + ops.LeftSemiJoin: "SEMI", + ops.LeftAntiJoin: "ANTI", + ops.AsOfJoin: "ASOF", } From c1bd46f3d437d44b0d5b62dc49f85424c00e67b8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:09:49 -0400 Subject: [PATCH 093/222] style: imports --- ibis/backends/duckdb/datatypes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py index 9ed770ea26aa..c45aac11780b 100644 --- a/ibis/backends/duckdb/datatypes.py +++ b/ibis/backends/duckdb/datatypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np + from ibis.formats.pandas import PandasData From 2ad80f2820b843ce54c5cbe9af2116e9dd8d36b5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:12:34 -0400 Subject: [PATCH 094/222] chore: eliminate dialect argument in duckdb sqlglot compiler --- ibis/backends/duckdb/compiler/relations.py | 38 ++++++--------- ibis/backends/duckdb/compiler/values.py | 55 +++++++++------------- 2 files changed, 37 insertions(+), 56 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 430b112e7e7f..723752be3df7 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -20,7 +20,7 @@ def translate_rel(op: ops.TableNode, **_): @translate_rel.register(ops.DummyTable) def _dummy(op: ops.DummyTable, **kw): - return sg.select(*map(partial(translate_val, **kw), op.values), dialect="duckdb") + return sg.select(*map(partial(translate_val, **kw), op.values)) @translate_rel.register(ops.PhysicalTable) @@ -43,9 +43,7 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): from_ = join = None tr_val = partial(translate_val, needs_alias=needs_alias, **kw) selections = tuple(map(tr_val, op.selections)) or "*" - sel = sg.select(*selections, dialect="duckdb").from_( - from_ if from_ is not None else table, dialect="duckdb" - ) + sel = sg.select(*selections).from_(from_ if from_ is not None else table) if join is not None: sel = sel.join(join) @@ -55,15 +53,12 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): sel = sg.select("*").from_(sel.subquery(kw["aliases"][op.table])) res = functools.reduce( lambda left, right: left.and_(right), - ( - sg.condition(tr_val(predicate), dialect="duckdb") - for predicate in predicates - ), + (sg.condition(tr_val(predicate)) for predicate in predicates), ) - sel = sel.where(res, dialect="duckdb") + sel = sel.where(res) if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val, sort_keys), dialect="duckdb") + sel = sel.order_by(*map(tr_val, sort_keys)) return sel @@ -79,16 +74,16 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): sel = sg.select(*selections).from_(table) if group_keys := op.by: - sel = sel.group_by(*map(tr_val_no_alias, group_keys), dialect="duckdb") + sel = sel.group_by(*map(tr_val_no_alias, group_keys)) if predicates := op.predicates: - sel = sel.where(*map(tr_val_no_alias, predicates), dialect="duckdb") + sel = sel.where(*map(tr_val_no_alias, predicates)) if having := op.having: - sel = sel.having(*map(tr_val_no_alias, having), dialect="duckdb") + sel = sel.having(*map(tr_val_no_alias, having)) if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val_no_alias, sort_keys), dialect="duckdb") + sel = sel.order_by(*map(tr_val_no_alias, sort_keys)) return sel @@ -111,16 +106,13 @@ def _join(op: ops.Join, *, left, right, **kw): if predicates: on = functools.reduce( lambda left, right: left.and_(right), - ( - sg.condition(translate_val(predicate, **kw), dialect="duckdb") - for predicate in predicates - ), + (sg.condition(translate_val(predicate, **kw)) for predicate in predicates), ) else: on = None join_type = _JOIN_TYPES[type(op)] try: - return left.join(right, join_type=join_type, on=on, dialect="duckdb") + return left.join(right, join_type=join_type, on=on) except AttributeError: select_args = [f"{left.alias_or_name}.*"] @@ -130,9 +122,7 @@ def _join(op: ops.Join, *, left, right, **kw): if not isinstance(op, (ops.LeftSemiJoin, ops.LeftAntiJoin)): select_args.append(f"{right.alias_or_name}.*") return ( - sg.select(*select_args, dialect="duckdb") - .from_(left, dialect="duckdb") - .join(right, join_type=join_type, on=on, dialect="duckdb") + sg.select(*select_args).from_(left).join(right, join_type=join_type, on=on) ) @@ -230,9 +220,9 @@ def _dropna(op: ops.DropNa, *, table, **kw): tr_val = partial(translate_val, **kw) predicate = tr_val(raw_predicate) try: - return table.where(predicate, dialect="duckdb") + return table.where(predicate) except AttributeError: - return sg.select("*").from_(table).where(predicate, dialect="duckdb") + return sg.select("*").from_(table).where(predicate) @translate_rel.register diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 34742cc2d863..11c43561566a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from collections.abc import Mapping -# TODO: Find a way to remove all the dialect="duckdb" kwargs - @functools.singledispatch def translate_val(op, **_): @@ -53,16 +51,16 @@ def _column(op, *, aliases, **_): def _alias(op, render_aliases: bool = True, **kw): val = translate_val(op.arg, render_aliases=render_aliases, **kw) if render_aliases: - return sg.alias(val, op.name, dialect="duckdb") + return sg.alias(val, op.name) return val ### Literals -def _sql(obj, dialect="duckdb"): +def _sql(obj): try: - return obj.sql(dialect=dialect) + return obj.sql(dialect="duckdb") except AttributeError: return obj @@ -285,9 +283,7 @@ def formatter(op, **kw): @translate_val.register(_op) def _fmt(op, _name: str = _name, **kw): - return sg.func( - _name, *map(partial(translate_val, **kw), op.args), dialect="duckdb" - ) + return sg.func(_name, *map(partial(translate_val, **kw), op.args)) del _fmt, _name, _op @@ -416,7 +412,7 @@ def _try_cast(op, **kw): @translate_val.register(ops.TypeOf) def _type_of(op, **kw): arg = translate_val(op.arg, **kw) - return sg.func("typeof", arg, dialect="duckdb") + return sg.func("typeof", arg) ### Comparator Conundrums @@ -456,7 +452,7 @@ def _aggregate(op, func, *, where, **kw): for argname, arg in zip(op.argnames, op.args) if argname not in ("where", "how") ] - agg = sg.func(func, *args, dialect="duckdb") + agg = sg.func(func, *args) return _apply_agg_filter(agg, where=op.where, **kw) @@ -742,7 +738,7 @@ def _interval_from_integer(op, **kw): arg = translate_val(op.arg, **kw) if op.dtype.resolution == "week": return sg.func("to_days", arg * 7) - return sg.func(f"to_{op.dtype.resolution}s", arg, dialect="duckdb") + return sg.func(f"to_{op.dtype.resolution}s", arg) ### String Instruments @@ -811,7 +807,7 @@ def _regex_extract(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) group = translate_val(op.index, **kw) - return sg.func("regexp_extract", arg, pattern, group, dialect="duckdb") + return sg.func("regexp_extract", arg, pattern, group) @translate_val.register(ops.Levenshtein) @@ -885,19 +881,19 @@ def _is_not_null(op, **kw): def _if_null(op, **kw): arg = translate_val(op.arg, **kw) ifnull = translate_val(op.ifnull_expr, **kw) - return sg.func("ifnull", arg, ifnull, dialect="duckdb") + return sg.func("ifnull", arg, ifnull) @translate_val.register(ops.NullIfZero) def _null_if_zero(op, **kw): arg = translate_val(op.arg, **kw) - return sg.func("nullif", arg, 0, dialect="duckdb") + return sg.func("nullif", arg, 0) @translate_val.register(ops.ZeroIfNull) def _zero_if_null(op, **kw): arg = translate_val(op.arg, **kw) - return sg.func("ifnull", arg, 0, dialect="duckdb") + return sg.func("ifnull", arg, 0) ### Definitely Not Tensors @@ -910,12 +906,11 @@ def _array_sort(op, **kw): sg_expr = sg.exp.If( this=arg.is_(sg.exp.Null()), true=sg.exp.Null(), - false=sg.func("list_distinct", arg, dialect="duckdb") + false=sg.func("list_distinct", arg) + sg.exp.If( - this=sg.func("list_count", arg, dialect="duckdb") - < sg.func("array_length", arg, dialect="duckdb"), - true=sg.func("list_value", sg.exp.Null(), dialect="duckdb"), - false=sg.func("list_value", dialect="duckdb"), + this=sg.func("list_count", arg) < sg.func("array_length", arg), + true=sg.func("list_value", sg.exp.Null()), + false=sg.func("list_value"), ), ) return sg_expr @@ -953,7 +948,7 @@ def _in_column(op, **kw): @translate_val.register(ops.ArrayCollect) def _array_collect(op, **kw): - agg = sg.func("list", translate_val(op.arg, **kw), dialect="duckdb") + agg = sg.func("list", translate_val(op.arg, **kw)) return _apply_agg_filter(agg, where=op.where, **kw) @@ -1157,7 +1152,7 @@ def _repeat(op, **kw): def _quantile(op, **kw): arg = translate_val(op.arg, **kw) quantile = translate_val(op.quantile, **kw) - sg_expr = sg.func("quantile_cont", arg, quantile, dialect="duckdb") + sg_expr = sg.func("quantile_cont", arg, quantile) return _apply_agg_filter(sg_expr, where=op.where, **kw) @@ -1291,11 +1286,7 @@ def _exists_subquery(op, **kw): kw["table"] = translate_rel(op.foreign_table.table, **kw) foreign_table = translate_rel(op.foreign_table, **kw) predicates = translate_val(op.predicates, **kw) - subq = ( - sg.select(1) - .from_(foreign_table, dialect="duckdb") - .where(sg.condition(predicates), dialect="duckdb") - ) + subq = sg.select(1).from_(foreign_table).where(sg.condition(predicates)) prefix = "NOT " * isinstance(op, ops.NotExistsSubquery) return f"{prefix}EXISTS ({subq})" @@ -1305,7 +1296,7 @@ def _group_concat(op, **kw): arg = translate_val(op.arg, **kw) sep = translate_val(op.sep, **kw) - concat = sg.func("string_agg", arg, sep, dialect="duckdb") + concat = sg.func("string_agg", arg, sep) return _apply_agg_filter(concat, where=op.where, **kw) @@ -1505,22 +1496,22 @@ def _row_number(_, **kw): @translate_val.register(ops.DenseRank) def _dense_rank(_, **kw): - return sg.func("dense_rank", dialect="duckdb") + return sg.func("dense_rank") @translate_val.register(ops.MinRank) def _rank(_, **kw): - return sg.func("rank", dialect="duckdb") + return sg.func("rank") @translate_val.register(ops.PercentRank) def _percent_rank(_, **kw): - return sg.func("percent_rank", dialect="duckdb") + return sg.func("percent_rank") @translate_val.register(ops.CumeDist) def _cume_dist(_, **kw): - return sg.func("percent_rank", dialect="duckdb") + return sg.func("percent_rank") @translate_val.register From 76bbe34d90d0e61e7ddb823cb2b45f045827e44f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:28:11 -0400 Subject: [PATCH 095/222] chore: lift DuckDBType.from_ibis out and fix binary literal test --- ibis/backends/duckdb/compiler/values.py | 26 ++++++++++--------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 11c43561566a..e2d61438c72f 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -73,11 +73,12 @@ def sg_literal(arg, is_string=True): def _literal(op, **kw): value = op.value dtype = op.dtype + sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: if dtype.is_null(): return sg.exp.Null() - return sg.cast(sg.exp.Null(), to=DuckDBType.from_ibis(dtype)) + return sg.cast(sg.exp.Null(), to=sg_type) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_inet(): @@ -108,9 +109,7 @@ def _literal(op, **kw): ) dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) - sg_expr = sg.cast( - sg_literal(value, is_string=False), to=DuckDBType.from_ibis(dtype) - ) + sg_expr = sg.cast(sg_literal(value, is_string=False), to=sg_type) return sg_expr elif dtype.is_numeric(): if math.isinf(value): @@ -123,10 +122,7 @@ def _literal(op, **kw): expression=sg_literal("NaN"), to=sg.exp.DataType.Type.FLOAT, ) - return sg.cast( - sg_literal(value, is_string=False), - to=DuckDBType.from_ibis(dtype), - ) + return sg.cast(sg_literal(value, is_string=False), to=sg_type) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): @@ -158,10 +154,7 @@ def _literal(op, **kw): [ # TODO: this cast makes for frustrating output # is there any better way to handle it? - sg.cast( - sg_literal(v, is_string=is_string), - to=DuckDBType.from_ibis(value_type), - ) + sg.cast(sg_literal(v, is_string=is_string), to=sg_type) for v in value ] ) @@ -187,10 +180,11 @@ def _literal(op, **kw): sg_expr = sg.exp.Struct.from_arg_list(slices) return sg_expr elif dtype.is_uuid(): - return sg.cast( - sg.exp.Literal(this=str(value), is_string=True), - to=sg.exp.DataType.Type.UUID, - ) + return sg.cast(sg_literal(value, is_string=True), to=sg_type) + elif dtype.is_binary(): + bytestring = "".join(map("\\x{:02x}".format, value)) + lit = sg_literal(bytestring) + return sg.cast(lit, to=sg_type) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") From 73f276bfdbae20d0cf705a092f965f372d7d0f1e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:30:10 -0400 Subject: [PATCH 096/222] chore: fix macaddr literal translation --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index e2d61438c72f..24b80b894a93 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -81,7 +81,7 @@ def _literal(op, **kw): return sg.cast(sg.exp.Null(), to=sg_type) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) - elif dtype.is_inet(): + elif dtype.is_inet() or dtype.is_macaddr(): return sg.exp.Literal(this=str(value), is_string=True) elif dtype.is_string(): return sg_literal(value) From f19d13dca313e38717b0926aeeeb6e3c7bdd723b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:38:17 -0400 Subject: [PATCH 097/222] chore: shorten literal translation code --- ibis/backends/duckdb/compiler/values.py | 98 +++++++++---------------- 1 file changed, 35 insertions(+), 63 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 24b80b894a93..0fe6a1ffd23f 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -76,14 +76,11 @@ def _literal(op, **kw): sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: - if dtype.is_null(): - return sg.exp.Null() - return sg.cast(sg.exp.Null(), to=sg_type) + null = sg.exp.Null() + return null if dtype.is_null() else sg.cast(null, to=sg_type) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) - elif dtype.is_inet() or dtype.is_macaddr(): - return sg.exp.Literal(this=str(value), is_string=True) - elif dtype.is_string(): + elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): return sg_literal(value) elif dtype.is_decimal(): # TODO: make this a sqlglot expression @@ -98,93 +95,68 @@ def _literal(op, **kw): f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" ) if math.isinf(value): - return sg.exp.cast( - expression=sg_literal(value), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.cast(expression=sg_literal(value), to=sg.exp.DataType.Type.FLOAT) elif math.isnan(value): - return sg.exp.cast( - expression=sg_literal("NaN"), - to=sg.exp.DataType.Type.FLOAT, - ) - - dtype = dt.Decimal(precision=precision, scale=scale, nullable=dtype.nullable) - sg_expr = sg.cast(sg_literal(value, is_string=False), to=sg_type) - return sg_expr + return sg.cast(expression=sg_literal("NaN"), to=sg.exp.DataType.Type.FLOAT) + return sg.cast(sg_literal(value, is_string=False), to=sg_type) elif dtype.is_numeric(): if math.isinf(value): - return sg.exp.cast( - expression=sg_literal(value), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.cast(expression=sg_literal(value), to=sg.exp.DataType.Type.FLOAT) elif math.isnan(value): - return sg.exp.cast( - expression=sg_literal("NaN"), - to=sg.exp.DataType.Type.FLOAT, - ) + return sg.cast(expression=sg_literal("NaN"), to=sg.exp.DataType.Type.FLOAT) return sg.cast(sg_literal(value, is_string=False), to=sg_type) elif dtype.is_interval(): return _interval_format(op) elif dtype.is_timestamp(): - year = sg_literal(op.value.year, is_string=False) - month = sg_literal(op.value.month, is_string=False) - day = sg_literal(op.value.day, is_string=False) - hour = sg_literal(op.value.hour, is_string=False) - minute = sg_literal(op.value.minute, is_string=False) - second = sg_literal(op.value.second, is_string=False) - if op.value.microsecond: - microsecond = sg_literal(op.value.microsecond / 1e6, is_string=False) + year = sg_literal(value.year, is_string=False) + month = sg_literal(value.month, is_string=False) + day = sg_literal(value.day, is_string=False) + hour = sg_literal(value.hour, is_string=False) + minute = sg_literal(value.minute, is_string=False) + second = sg_literal(value.second, is_string=False) + if us := value.microsecond: + microsecond = sg_literal(us / 1e6, is_string=False) second += microsecond - if dtype.timezone is not None: - timezone = sg_literal(dtype.timezone, is_string=True) + if (tz := dtype.timezone) is not None: + timezone = sg_literal(tz, is_string=True) return sg.func( "make_timestamptz", year, month, day, hour, minute, second, timezone ) else: return sg.func("make_timestamp", year, month, day, hour, minute, second) elif dtype.is_date(): - year = sg_literal(op.value.year, is_string=False) - month = sg_literal(op.value.month, is_string=False) - day = sg_literal(op.value.day, is_string=False) + year = sg_literal(value.year, is_string=False) + month = sg_literal(value.month, is_string=False) + day = sg_literal(value.day, is_string=False) return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): - value_type = dtype.value_type - is_string = isinstance(value_type, dt.String) - values = sg.exp.Array().from_arg_list( - [ - # TODO: this cast makes for frustrating output - # is there any better way to handle it? - sg.cast(sg_literal(v, is_string=is_string), to=sg_type) - for v in value - ] + is_string = dtype.value_type.is_string() + return sg.exp.Array.from_arg_list( + [sg.cast(sg_literal(v, is_string=is_string), to=sg_type) for v in value] ) - return values elif dtype.is_map(): key_type = dtype.key_type value_type = dtype.value_type - keys = sg.exp.Array().from_arg_list( + keys = sg.exp.Array.from_arg_list( [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] ) - values = sg.exp.Array().from_arg_list( + values = sg.exp.Array.from_arg_list( [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] ) - sg_expr = sg.exp.Map(keys=keys, values=values) - return sg_expr + return sg.exp.Map(keys=keys, values=values) elif dtype.is_struct(): - keys = [sg_literal(key) for key in value.keys()] + keys = list(map(sg_literal, value.keys())) values = [ - _literal(ops.Literal(v, dtype=subdtype), **kw) - for subdtype, v in zip(dtype.types, value.values()) + _literal(ops.Literal(v, dtype=field_dtype), **kw) + for field_dtype, v in zip(dtype.types, value.values()) ] - slices = [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] - sg_expr = sg.exp.Struct.from_arg_list(slices) - return sg_expr + return sg.exp.Struct.from_arg_list( + [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] + ) elif dtype.is_uuid(): - return sg.cast(sg_literal(value, is_string=True), to=sg_type) + return sg.cast(sg_literal(value), to=sg_type) elif dtype.is_binary(): - bytestring = "".join(map("\\x{:02x}".format, value)) - lit = sg_literal(bytestring) - return sg.cast(lit, to=sg_type) + return sg.cast(sg_literal("".join(map("\\x{:02x}".format, value))), to=sg_type) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") From c1edafd13a0485be539b249b48de95d3b18f658a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:49:54 -0400 Subject: [PATCH 098/222] chore: fix literal array translation of value type --- ibis/backends/duckdb/compiler/values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 0fe6a1ffd23f..d168721764b1 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -130,9 +130,9 @@ def _literal(op, **kw): day = sg_literal(value.day, is_string=False) return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): - is_string = dtype.value_type.is_string() + value_type = dtype.value_type return sg.exp.Array.from_arg_list( - [sg.cast(sg_literal(v, is_string=is_string), to=sg_type) for v in value] + [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value] ) elif dtype.is_map(): key_type = dtype.key_type From dcef166252cf38c0258ca4c9593743e1e41bf65a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 05:57:42 -0400 Subject: [PATCH 099/222] chore: fix approx median --- ibis/backends/duckdb/compiler/values.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index d168721764b1..399881261cec 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1582,7 +1582,9 @@ def format_window_frame(func, frame, **kw): @translate_val.register(ops.ApproxMedian) def _approx_median(op, **kw): - expr = sg.func("approx_quantile", "0.5", translate_val(op.arg)) + expr = sg.func( + "approx_quantile", translate_val(op.arg, **kw), sg_literal(0.5, is_string=False) + ) return _apply_agg_filter(expr, where=op.where, **kw) From e0f50deefb48cb9ea15c766e817087eee3639857 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:12:33 -0400 Subject: [PATCH 100/222] chore: unify decimal and non-decimal literal translation --- ibis/backends/duckdb/compiler/values.py | 30 ++++++++----------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 399881261cec..b264ed526f08 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -82,28 +82,16 @@ def _literal(op, **kw): return sg.exp.Boolean(this=value) elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): return sg_literal(value) - elif dtype.is_decimal(): - # TODO: make this a sqlglot expression - precision = dtype.precision - scale = dtype.scale - if precision is None: - precision = 38 - if scale is None: - scale = 9 - if not 1 <= precision <= 38: - raise NotImplementedError( - f"Unsupported precision. Supported values: [1 : 38]. Current value: {precision!r}" - ) - if math.isinf(value): - return sg.cast(expression=sg_literal(value), to=sg.exp.DataType.Type.FLOAT) - elif math.isnan(value): - return sg.cast(expression=sg_literal("NaN"), to=sg.exp.DataType.Type.FLOAT) - return sg.cast(sg_literal(value, is_string=False), to=sg_type) elif dtype.is_numeric(): - if math.isinf(value): - return sg.cast(expression=sg_literal(value), to=sg.exp.DataType.Type.FLOAT) - elif math.isnan(value): - return sg.cast(expression=sg_literal("NaN"), to=sg.exp.DataType.Type.FLOAT) + # cast non finite values to float because that's the behavior of + # duckdb when a mixed decimal/float operation is performed + # + # float will be upcast to double if necessary by duckdb + if not math.isfinite(value): + return sg.cast( + sg_literal(value), + to=sg.exp.DataType.Type.FLOAT if dtype.is_decimal() else sg_type, + ) return sg.cast(sg_literal(value, is_string=False), to=sg_type) elif dtype.is_interval(): return _interval_format(op) From 5e9544548bd1087f843756592fc89af3fa82a826 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:16:38 -0400 Subject: [PATCH 101/222] chore: fix isin/notin column --- ibis/backends/duckdb/compiler/values.py | 7 +++---- .../snapshots/test_sql/test_isin_bug/duckdb/out.sql | 12 ++++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b264ed526f08..c112827987e4 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -892,11 +892,10 @@ def _in_values(op, **kw): @translate_val.register(ops.InColumn) def _in_column(op, **kw): + from ibis.backends.duckdb.compiler import translate + value = translate_val(op.value, **kw) - options = translate_val(ops.TableArrayView(op.options.to_expr().as_table()), **kw) - # TODO: fix? - # if not isinstance(options, sa.sql.Selectable): - # options = sg.select(options) + options = translate(op.options.to_expr().as_table().op(), {}) return value.isin(options) diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql index 218ccb1d5c46..3c8533744eab 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql @@ -4,10 +4,14 @@ SELECT t1.x FROM ( SELECT - t0.x AS x - FROM t AS t0 + * + FROM "t" AS t0 WHERE - t0.x > CAST(2 AS TINYINT) + ( + t0.x + ) > ( + CAST(2 AS TINYINT) + ) ) AS t1 ) AS "InColumn(x, x)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file From 4d0e68d73b7d1706a2eb4333bf208ff289aaf6ff Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:19:17 -0400 Subject: [PATCH 102/222] chore: fix regexp_extract; dialect="duckdb" is required to indicate support for the `position` argument --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index c112827987e4..87bdfb331179 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -761,7 +761,7 @@ def _regex_extract(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) group = translate_val(op.index, **kw) - return sg.func("regexp_extract", arg, pattern, group) + return sg.func("regexp_extract", arg, pattern, group, dialect="duckdb") @translate_val.register(ops.Levenshtein) From e806aa6b3eb710ce5dbd0edbd89fda8f1af7fa65 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:28:28 -0400 Subject: [PATCH 103/222] chore: extract sg.exp.Null() into NULL --- ibis/backends/duckdb/compiler/values.py | 74 +++++++++---------------- 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 87bdfb331179..bc459637b9b1 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -20,6 +20,9 @@ from collections.abc import Mapping +NULL = sg.exp.Null() + + @functools.singledispatch def translate_val(op, **_): """Translate a value expression into sqlglot.""" @@ -76,7 +79,7 @@ def _literal(op, **kw): sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: - null = sg.exp.Null() + null = NULL return null if dtype.is_null() else sg.cast(null, to=sg_type) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) @@ -359,7 +362,6 @@ def _try_cast(op, **kw): return sg.exp.TryCast( this=translate_val(op.arg, **kw), to=DuckDBType.to_string(op.to), - dialect="duckdb", ) @@ -407,7 +409,7 @@ def _aggregate(op, func, *, where, **kw): if argname not in ("where", "how") ] agg = sg.func(func, *args) - return _apply_agg_filter(agg, where=op.where, **kw) + return _apply_agg_filter(agg, where=where, **kw) @translate_val.register(ops.Any) @@ -710,7 +712,7 @@ def _substring(op, **kw): if_pos = sg.exp.Substring(this=arg, start=start + 1, length=length) if_neg = sg.exp.Substring(this=arg, start=start, length=length) - sg_expr = sg.exp.If( + return sg.exp.If( this=sg.exp.GTE( this=start, expression=sg.exp.Literal(this="0", is_string=False) ), @@ -718,8 +720,6 @@ def _substring(op, **kw): false=if_neg, ) - return sg_expr - @translate_val.register(ops.StringFind) def _string_find(op, **kw): @@ -752,7 +752,6 @@ def _regex_replace(op, **kw): pattern, replacement, sg.exp.Literal(this="g", is_string=True), - dialect="duckdb", ) @@ -789,8 +788,7 @@ def _string_join(op, **kw): @translate_val.register(ops.StringConcat) def _string_concat(op, **kw): - arg = map(partial(translate_val, **kw), op.arg) - return sg.exp.Concat(expressions=list(arg)) + return sg.exp.Concat(expressions=list(map(partial(translate_val, **kw), op.arg))) @translate_val.register(ops.StringSQLLike) @@ -821,14 +819,12 @@ def _string_capitalize(op, **kw): ### NULL PLAYER CHARACTER @translate_val.register(ops.IsNull) def _is_null(op, **kw): - arg = translate_val(op.arg, **kw) - return arg.is_(sg.exp.null()) + return translate_val(op.arg, **kw).is_(sg.exp.null()) @translate_val.register(ops.NotNull) def _is_not_null(op, **kw): - arg = translate_val(op.arg, **kw) - return arg.is_(sg.not_(sg.exp.null())) + return translate_val(op.arg, **kw).is_(sg.not_(sg.exp.null())) @translate_val.register(ops.IfNull) @@ -857,17 +853,16 @@ def _zero_if_null(op, **kw): def _array_sort(op, **kw): arg = translate_val(op.arg, **kw) - sg_expr = sg.exp.If( - this=arg.is_(sg.exp.Null()), - true=sg.exp.Null(), + return sg.exp.If( + this=arg.is_(NULL), + true=NULL, false=sg.func("list_distinct", arg) + sg.exp.If( this=sg.func("list_count", arg) < sg.func("array_length", arg), - true=sg.func("list_value", sg.exp.Null()), + true=sg.func("list_value", NULL), false=sg.func("list_value"), ), ) - return sg_expr @translate_val.register(ops.ArrayIndex) @@ -907,28 +902,22 @@ def _array_collect(op, **kw): @translate_val.register(ops.ArrayConcat) def _array_concat(op, **kw): - sg_expr = sg.func( + return sg.func( "flatten", - sg.func( - "list_value", - *(translate_val(arg, **kw) for arg in op.arg), - ), - dialect="duckdb", + sg.func("list_value", *(translate_val(arg, **kw) for arg in op.arg)), ) - return sg_expr @translate_val.register(ops.ArrayRepeat) def _array_repeat_op(op, **kw): arg = translate_val(op.arg, **kw) times = translate_val(op.times, **kw) - sg_expr = sg.func( + return sg.func( "flatten", sg.select( sg.func("array", sg.select(arg).from_(sg.func("range", times))) ).subquery(), ) - return sg_expr def _neg_idx_to_pos(array, idx): @@ -958,7 +947,7 @@ def _array_slice_op(op, **kw): start = sg.func("least", arg_length, _neg_idx_to_pos(arg, start)) if (stop := op.stop) is None: - stop = sg.exp.Null() + stop = NULL else: stop = _neg_idx_to_pos(arg, translate_val(stop, **kw)) @@ -980,8 +969,7 @@ def _array_map(op, **kw): this=result, expressions=[sg.to_identifier(f"{op.parameter}", quoted=False)], ) - sg_expr = sg.func("list_transform", arg, lamduh) - return sg_expr + return sg.func("list_transform", arg, lamduh) @translate_val.register(ops.ArrayFilter) @@ -992,8 +980,7 @@ def _array_filter(op, **kw): this=result, expressions=[sg.exp.Identifier(this=f"{op.parameter}", quoted=False)], ) - sg_expr = sg.func("list_filter", arg, lamduh) - return sg_expr + return sg.func("list_filter", arg, lamduh) @translate_val.register(ops.ArrayIntersect) @@ -1034,7 +1021,7 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: ] ) lamduh = sg.exp.Lambda(this=result, expressions=[i]) - sg_expr = sg.func( + return sg.func( "list_transform", sg.func( "range", @@ -1043,11 +1030,8 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: sg.func("greatest", *[sg.func("len", arg) for arg in args]) + 1, ), lamduh, - dialect="duckdb", ) - return sg_expr - ### Counting @@ -1304,9 +1288,7 @@ def _identical_to(op, **kw): @translate_val.register(ops.Coalesce) def _vararg_func(op, **kw): return sg.func( - f"{op.__class__.__name__.lower()}", - *map(partial(translate_val, **kw), op.arg), - dialect="duckdb", + f"{op.__class__.__name__.lower()}", *map(partial(translate_val, **kw), op.arg) ) @@ -1314,8 +1296,7 @@ def _vararg_func(op, **kw): def _map(op, **kw): keys = translate_val(op.keys, **kw) values = translate_val(op.values, **kw) - sg_expr = sg.exp.Map(keys=keys, values=values) - return sg_expr + return sg.exp.Map(keys=keys, values=values) @translate_val.register(ops.MapGet) @@ -1323,20 +1304,16 @@ def _map_get(op, **kw): arg = translate_val(op.arg, **kw) key = translate_val(op.key, **kw) default = translate_val(op.default, **kw) - sg_expr = sg.func( - "ifnull", - sg.func("list_extract", sg.func("element_at", arg, key), 1), - default, - dialect="duckdb", + return sg.func( + "ifnull", sg.func("list_extract", sg.func("element_at", arg, key), 1), default ) - return sg_expr @translate_val.register(ops.MapContains) def _map_contains(op, **kw): arg = translate_val(op.arg, **kw) key = translate_val(op.key, **kw) - sg_expr = sg.exp.NEQ( + return sg.exp.NEQ( this=sg.func( "array_length", sg.func( @@ -1347,7 +1324,6 @@ def _map_contains(op, **kw): ), expression=sg.exp.Literal(this="0", is_string=False), ) - return sg_expr def _is_map_literal(op): From d2e9113581bfbfd5e431a4e11b951f833fc514fe Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:19:59 -0400 Subject: [PATCH 104/222] chore: move a bunch of functions to automatic registration --- ibis/backends/duckdb/compiler/values.py | 147 +++++------------------- 1 file changed, 30 insertions(+), 117 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index bc459637b9b1..efd5418ca4a5 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -190,6 +190,13 @@ def _literal(op, **kw): ops.First: "first", ops.Last: "last", ops.Count: "count", + ops.All: "bool_and", + ops.Any: "bool_or", + ops.ArrayCollect: "list", + ops.GroupConcat: "string_agg", + ops.BitOr: "bit_or", + ops.BitAnd: "bit_and", + ops.BitXor: "bit_xor", # string operations ops.StringContains: "contains", ops.StringLength: "length", @@ -222,6 +229,13 @@ def _literal(op, **kw): ops.LastValue: "last_value", ops.NTile: "ntile", ops.Hash: "hash", + ops.TimeFromHMS: "make_time", + ops.StringToTimestamp: "strptime", + ops.Levenshtein: "levenshtein", + ops.Repeat: "repeat", + ops.Map: "map", + ops.MapMerge: "map_concat", + ops.JSONGetItem: "json_extract", } @@ -412,20 +426,6 @@ def _aggregate(op, func, *, where, **kw): return _apply_agg_filter(agg, where=where, **kw) -@translate_val.register(ops.Any) -def _any(op, **kw): - arg = translate_val(op.arg, **kw) - any_expr = sg.func("bool_or", arg) - return _apply_agg_filter(any_expr, where=op.where, **kw) - - -@translate_val.register(ops.All) -def _all(op, **kw): - arg = translate_val(op.arg, **kw) - all_expr = sg.func("bool_and", arg) - return _apply_agg_filter(all_expr, where=op.where, **kw) - - @translate_val.register(ops.NotAny) def _not_any(op, **kw): return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) @@ -459,14 +459,6 @@ def _time(op, **kw): return sg.cast(expression=arg, to=sg.exp.DataType.Type.TIME) -@translate_val.register(ops.TimeFromHMS) -def _time_from_hms(op, **kw): - hours = translate_val(op.hours, **kw) - minutes = translate_val(op.minutes, **kw) - seconds = translate_val(op.seconds, **kw) - return sg.func("make_time", hours, minutes, seconds) - - @translate_val.register(ops.TimestampNow) def _timestamp_now(op, **kw): """DuckDB current timestamp defaults to timestamp + tz.""" @@ -491,10 +483,13 @@ def _timestamp_from_ymdhms(op, **kw): minute = translate_val(op.minutes, **kw) second = translate_val(op.seconds, **kw) + args = [year, month, day, hour, minute, second] + + func = "make_timestamp" if (timezone := op.dtype.timezone) is not None: - return f"make_timestamptz({year}, {month}, {day}, {hour}, {minute}, {second}, '{timezone}')" - else: - return f"make_timestamp({year}, {month}, {day}, {hour}, {minute}, {second})" + func += "tz" + args.append(sg.exp.Literal(this=timezone, is_string=True)) + return sg.func(func, *args) @translate_val.register(ops.Strftime) @@ -508,13 +503,6 @@ def _strftime(op, **kw): return sg.func("strftime", arg, format_str) -@translate_val.register(ops.StringToTimestamp) -def _string_to_timestamp(op, **kw): - arg = translate_val(op.arg, **kw) - format_str = translate_val(op.format_str, **kw) - return sg.func("strptime", arg, format_str) - - @translate_val.register(ops.ExtractEpochSeconds) def _extract_epoch_seconds(op, **kw): arg = translate_val(op.arg, **kw) @@ -608,13 +596,13 @@ def _truncate(op, **kw): except KeyError: raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") - return f"date_trunc('{duckunit}', {arg})" + return sg.func("date_trunc", duckunit, arg) @translate_val.register(ops.DayOfWeekIndex) def _day_of_week_index(op, **kw): arg = translate_val(op.arg, **kw) - return f"(dayofweek({arg}) + 6) % 7" + return (sg.func("dayofweek", arg) + 6) % 7 @translate_val.register(ops.DayOfWeekName) @@ -763,13 +751,6 @@ def _regex_extract(op, **kw): return sg.func("regexp_extract", arg, pattern, group, dialect="duckdb") -@translate_val.register(ops.Levenshtein) -def _levenshtein(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.func("levenshtein", left, right) - - @translate_val.register(ops.StringSplit) def _string_split(op, **kw): arg = translate_val(op.arg, **kw) @@ -869,8 +850,8 @@ def _array_sort(op, **kw): def _array_index_op(op, **kw): arg = translate_val(op.arg, **kw) index = translate_val(op.index, **kw) - correct_idx = f"if({index} >= 0, {index} + 1, {index})" - return f"array_extract({arg}, {correct_idx})" + correct_idx = sg.func("if", index >= 0, index + 1, index) + return sg.func("list_extract", arg, correct_idx) @translate_val.register(ops.InValues) @@ -894,12 +875,6 @@ def _in_column(op, **kw): return value.isin(options) -@translate_val.register(ops.ArrayCollect) -def _array_collect(op, **kw): - agg = sg.func("list", translate_val(op.arg, **kw)) - return _apply_agg_filter(agg, where=op.where, **kw) - - @translate_val.register(ops.ArrayConcat) def _array_concat(op, **kw): return sg.func( @@ -958,7 +933,7 @@ def _array_slice_op(op, **kw): def _array_string_join(op, **kw): arg = translate_val(op.arg, **kw) sep = translate_val(op.sep, **kw) - return f"list_aggregate({arg}, 'string_agg', {sep})" + return sg.func("list_aggr", arg, sg_literal("string_agg"), sep) @translate_val.register(ops.ArrayMap) @@ -1055,8 +1030,7 @@ def _count_distinct_star(op, **kw): @translate_val.register(ops.CountStar) def _count_star(op, **kw): - sql = sg.exp.Count(this=sg.exp.Star()) - return _apply_agg_filter(sql, where=op.where, **kw) + return _apply_agg_filter(sg.exp.Count(this=sg.exp.Star()), where=op.where, **kw) @translate_val.register(ops.Sum) @@ -1074,13 +1048,6 @@ def _nth_value(op, **kw): return sg.func("nth_value", arg, nth + 1) -@translate_val.register(ops.Repeat) -def _repeat(op, **kw): - arg = translate_val(op.arg, **kw) - times = translate_val(op.times, **kw) - return sg.func("repeat", arg, times) - - ### Stats @@ -1228,24 +1195,15 @@ def _exists_subquery(op, **kw): return f"{prefix}EXISTS ({subq})" -@translate_val.register(ops.GroupConcat) -def _group_concat(op, **kw): - arg = translate_val(op.arg, **kw) - sep = translate_val(op.sep, **kw) - - concat = sg.func("string_agg", arg, sep) - return _apply_agg_filter(concat, where=op.where, **kw) - - @translate_val.register(ops.ArrayColumn) def _array_column(op, **kw): - return sg.exp.Array.from_arg_list([translate_val(col, **kw) for col in op.cols]) + return sg.exp.Array(expressions=[translate_val(col, **kw) for col in op.cols]) @translate_val.register(ops.StructColumn) def _struct_column(op, **kw): - return sg.exp.Struct( - expressions=[ + return sg.exp.Struct.from_arg_list( + [ sg.exp.Slice( this=sg.exp.Literal(this=name, is_string=True), expression=translate_val(value, **kw), @@ -1292,13 +1250,6 @@ def _vararg_func(op, **kw): ) -@translate_val.register(ops.Map) -def _map(op, **kw): - keys = translate_val(op.keys, **kw) - values = translate_val(op.values, **kw) - return sg.exp.Map(keys=keys, values=values) - - @translate_val.register(ops.MapGet) def _map_get(op, **kw): arg = translate_val(op.arg, **kw) @@ -1326,21 +1277,6 @@ def _map_contains(op, **kw): ) -def _is_map_literal(op): - return isinstance(op, ops.Literal) or ( - isinstance(op, ops.Map) - and isinstance(op.keys, ops.Literal) - and isinstance(op.values, ops.Literal) - ) - - -@translate_val.register(ops.MapMerge) -def _map_merge(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - return sg.func("map_concat", left, right) - - def _binary_infix(sg_expr: sg.exp._Expression): def formatter(op, **kw): left = translate_val(op.left, **kw) @@ -1386,22 +1322,6 @@ def formatter(op, **kw): del _op, _sym -_bit_agg = { - ops.BitOr: "bit_or", - ops.BitAnd: "bit_and", - ops.BitXor: "bit_xor", -} - - -@translate_val.register(ops.BitAnd) -@translate_val.register(ops.BitOr) -@translate_val.register(ops.BitXor) -def _bitor(op, **kw): - arg = translate_val(op.arg, **kw) - bit_expr = sg.func(_bit_agg[type(op)], arg) - return _apply_agg_filter(bit_expr, where=op.where, **kw) - - @translate_val.register(ops.Xor) def _xor(op, **kw): # https://github.com/tobymao/sqlglot/issues/2238 @@ -1604,14 +1524,7 @@ def formatter(op, **kw): @translate_val.register(ops.Argument) def _argument(op, **_): - return sg.exp.Identifier(this=op.name, quoted=False) - - -@translate_val.register(ops.JSONGetItem) -def _json_getitem(op, **kw): - return sg.exp.JSONExtract( - this=translate_val(op.arg, **kw), expression=translate_val(op.index, **kw) - ) + return sg.to_identifier(op.name) @translate_val.register(ops.RowID) From 06a763e264a61d029952146f1c387e06377e6819 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:20:18 -0400 Subject: [PATCH 105/222] chore: fix cume_dist translation --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index efd5418ca4a5..9f3361daba29 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1360,7 +1360,7 @@ def _percent_rank(_, **kw): @translate_val.register(ops.CumeDist) def _cume_dist(_, **kw): - return sg.func("percent_rank") + return sg.func("cume_dist") @translate_val.register From dce28cb62a9797ff96fda3b79ae67cb42c889403 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:20:39 -0400 Subject: [PATCH 106/222] chore: move order by keys to sqlglot construct --- ibis/backends/duckdb/compiler/values.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 9f3361daba29..4e38699f1489 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1366,8 +1366,7 @@ def _cume_dist(_, **kw): @translate_val.register def _sort_key(op: ops.SortKey, **kw): arg = translate_val(op.expr, **kw) - direction = "ASC" if op.ascending else "DESC" - return f"{_sql(arg)} {direction}" + return sg.exp.Ordered(this=arg, desc=not op.ascending) ### Window functions From f3e62bc96aa66cccdbbc48d1b0c81f9446cfe34f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:21:19 -0400 Subject: [PATCH 107/222] chore: translate window functions to sqlglot objects --- ibis/backends/duckdb/compiler/values.py | 113 ++++++++++-------------- 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 4e38699f1489..360426f98f27 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1389,64 +1389,16 @@ def cumulative_to_window(func, frame): return new_expr.op() -def format_window_boundary(boundary, **kw): - value = translate_val(boundary.value, **kw) - if boundary.preceding: - return f"{value} PRECEDING" - else: - return f"{value} FOLLOWING" - - -# TODO -def format_window_frame(func, frame, **kw): - components = [] - - if frame.how == "rows" and frame.max_lookback is not None: - raise NotImplementedError( - "Rows with max lookback is not implemented for the Duckdb backend." - ) - - if frame.group_by: - partition_args = ", ".join( - map(_sql, map(partial(translate_val, **kw), frame.group_by)) - ) - components.append(f"PARTITION BY {partition_args}") - - if frame.order_by: - order_args = ", ".join( - map(_sql, map(partial(translate_val, **kw), frame.order_by)) - ) - components.append(f"ORDER BY {order_args}") - - frame_clause_not_allowed = ( - ops.Lag, - ops.Lead, - ops.DenseRank, - ops.MinRank, - ops.NTile, - ops.PercentRank, - ops.CumeDist, - ops.RowNumber, - ) - - if frame.start is None and frame.end is None: - # no-op, default is full sample - pass - elif not isinstance(func, frame_clause_not_allowed): - if frame.start is None: - start = "UNBOUNDED PRECEDING" - else: - start = format_window_boundary(frame.start, **kw) - - if frame.end is None: - end = "UNBOUNDED FOLLOWING" - else: - end = format_window_boundary(frame.end, **kw) - - frame = f"{frame.how.upper()} BETWEEN {start} AND {end}" - components.append(frame) - - return f"OVER ({' '.join(components)})" +_FRAME_CLAUSE_NOT_ALLOWED = ( + ops.Lag, + ops.Lead, + ops.DenseRank, + ops.MinRank, + ops.NTile, + ops.PercentRank, + ops.CumeDist, + ops.RowNumber, +) # TODO @@ -1477,15 +1429,46 @@ def _window(op: ops.WindowFunction, **kw: Any): arg = cumulative_to_window(op.func, op.frame) return translate_val(arg, **kw) - window_formatted = format_window_frame(op, op.frame, **kw) - func = op.func.__window_op__ - func_formatted = translate_val(func, **kw) - result = f"{func_formatted} {window_formatted}" + func = op.func + frame = op.frame + tr_val = partial(translate_val, **kw) - if isinstance(func, ops.RankBase): - return f"({result} - 1)" + if frame.start is None: + start = "UNBOUNDED" + else: + start = tr_val(frame.start.value, **kw) + + if frame.end is None: + end = "UNBOUNDED" + else: + end = tr_val(frame.end.value, **kw) + + spec = sg.exp.WindowSpec( + kind=frame.how, + start=start, + start_side="preceding", + end=end, + end_side="following", + over="OVER", + ) - return result + if frame.group_by: + partition_by = list(map(tr_val, frame.group_by)) + else: + partition_by = None + + if frame.order_by: + order = sg.exp.Order(expressions=list(map(tr_val, frame.order_by))) + else: + order = None + + this = tr_val(func, **kw) + window = sg.exp.Window(this=this, partition_by=partition_by, order=order, spec=spec) + + # preserve zero-based indexing + if isinstance(func, ops.RankBase): + return window - 1 + return window def shift_like(op_class, name): From d4bdd4884396219ba968e052958a3e5d17621be5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:27:07 -0400 Subject: [PATCH 108/222] chore: regen window function snapshots --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 360426f98f27..ca22add525c2 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1444,7 +1444,7 @@ def _window(op: ops.WindowFunction, **kw: Any): end = tr_val(frame.end.value, **kw) spec = sg.exp.WindowSpec( - kind=frame.how, + kind=frame.how.upper(), start=start, start_side="preceding", end=end, From 8cd86292cce058f066d97301bbef6f944e487ac4 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:29:11 -0400 Subject: [PATCH 109/222] chore: delete window function TODO comment --- ibis/backends/duckdb/compiler/values.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index ca22add525c2..b061ca3a1074 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1422,7 +1422,6 @@ def _approx_median(op, **kw): return _apply_agg_filter(expr, where=op.where, **kw) -# TODO @translate_val.register(ops.WindowFunction) def _window(op: ops.WindowFunction, **kw: Any): if isinstance(op.func, ops.CumulativeOp): From a1d3e44c30022642de1d23e4da801829cea96542 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:38:17 -0400 Subject: [PATCH 110/222] chore: implement exists/not exists --- ibis/backends/duckdb/compiler/values.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b061ca3a1074..6f978f0f8df3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1180,19 +1180,22 @@ def _table_array_view(op, *, cache, **kw): return res.subquery() -# TODO @translate_val.register(ops.ExistsSubquery) -@translate_val.register(ops.NotExistsSubquery) def _exists_subquery(op, **kw): - from ibis.backends.duckdb.compiler.relations import translate_rel + from ibis.backends.clickhouse.compiler import translate - if "table" not in kw: - kw["table"] = translate_rel(op.foreign_table.table, **kw) - foreign_table = translate_rel(op.foreign_table, **kw) + foreign_table = translate(op.foreign_table, {}) predicates = translate_val(op.predicates, **kw) - subq = sg.select(1).from_(foreign_table).where(sg.condition(predicates)) - prefix = "NOT " * isinstance(op, ops.NotExistsSubquery) - return f"{prefix}EXISTS ({subq})" + return sg.exp.Exists( + this=sg.select(1) + .from_(foreign_table.subquery()) + .where(sg.condition(predicates)) + ) + + +@translate_val.register(ops.NotExistsSubquery) +def _not_exists_subquery(op, **kw): + return sg.not_(_exists_subquery(op, **kw)) @translate_val.register(ops.ArrayColumn) From d6657bb7ea71627af26a16c8e6b93c7f549b5e23 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:40:25 -0400 Subject: [PATCH 111/222] chore: remove unnecessary `_FRAME_CLAUSE_NOT_ALLOWED` constant --- ibis/backends/duckdb/compiler/values.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 6f978f0f8df3..b405d5334d94 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1392,18 +1392,6 @@ def cumulative_to_window(func, frame): return new_expr.op() -_FRAME_CLAUSE_NOT_ALLOWED = ( - ops.Lag, - ops.Lead, - ops.DenseRank, - ops.MinRank, - ops.NTile, - ops.PercentRank, - ops.CumeDist, - ops.RowNumber, -) - - # TODO _map_interval_to_microseconds = { "W": 604800000000, From de6717e52174421326256c9e8110cc2f8486fbe7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:42:08 -0400 Subject: [PATCH 112/222] chore: fix date_trunc --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b405d5334d94..816806f82d9c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -596,7 +596,7 @@ def _truncate(op, **kw): except KeyError: raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") - return sg.func("date_trunc", duckunit, arg) + return sg.func("date_trunc", sg_literal(duckunit), arg) @translate_val.register(ops.DayOfWeekIndex) From 730a12f1a06dd147d4eaf7839f2d167fc0dc20ef Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 07:49:04 -0400 Subject: [PATCH 113/222] chore: fix column order, which may be returned differently by duckdb --- ibis/formats/pandas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/formats/pandas.py b/ibis/formats/pandas.py index 32d531ae10e8..d5616a200294 100644 --- a/ibis/formats/pandas.py +++ b/ibis/formats/pandas.py @@ -108,8 +108,8 @@ def convert_table(cls, df, schema): "schema column count does not match input data column count" ) - for (name, series), dtype in zip(df.items(), schema.types): - df[name] = cls.convert_column(series, dtype) + for name, dtype in schema.items(): + df[name] = cls.convert_column(df[name], dtype) # return data with the schema's columns which may be different than the # input columns From 039e9bd45bcd1932f67a8105cc633bb0180efed6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:07:25 -0400 Subject: [PATCH 114/222] chore: make sure all literal -> table transformations have a consistent alias, so that extraction of columns does not depend on the backend naming conventions --- ibis/expr/types/generic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index d5312cbea576..67ed1e847979 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -972,9 +972,11 @@ def as_table(self) -> ir.Table: from ibis.expr.analysis import find_first_base_table op = self.op() + name = op.name + op = ops.Alias(op, name) table = find_first_base_table(op) if table is not None: - return table.to_expr().aggregate([self]) + return table.to_expr().aggregate([self.name(name)]) else: return ops.DummyTable(values=(op,)).to_expr() From 3d7cb7f772959556233f9250bef1d2d521e73478 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:07:41 -0400 Subject: [PATCH 115/222] chore: use pandas result getter instead of branch --- ibis/backends/duckdb/__init__.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index ca81980dfa39..66c7f93937e3 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -460,17 +460,11 @@ def execute( # TODO: should we do this in arrow? # also what is pandas doing with dates? + # 🡅 is because of https://github.com/duckdb/duckdb/issues/8539 pandas_df = result.fetch_df() result = DuckDBPandasData.convert_table(pandas_df, schema) - if isinstance(expr, ir.Table): - return result - elif isinstance(expr, ir.Column): - return result.iloc[:, 0] - elif isinstance(expr, ir.Scalar): - return result.iat[0, 0] - else: - raise ValueError + return expr.__pandas_result__(result) def load_extension(self, extension: str) -> None: """Install and load a duckdb extension by name or path. From 244512170280ed09bfc2b30af11c1b0925c1992c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:08:22 -0400 Subject: [PATCH 116/222] chore: reorder columns to schema order instead of forcing them to be equivalent; this works around https://github.com/duckdb/duckdb/issues/8539 safely --- ibis/formats/pandas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/formats/pandas.py b/ibis/formats/pandas.py index d5616a200294..ab8e0b3a1151 100644 --- a/ibis/formats/pandas.py +++ b/ibis/formats/pandas.py @@ -113,8 +113,7 @@ def convert_table(cls, df, schema): # return data with the schema's columns which may be different than the # input columns - df.columns = schema.names - return df + return df.loc[:, list(schema.names)] @classmethod def convert_column(cls, obj, dtype): From 379e0336625bab0be515ac8721ed0717d0d2c150 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:29:09 -0400 Subject: [PATCH 117/222] chore: handle analytic window functions that require an order by --- ibis/backends/duckdb/compiler/values.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 816806f82d9c..bf0535fc5c12 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1415,13 +1415,15 @@ def _approx_median(op, **kw): @translate_val.register(ops.WindowFunction) def _window(op: ops.WindowFunction, **kw: Any): - if isinstance(op.func, ops.CumulativeOp): - arg = cumulative_to_window(op.func, op.frame) - return translate_val(arg, **kw) - func = op.func frame = op.frame + + if isinstance(func, ops.CumulativeOp): + arg = cumulative_to_window(func, op.frame) + return translate_val(arg, **kw) + tr_val = partial(translate_val, **kw) + this = tr_val(func, **kw) if frame.start is None: start = "UNBOUNDED" @@ -1447,12 +1449,16 @@ def _window(op: ops.WindowFunction, **kw: Any): else: partition_by = None - if frame.order_by: - order = sg.exp.Order(expressions=list(map(tr_val, frame.order_by))) + order_bys = list(map(tr_val, frame.order_by)) + + if isinstance(func, ops.Analytic) and not isinstance(func, ops.ShiftBase): + order_bys.extend(tr_val(ops.SortKey(arg, ascending=True)) for arg in func.args) + + if order_bys: + order = sg.exp.Order(expressions=order_bys) else: order = None - this = tr_val(func, **kw) window = sg.exp.Window(this=this, partition_by=partition_by, order=order, spec=spec) # preserve zero-based indexing From f51ec451aca54b597cca0ad32bcffd1a4ebf4d9d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:37:40 -0400 Subject: [PATCH 118/222] chore: shorten code a bit --- ibis/backends/duckdb/compiler/values.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index bf0535fc5c12..10f4bce3057a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1444,20 +1444,14 @@ def _window(op: ops.WindowFunction, **kw: Any): over="OVER", ) - if frame.group_by: - partition_by = list(map(tr_val, frame.group_by)) - else: - partition_by = None + partition_by = list(map(tr_val, frame.group_by)) or None order_bys = list(map(tr_val, frame.order_by)) if isinstance(func, ops.Analytic) and not isinstance(func, ops.ShiftBase): - order_bys.extend(tr_val(ops.SortKey(arg, ascending=True)) for arg in func.args) + order_bys.extend(map(tr_val, func.args)) - if order_bys: - order = sg.exp.Order(expressions=order_bys) - else: - order = None + order = sg.exp.Order(expressions=order_bys) if order_bys else None window = sg.exp.Window(this=this, partition_by=partition_by, order=order, spec=spec) From e5e7a8d4411db0bc033d6d9b054a2bbd3da7265d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 09:58:23 -0400 Subject: [PATCH 119/222] chore: fix ms conversion --- ibis/backends/duckdb/compiler/values.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 10f4bce3057a..7bc7f533c122 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -465,14 +465,25 @@ def _timestamp_now(op, **kw): return sg.cast(expression=sg.func("current_timestamp"), to="TIMESTAMP") +_POWERS_OF_TEN = { + "s": 0, + "ms": 3, + "us": 6, + "ns": 9, +} + + @translate_val.register(ops.TimestampFromUNIX) def _timestamp_from_unix(op, **kw): arg = translate_val(op.arg, **kw) - if (unit := op.unit.short) in {"ms", "us", "ns"}: + unit = op.unit.short + if unit == "ms": + return sg.func("epoch_ms", arg) + elif unit == "s": + return sg.exp.UnixToTime(this=arg) + else: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - return sg.exp.UnixToTime(this=arg) - @translate_val.register(ops.TimestampFromYMDHMS) def _timestamp_from_ymdhms(op, **kw): From 52bd89cbe4d892ab873545cc22a090a23da95c33 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:02:39 -0400 Subject: [PATCH 120/222] chore: xfail already xfailing test --- ibis/backends/tests/test_temporal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index e8c0df5d78f7..f932db18c73c 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -1441,7 +1441,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): reason="PySpark backend does not support timestamp from unix time with unit us. Supported unit is s.", ), pytest.mark.notimpl( - ["mssql", "clickhouse"], + ["mssql", "clickhouse", "duckdb"], raises=com.UnsupportedOperationError, reason="`us` unit is not supported!", ), From b5d3ec1ab1b6f9d1fea979ea108e5b8acc110e55 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:05:15 -0400 Subject: [PATCH 121/222] chore: fix intervals --- ibis/backends/duckdb/compiler/values.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 7bc7f533c122..f73605bab34e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -76,6 +76,10 @@ def sg_literal(arg, is_string=True): def _literal(op, **kw): value = op.value dtype = op.dtype + + if dtype.is_interval() and value is not None: + return _interval_format(op) + sg_type = DuckDBType.from_ibis(dtype) if value is None and dtype.nullable: From c6680e3b27312c99828824b6c53c0f34f342c49e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:15:31 -0400 Subject: [PATCH 122/222] chore: remove explicit typeof rule --- ibis/backends/duckdb/compiler/values.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index f73605bab34e..ed98b5c481ce 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -240,6 +240,7 @@ def _literal(op, **kw): ops.Map: "map", ops.MapMerge: "map_concat", ops.JSONGetItem: "json_extract", + ops.TypeOf: "typeof", } @@ -383,12 +384,6 @@ def _try_cast(op, **kw): ) -@translate_val.register(ops.TypeOf) -def _type_of(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func("typeof", arg) - - ### Comparator Conundrums From 4f8b6d305809a257a75af89390b2825c72235de6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:15:46 -0400 Subject: [PATCH 123/222] chore: fix casting integers to intervals --- ibis/backends/duckdb/compiler/values.py | 46 +++++++++++-------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index ed98b5c481ce..5824703d970c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -341,46 +341,40 @@ def _round(op, **kw): ### Dtype Dysmorphia -_interval_cast_suffixes = { - "s": "Second", - "m": "Minute", - "h": "Hour", - "D": "Day", - "W": "Week", - "M": "Month", - "Q": "Quarter", - "Y": "Year", +_interval_suffixes = { + "ms": "milliseconds", + "us": "microseconds", + "s": "seconds", + "m": "minutes", + "h": "hours", + "D": "days", + "M": "months", + "Y": "years", } @translate_val.register(ops.Cast) def _cast(op, **kw): arg = translate_val(op.arg, **kw) + to = op.to - if isinstance(op.to, dt.Interval): - suffix = _interval_cast_suffixes[op.to.unit.short] - if isinstance(op.arg, ops.TableColumn): - return ( - f"INTERVAL (i) {suffix} FROM (SELECT {arg.name} FROM {arg.table}) t(i)" - ) - - else: - return sg.exp.Interval(this=arg, unit=suffix) - elif isinstance(op.to, dt.Timestamp) and isinstance(op.arg.dtype, dt.Integer): + if to.is_interval(): + return sg.func( + f"to_{_interval_suffixes[to.unit.short]}", + sg.cast(arg, to=DuckDBType.from_ibis(dt.int32)), + ) + elif to.is_timestamp() and op.arg.dtype.is_integer(): return sg.func("to_timestamp", arg) - elif isinstance(op.to, dt.Timestamp) and op.to.timezone is not None: - timezone = sg.exp.Literal(this=op.to.timezone, is_string=True) - return sg.func("timezone", timezone, arg) + elif to.is_timestamp() and (tz := to.timezone) is not None: + return sg.func("timezone", sg_literal(tz), arg) - to = translate_val(op.to, **kw) - return sg.cast(expression=arg, to=to) + return sg.cast(expression=arg, to=translate_val(to, **kw)) @translate_val.register(ops.TryCast) def _try_cast(op, **kw): return sg.exp.TryCast( - this=translate_val(op.arg, **kw), - to=DuckDBType.to_string(op.to), + this=translate_val(op.arg, **kw), to=DuckDBType.from_ibis(op.to) ) From a7ca399efa90e72cf5f2468b50dc5ea6b8e3ae2c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:29:33 -0400 Subject: [PATCH 124/222] chore: use arrow --- ibis/backends/duckdb/__init__.py | 5 +++-- ibis/backends/duckdb/datatypes.py | 4 ---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 66c7f93937e3..b7b2a9a3de4a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -462,8 +462,9 @@ def execute( # also what is pandas doing with dates? # 🡅 is because of https://github.com/duckdb/duckdb/issues/8539 - pandas_df = result.fetch_df() - result = DuckDBPandasData.convert_table(pandas_df, schema) + t = result.arrow() + df = t.to_pandas(date_as_object=True, timestamp_as_object=True) + result = DuckDBPandasData.convert_table(df, schema) return expr.__pandas_result__(result) def load_extension(self, extension: str) -> None: diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py index c45aac11780b..a4277ca82760 100644 --- a/ibis/backends/duckdb/datatypes.py +++ b/ibis/backends/duckdb/datatypes.py @@ -6,10 +6,6 @@ class DuckDBPandasData(PandasData): - @staticmethod - def convert_Map(s, dtype, pandas_type): - return s.map(lambda x: dict(zip(x["key"], x["value"])), na_action="ignore") - @staticmethod def convert_Array(s, dtype, pandas_type): return s.replace(np.nan, None) From f0c0a9ea8d5d2d15d61eb28826066f37a6a21a51 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:38:46 -0400 Subject: [PATCH 125/222] chore: fix duckdb xfail --- ibis/backends/tests/test_numeric.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index ee244c9b26c7..c427f4eb14d0 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -24,10 +24,10 @@ import duckdb DuckDBConversionException = duckdb.ConversionException + DuckDBParserException = duckdb.ParserException except ImportError: duckdb = None - DuckDBConversionException = None - + DuckDBConversionException = DuckDBParserException = None try: import clickhouse_connect as cc @@ -385,9 +385,7 @@ def test_numeric_literal(con, backend, expr, expected_types): raises=ImpalaHiveServer2Error, ), pytest.mark.broken( - ["duckdb"], - "Unsupported precision.", - raises=NotImplementedError, + ["duckdb"], "Unsupported precision.", raises=DuckDBParserException ), pytest.mark.notyet(["datafusion"], raises=Exception), pytest.mark.notyet( From e5f9c8855fc9d9b5dfa345eb7e157cfb69961864 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:05:04 -0400 Subject: [PATCH 126/222] chore: fix from_ibis decimal --- ibis/backends/base/sqlglot/datatypes.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 5665783a847e..781d62e31ebb 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -275,11 +275,17 @@ def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType: @classmethod def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType: + if (precision := dtype.precision) is None: + precision = cls.default_decimal_precision + + if (scale := dtype.scale) is None: + scale = cls.default_decimal_scale + return sge.DataType( this=typecode.DECIMAL, expressions=[ - sge.DataTypeParam(this=sge.Literal.number(dtype.precision)), - sge.DataTypeParam(this=sge.Literal.number(dtype.scale)), + sge.DataTypeParam(this=sge.Literal.number(precision)), + sge.DataTypeParam(this=sge.Literal.number(scale)), ], ) From 6615b849bfffd78655db661cdf4c52c674bb8513 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:12:19 -0400 Subject: [PATCH 127/222] chore: replace `list_value` calls with `sg.exp.Array` --- ibis/backends/duckdb/compiler/values.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 5824703d970c..266c27d38786 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -843,9 +843,9 @@ def _array_sort(op, **kw): true=NULL, false=sg.func("list_distinct", arg) + sg.exp.If( - this=sg.func("list_count", arg) < sg.func("array_length", arg), - true=sg.func("list_value", NULL), - false=sg.func("list_value"), + this=sg.func("list_count", arg) < sg.func("len", arg), + true=sg.exp.Array.from_arg_list([NULL]), + false=sg.exp.Array.from_arg_list([]), ), ) @@ -883,7 +883,7 @@ def _in_column(op, **kw): def _array_concat(op, **kw): return sg.func( "flatten", - sg.func("list_value", *(translate_val(arg, **kw) for arg in op.arg)), + sg.exp.Array.from_arg_list([translate_val(arg, **kw) for arg in op.arg]), ) From 97309722410ddbef469c48b6715c16fefa4faaf0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:12:34 -0400 Subject: [PATCH 128/222] chore: use original conversion code --- ibis/backends/duckdb/__init__.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index b7b2a9a3de4a..c5593cc38fe2 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -458,12 +458,27 @@ def execute( except duckdb.CatalogException as e: raise exc.IbisError(e) - # TODO: should we do this in arrow? - # also what is pandas doing with dates? - # 🡅 is because of https://github.com/duckdb/duckdb/issues/8539 + import pyarrow.types as pat + + table = result.fetch_arrow_table() + + df = pd.DataFrame( + { + name: ( + col.to_pylist() + if ( + pat.is_nested(col.type) + or + # pyarrow / duckdb type null literals columns as int32? + # but calling `to_pylist()` will render it as None + col.null_count + ) + else col.to_pandas(timestamp_as_object=True) + ) + for name, col in zip(table.column_names, table.columns) + } + ) - t = result.arrow() - df = t.to_pandas(date_as_object=True, timestamp_as_object=True) result = DuckDBPandasData.convert_table(df, schema) return expr.__pandas_result__(result) From 9aff71ce33c056e9f9d559f09f1aae9dad8d17fa Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:18:19 -0400 Subject: [PATCH 129/222] chore: use original conversion code --- ibis/backends/duckdb/__init__.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index c5593cc38fe2..bd77b70fd0d7 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -454,32 +454,11 @@ def execute( self._log(sql) try: - result = self.con.execute(sql) + cur = self.con.execute(sql) except duckdb.CatalogException as e: raise exc.IbisError(e) - import pyarrow.types as pat - - table = result.fetch_arrow_table() - - df = pd.DataFrame( - { - name: ( - col.to_pylist() - if ( - pat.is_nested(col.type) - or - # pyarrow / duckdb type null literals columns as int32? - # but calling `to_pylist()` will render it as None - col.null_count - ) - else col.to_pandas(timestamp_as_object=True) - ) - for name, col in zip(table.column_names, table.columns) - } - ) - - result = DuckDBPandasData.convert_table(df, schema) + result = self.fetch_from_cursor(cur, schema) return expr.__pandas_result__(result) def load_extension(self, extension: str) -> None: @@ -1228,7 +1207,7 @@ def fetch_from_cursor( import pandas as pd import pyarrow.types as pat - table = cursor.cursor.fetch_arrow_table() + table = cursor.fetch_arrow_table() df = pd.DataFrame( { From 6207eac1c1473ba9b45a6790e267dd5407bae0ad Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:28:09 -0400 Subject: [PATCH 130/222] INVESTIGATE BEFORE MERGE: duckdb returns dateoffsets for interval objects --- ibis/backends/tests/test_temporal.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index f932db18c73c..ba0e2af2fc16 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -933,6 +933,11 @@ def convert_to_offset(x): raises=ValidationError, reason="unsupported operand type(s) for -: 'StringColumn' and 'TimestampScalar'", ), + pytest.mark.broken( + ["duckdb"], + raises=AssertionError, + reason="duckdb returns dateoffsets", + ), ], ), param( @@ -1972,7 +1977,9 @@ def test_extract_time_from_timestamp(con, microsecond): raises=(NotImplementedError, AttributeError), ) @pytest.mark.broken( - ["bigquery"], reason="BigQuery returns DateOffset arrays", raises=AssertionError + ["bigquery", "duckdb"], + reason="BigQuery returns DateOffset arrays", + raises=AssertionError, ) @pytest.mark.xfail_version( datafusion=["datafusion<31"], From b6f1fd02615fc696ff032e4c82253c2b008cb165 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:58:03 -0400 Subject: [PATCH 131/222] chore: get most of dot sql tests passing --- ibis/backends/duckdb/compiler/relations.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 723752be3df7..1cf1a19079a2 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -3,6 +3,7 @@ import functools from collections.abc import Mapping from functools import partial +from typing import Any import sqlglot as sg @@ -248,3 +249,21 @@ def _fillna(op: ops.FillNa, *, table, **kw): for col in op.schema.keys() ] return sg.select(*exprs).from_(table) + + +@translate_rel.register +def _view(op: ops.View, *, child, name: str, **_): + # TODO: find a better way to do this + backend = op.child.to_expr()._find_backend() + temp_view_src = backend._compile_temp_view( + table_name=name, source=sg.select("*").from_(child) + ) + backend.con.execute(temp_view_src.sql("duckdb")) + return sg.table(name) + + +@translate_rel.register +def _sql_string_view(op: ops.SQLStringView, query: str, **_: Any): + table = sg.table(op.name) + src = sg.parse_one(query, read="duckdb") + return sg.select("*").from_(table).with_(table, as_=src) From 61b92f8180da725957e064b885b29bb1ca4c5416 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:00:05 -0400 Subject: [PATCH 132/222] chore: avoid depending on physical table; eventually the classes should be decoupled --- ibis/backends/duckdb/compiler/relations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 1cf1a19079a2..fff0aa4115a3 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -24,8 +24,10 @@ def _dummy(op: ops.DummyTable, **kw): return sg.select(*map(partial(translate_val, **kw), op.values)) -@translate_rel.register(ops.PhysicalTable) -def _physical_table(op: ops.PhysicalTable, **_): +@translate_rel.register(ops.DatabaseTable) +@translate_rel.register(ops.UnboundTable) +@translate_rel.register(ops.InMemoryTable) +def _physical_table(op, **_): return sg.expressions.Table(this=sg.to_identifier(op.name, quoted=True)) From f1e9b1a6e30d9a80e7c3595404908f13f56a9364 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:01:29 -0400 Subject: [PATCH 133/222] chore: let parsing fall through to sqlglot --- ibis/backends/duckdb/compiler/relations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index fff0aa4115a3..91f54170c73e 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -267,5 +267,4 @@ def _view(op: ops.View, *, child, name: str, **_): @translate_rel.register def _sql_string_view(op: ops.SQLStringView, query: str, **_: Any): table = sg.table(op.name) - src = sg.parse_one(query, read="duckdb") - return sg.select("*").from_(table).with_(table, as_=src) + return sg.select("*").from_(table).with_(table, as_=query, dialect="duckdb") From 8d1c119623dfe53ff1e688d8eb0cb3e1323bb860 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:11:04 -0400 Subject: [PATCH 134/222] chore: dot sql all working yay --- ibis/backends/duckdb/__init__.py | 17 ++++++++++++----- ibis/backends/duckdb/compiler/relations.py | 5 +---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index bd77b70fd0d7..3feafbd1590a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -373,6 +373,7 @@ def do_connect( self.con.execute("SET enable_progress_bar = false") self._record_batch_readers_consumed = {} + self._temp_views: set[str] = set() def _from_url(self, url: str, **kwargs) -> BaseBackend: """Connect to a backend using a URL `url`. @@ -598,18 +599,24 @@ def _register_failure(self): f"please call one of {msg} directly" ) - def _compile_temp_view(self, table_name, source): - return sg.expressions.Create( - this=sg.expressions.Identifier( + def _create_temp_view(self, table_name, source): + if table_name not in self._temp_views and table_name in self.list_tables(): + raise ValueError( + f"{table_name} already exists as a non-temporary table or view" + ) + src = sg.exp.Create( + this=sg.exp.Identifier( this=table_name, quoted=True ), # CREATE ... 'table_name' kind="VIEW", # VIEW replace=True, # OR REPLACE - properties=sg.expressions.Properties( - expressions=[sg.expressions.TemporaryProperty()] # TEMPORARY + properties=sg.exp.Properties( + expressions=[sg.exp.TemporaryProperty()] # TEMPORARY ), expression=source, # AS ... ) + self.raw_sql(src.sql("duckdb")) + self._temp_views.add(table_name) @util.experimental def read_json( diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 91f54170c73e..fcaa47822ddd 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -257,10 +257,7 @@ def _fillna(op: ops.FillNa, *, table, **kw): def _view(op: ops.View, *, child, name: str, **_): # TODO: find a better way to do this backend = op.child.to_expr()._find_backend() - temp_view_src = backend._compile_temp_view( - table_name=name, source=sg.select("*").from_(child) - ) - backend.con.execute(temp_view_src.sql("duckdb")) + backend._create_temp_view(table_name=name, source=sg.select("*").from_(child)) return sg.table(name) From 7cb799007b3b21e3563666b6bd5274613b5bd2ab Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:11:26 -0400 Subject: [PATCH 135/222] chore: `sg.expressions` -> `sg.exp` --- ibis/backends/duckdb/__init__.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 3feafbd1590a..bd3a7fe757c6 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -102,14 +102,14 @@ def create_table( raise exc.IbisError("The schema or obj parameter is required") table_identifier = sg.to_identifier(name, quoted=True) - create_expr = sg.expressions.Create( + create_expr = sg.exp.Create( kind="TABLE", # TABLE replace=overwrite, # OR REPLACE ) if temp: - create_expr.args["properties"] = sg.expressions.Properties( - expressions=[sg.expressions.TemporaryProperty()] # TEMPORARY + create_expr.args["properties"] = sg.exp.Properties( + expressions=[sg.exp.TemporaryProperty()] # TEMPORARY ) if obj is not None and not isinstance(obj, ir.Expr): @@ -126,20 +126,20 @@ def create_table( create_expr.args["this"] = table_identifier # t0 else: # Schema -> Table -> [ColumnDefs] - schema_expr = sg.expressions.Schema( - this=sg.expressions.Table(this=table_identifier), + schema_expr = sg.exp.Schema( + this=sg.exp.Table(this=table_identifier), expressions=[ - sg.expressions.ColumnDef( + sg.exp.ColumnDef( this=sg.to_identifier(key, quoted=False), kind=DuckDBType.from_ibis(typ), ) if typ.nullable - else sg.expressions.ColumnDef( + else sg.exp.ColumnDef( this=sg.to_identifier(key, quoted=False), kind=DuckDBType.from_ibis(typ), constraints=[ - sg.expressions.ColumnConstraint( - kind=sg.expressions.NotNullColumnConstraint() + sg.exp.ColumnConstraint( + kind=sg.exp.NotNullColumnConstraint() ) ], ) @@ -241,8 +241,8 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema """ qualified_name = self._fully_qualified_name(table_name, database) if isinstance(qualified_name, str): - qualified_name = sg.expressions.Identifier(this=qualified_name, quoted=True) - query = sg.expressions.Describe(this=qualified_name) + qualified_name = sg.exp.Identifier(this=qualified_name, quoted=True) + query = sg.exp.Describe(this=qualified_name) results = self.raw_sql(query) names, types, nulls, *_ = results.fetch_arrow_table() names = names.to_pylist() @@ -497,7 +497,7 @@ def create_schema( ) name = sg.to_identifier(database, quoted=True) - return sg.expressions.Create( + return sg.exp.Create( this=name, kind="SCHEMA", replace=force, @@ -512,7 +512,7 @@ def drop_schema( ) name = sg.to_identifier(database, quoted=True) - return sg.expressions.Drop( + return sg.exp.Drop( this=name, kind="SCHEMA", replace=force, From cf503d754126a1224d6e3d48fc8bbcda87b342dd Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:14:45 -0400 Subject: [PATCH 136/222] chore: reuse `_create_temp_view`; down to a single test failure locally (!) --- ibis/backends/duckdb/__init__.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index bd3a7fe757c6..84445bf30a1a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -654,14 +654,13 @@ def read_json( options = [f"{key}={val}" for key, val in kwargs.items()] - sg_view_expr = self._compile_temp_view( + self._create_temp_view( table_name, sg.select("*").from_( sg.func("read_json_auto", normalize_filenames(source_list), *options) ), ) - self.raw_sql(sg_view_expr) return self.table(table_name) def read_csv( @@ -712,12 +711,11 @@ def read_csv( for key, val in kwargs.items() ] - sg_view_expr = self._compile_temp_view( + self._create_temp_view( table_name, sg.select("*").from_(sg.func("read_csv", source_list, *options)), ) - self.raw_sql(sg_view_expr) return self.table(table_name) def read_parquet( @@ -774,13 +772,11 @@ def _read_parquet_duckdb_native( else: pq_func = sg.func("read_parquet", source_list) - sg_view_expr = self._compile_temp_view( + self._create_temp_view( table_name, sg.select("*").from_(pq_func), ) - self.raw_sql(sg_view_expr) - def _read_parquet_pyarrow_dataset( self, source_list: str | Iterable[str], table_name: str, **kwargs: Any ) -> None: @@ -895,13 +891,12 @@ def read_postgres( ) self._load_extensions(["postgres_scanner"]) - sg_view_expr = self._compile_temp_view( + self._create_temp_view( table_name, sg.select("*").from_( sg.func("postgres_scan_pushdown", uri, schema, table_name) ), ) - self.raw_sql(sg_view_expr) return self.table(table_name) @@ -948,7 +943,7 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab raise ValueError("`table_name` is required when registering a sqlite table") self._load_extensions(["sqlite"]) - sg_view_expr = self._compile_temp_view( + self._create_temp_view( table_name, sg.select("*").from_( sg.func( @@ -956,7 +951,6 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab ) ), ) - self.raw_sql(sg_view_expr) return self.table(table_name) From b9e3aae4eea3375f01a4ca07ff4312854b28cd49 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:15:59 -0400 Subject: [PATCH 137/222] chore: slightly better comment --- ibis/backends/duckdb/compiler/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index fcaa47822ddd..98192c70b811 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -255,7 +255,7 @@ def _fillna(op: ops.FillNa, *, table, **kw): @translate_rel.register def _view(op: ops.View, *, child, name: str, **_): - # TODO: find a better way to do this + # TODO: find a way to do this without creating a temporary view backend = op.child.to_expr()._find_backend() backend._create_temp_view(table_name=name, source=sg.select("*").from_(child)) return sg.table(name) From 17932f1646b49c4a4d023a7c4839df80f924b16d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:11:41 -0400 Subject: [PATCH 138/222] chore: remove bogus `external_tables` argument --- ibis/backends/duckdb/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 84445bf30a1a..b34c73b1e851 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -439,11 +439,7 @@ def _log(self, sql: str) -> None: util.log(sql) def execute( - self, - expr: ir.Expr, - limit: str | None = "default", - external_tables: Mapping[str, pd.DataFrame] | None = None, - **kwargs: Any, + self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any ) -> Any: """Execute an expression.""" From 9b2ffe7964d1a9f25567a447230660c45b92fe76 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:14:04 -0400 Subject: [PATCH 139/222] chore: remove dialect from set ops --- ibis/backends/duckdb/compiler/relations.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 98192c70b811..09f900387726 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -151,19 +151,16 @@ def _query(op: ops.SQLQueryResult, *, aliases, **_): @translate_rel.register def _set_op(op: ops.SetOp, *, left, right, **_): - dialect = "duckdb" - if isinstance(left, sg.exp.Table): - left = sg.select("*", dialect=dialect).from_(left, dialect=dialect) + left = sg.select("*").from_(left) if isinstance(right, sg.exp.Table): - right = sg.select("*", dialect=dialect).from_(right, dialect=dialect) + right = sg.select("*").from_(right) return _SET_OP_FUNC[type(op)]( left.args.get("this", left), right.args.get("this", right), distinct=op.distinct, - dialect=dialect, ) From 0fef49db98301f82296e0b603b61a065db586eeb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:35:32 -0400 Subject: [PATCH 140/222] chore: remove timezone handling code; it may actually be correct, file an issue --- ibis/backends/duckdb/compiler/values.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 266c27d38786..156dd44c26ca 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -365,8 +365,6 @@ def _cast(op, **kw): ) elif to.is_timestamp() and op.arg.dtype.is_integer(): return sg.func("to_timestamp", arg) - elif to.is_timestamp() and (tz := to.timezone) is not None: - return sg.func("timezone", sg_literal(tz), arg) return sg.cast(expression=arg, to=translate_val(to, **kw)) From d3f8d9190d68df92a48fe9262c2398e4809ccea5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:47:50 -0400 Subject: [PATCH 141/222] chore: get builtin agg udfs working --- ibis/backends/duckdb/__init__.py | 6 +++++- ibis/backends/duckdb/compiler/values.py | 5 +++++ ibis/backends/duckdb/tests/test_udf.py | 13 ++++++------- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index b34c73b1e851..0cde035d108d 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -1296,7 +1296,8 @@ def _register_udfs(self, expr: ir.Expr) -> None: con.remove_function(udf_node.__class__.__name__) registration_func = compile_func(udf_node) - registration_func(con) + if registration_func is not None: + registration_func(con) def _compile_udf(self, udf_node: ops.ScalarUDF) -> None: func = udf_node.__func__ @@ -1318,6 +1319,9 @@ def register_udf(con): _compile_python_udf = _compile_udf _compile_pyarrow_udf = _compile_udf + def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: + """No op.""" + def _compile_pandas_udf(self, _: ops.ScalarUDF) -> None: raise NotImplementedError("duckdb doesn't support pandas UDFs") diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 156dd44c26ca..b89924dd6257 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1511,3 +1511,8 @@ def _rowid(op, *, aliases, **_) -> str: def _scalar_udf(op, **kw) -> str: funcname = op.__class__.__name__ return sg.func(funcname, *(translate_val(arg, **kw) for arg in op.args)) + + +@translate_val.register(ops.AggUDF) +def _scalar_udf(op, **kw) -> str: + return _aggregate(op, op.__class__.__name__, where=op.where, **kw) diff --git a/ibis/backends/duckdb/tests/test_udf.py b/ibis/backends/duckdb/tests/test_udf.py index 0292796c749d..8ba95f1feb78 100644 --- a/ibis/backends/duckdb/tests/test_udf.py +++ b/ibis/backends/duckdb/tests/test_udf.py @@ -50,9 +50,7 @@ def test_builtin_scalar(con, func): a, b = "duck", "luck" expr = func(a, b) - with con.begin() as c: - expected = c.exec_driver_sql(f"SELECT {func.__name__}({a!r}, {b!r})").scalar() - + expected = con.raw_sql(f"SELECT {func.__name__}({a!r}, {b!r})").df().squeeze() assert con.execute(expr) == expected @@ -79,9 +77,10 @@ def test_builtin_agg(con, func): data = ibis.memtable({"a": raw_data}) expr = func(data.a) - with con.begin() as c: - expected = c.exec_driver_sql( - f"SELECT {func.__name__}(a) FROM UNNEST({raw_data!r}) _ (a)" - ).scalar() + expected = ( + con.raw_sql(f"SELECT {func.__name__}(a) FROM UNNEST({raw_data!r}) _ (a)") + .df() + .squeeze() + ) assert con.execute(expr) == expected From 6b37361fe34e812c2b5a82296083329968ddb0a5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:48:38 -0400 Subject: [PATCH 142/222] chore: remove the need to pass where --- ibis/backends/duckdb/compiler/values.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index b89924dd6257..06dcf53dcc39 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -407,14 +407,14 @@ def _apply_agg_filter(expr, *, where, **kw): return expr -def _aggregate(op, func, *, where, **kw): +def _aggregate(op, func, **kw): args = [ translate_val(arg, **kw) for argname, arg in zip(op.argnames, op.args) if argname not in ("where", "how") ] agg = sg.func(func, *args) - return _apply_agg_filter(agg, where=where, **kw) + return _apply_agg_filter(agg, where=op.where, **kw) @translate_val.register(ops.NotAny) @@ -1133,7 +1133,7 @@ def _arbitrary(op, **kw): "first": "first", "last": "last", } - return _aggregate(op, functions[op.how], where=op.where, **kw) + return _aggregate(op, functions[op.how], **kw) @translate_val.register(ops.FindInSet) @@ -1515,4 +1515,4 @@ def _scalar_udf(op, **kw) -> str: @translate_val.register(ops.AggUDF) def _scalar_udf(op, **kw) -> str: - return _aggregate(op, op.__class__.__name__, where=op.where, **kw) + return _aggregate(op, op.__class__.__name__, **kw) From a23c33ce86628620ca168c99a65077c04bc0b976 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:56:09 -0400 Subject: [PATCH 143/222] chore: clean up aggregates --- ibis/backends/duckdb/compiler/values.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 06dcf53dcc39..0ec74fd07edb 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -244,17 +244,20 @@ def _literal(op, **kw): } -def _agg(func_name): - def formatter(op, **kw): - return _aggregate(op, func_name, where=op.where, **kw) - - return formatter +def _aggregate(op, func, **kw): + args = [ + translate_val(arg, **kw) + for argname, arg in zip(op.argnames, op.args) + if argname not in ("where", "how") + ] + agg = sg.func(func, *args) + return _apply_agg_filter(agg, where=op.where, **kw) for _op, _name in _simple_ops.items(): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - translate_val.register(_op)(_agg(_name)) + translate_val.register(_op)(partial(_aggregate, func=_name)) else: @translate_val.register(_op) @@ -407,16 +410,6 @@ def _apply_agg_filter(expr, *, where, **kw): return expr -def _aggregate(op, func, **kw): - args = [ - translate_val(arg, **kw) - for argname, arg in zip(op.argnames, op.args) - if argname not in ("where", "how") - ] - agg = sg.func(func, *args) - return _apply_agg_filter(agg, where=op.where, **kw) - - @translate_val.register(ops.NotAny) def _not_any(op, **kw): return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) From 473e33126a04abc0c930a8509ad6a63ac906a45e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 13:56:22 -0400 Subject: [PATCH 144/222] chore: unwrap aliases in struct field access --- ibis/backends/duckdb/compiler/values.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 0ec74fd07edb..1e0dd82cf421 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1211,9 +1211,13 @@ def _struct_column(op, **kw): ) +def _unwrap_alias(op): + return op.arg if isinstance(op, ops.Alias) else op + + @translate_val.register(ops.StructField) def _struct_field(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(_unwrap_alias(op.arg), **kw) return sg.exp.StructExtract( this=arg, expression=sg.exp.Literal(this=op.field, is_string=True) ) From 207152d9c6cd280ee5752b779525b40739e1adc5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 14:46:40 -0400 Subject: [PATCH 145/222] chore: remove last vestiges of `Schema.apply_to` --- ibis/backends/pyspark/__init__.py | 3 ++- ibis/expr/schema.py | 16 ++-------------- ibis/expr/tests/test_schema.py | 17 ----------------- 3 files changed, 4 insertions(+), 32 deletions(-) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 9599fde7a73b..5b82be143853 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -31,6 +31,7 @@ from ibis.backends.pyspark.client import PySparkTable from ibis.backends.pyspark.compiler import PySparkExprTranslator from ibis.backends.pyspark.datatypes import PySparkType +from ibis.formats.pandas import PandasData if TYPE_CHECKING: from collections.abc import Sequence @@ -233,7 +234,7 @@ def close(self): def fetch_from_cursor(self, cursor, schema): df = cursor.query.toPandas() # blocks until finished - return schema.apply_to(df) + return PandasData.convert_table(df, schema) def raw_sql(self, query: str) -> _PySparkCursor: query = self._session.sql(query) diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index d3b1841d517b..5eaeec19f9ec 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping -from typing import TYPE_CHECKING, Any +from typing import Any import ibis.expr.datatypes as dt from ibis.common.annotations import attribute @@ -10,10 +10,7 @@ from ibis.common.exceptions import InputTypeError, IntegrityError from ibis.common.grounds import Concrete from ibis.common.patterns import Coercible -from ibis.util import deprecated, indent - -if TYPE_CHECKING: - import pandas as pd +from ibis.util import indent class Schema(Concrete, Coercible, MapSet): @@ -207,15 +204,6 @@ def name_at_position(self, i: int) -> str: """ return self.names[i] - @deprecated( - as_of="6.0", - instead="use ibis.formats.pandas.PandasConverter.convert_frame() instead", - ) - def apply_to(self, df: pd.DataFrame) -> pd.DataFrame: - from ibis.formats.pandas import PandasData - - return PandasData.convert_table(df, self) - @lazy_singledispatch def schema(value: Any) -> Schema: diff --git a/ibis/expr/tests/test_schema.py b/ibis/expr/tests/test_schema.py index 7d926a1ee198..99e3cebe0956 100644 --- a/ibis/expr/tests/test_schema.py +++ b/ibis/expr/tests/test_schema.py @@ -5,7 +5,6 @@ from typing import NamedTuple import numpy as np -import pandas.testing as tm import pyarrow as pa import pytest @@ -180,22 +179,6 @@ def df(): return pd.DataFrame({"A": pd.Series([1], dtype="int8"), "b": ["x"]}) -def test_apply_to_column_rename(df): - schema = sch.Schema({"a": "int8", "B": "string"}) - expected = df.rename({"A": "a", "b": "B"}, axis=1) - with pytest.warns(FutureWarning): - df = schema.apply_to(df.copy()) - tm.assert_frame_equal(df, expected) - - -def test_apply_to_column_order(df): - schema = sch.Schema({"a": "int8", "b": "string"}) - expected = df.rename({"A": "a"}, axis=1) - with pytest.warns(FutureWarning): - new_df = schema.apply_to(df.copy()) - tm.assert_frame_equal(new_df, expected) - - def test_api_accepts_schema_objects(): s1 = sch.schema(dict(a="int", b="str")) s2 = sch.schema(s1) From 9ea9f970930b3f89768e36ad20c4eb0351470f4d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:10:40 -0400 Subject: [PATCH 146/222] chore: fix strip functions --- ibis/backends/duckdb/compiler/values.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 1e0dd82cf421..28a3847dee0e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -3,6 +3,7 @@ import calendar import functools import math +import string from functools import partial from typing import TYPE_CHECKING, Any @@ -212,9 +213,6 @@ def _literal(op, **kw): ops.EndsWith: "suffix", ops.LPad: "lpad", ops.RPad: "rpad", - ops.LStrip: "ltrim", - ops.RStrip: "rtrim", - ops.Strip: "trim", ops.StringAscii: "ascii", ops.StrRight: "right", # Other operations @@ -683,6 +681,21 @@ def _interval_from_integer(op, **kw): ### String Instruments +@translate_val.register(ops.Strip) +def _strip(op, **kw): + return sg.func("trim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) + + +@translate_val.register(ops.RStrip) +def _rstrip(op, **kw): + return sg.func("rtrim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) + + +@translate_val.register(ops.LStrip) +def _lstrip(op, **kw): + return sg.func("ltrim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) + + @translate_val.register(ops.Substring) def _substring(op, **kw): arg = translate_val(op.arg, **kw) From 7635882505529847a4d51bc62023fe359ee06c61 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:11:10 -0400 Subject: [PATCH 147/222] chore: ensure that null handling is done with IdenticalTo when `==` and `!=` are used --- ibis/expr/types/generic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 67ed1e847979..ef40fdc98eac 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -851,9 +851,13 @@ def __hash__(self) -> int: return super().__hash__() def __eq__(self, other: Value) -> ir.BooleanValue: + if other is None: + return _binop(ops.IdenticalTo, self, other) return _binop(ops.Equals, self, other) def __ne__(self, other: Value) -> ir.BooleanValue: + if other is None: + return ~self.__eq__(other) return _binop(ops.NotEquals, self, other) def __ge__(self, other: Value) -> ir.BooleanValue: From 544aa9ddba56b78210073b4a75f1a63fc954a2a7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:20:35 -0400 Subject: [PATCH 148/222] chore: fix array concat to preserve existing behavior --- ibis/backends/duckdb/compiler/values.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 28a3847dee0e..21df941be87e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -221,7 +221,7 @@ def _literal(op, **kw): ops.Unnest: "unnest", ops.Degrees: "degrees", ops.Radians: "radians", - ops.NullIf: "nullIf", + ops.NullIf: "nullif", ops.MapLength: "cardinality", ops.MapKeys: "map_keys", ops.MapValues: "map_values", @@ -885,10 +885,10 @@ def _in_column(op, **kw): @translate_val.register(ops.ArrayConcat) def _array_concat(op, **kw): - return sg.func( - "flatten", - sg.exp.Array.from_arg_list([translate_val(arg, **kw) for arg in op.arg]), - ) + result, *rest = map(partial(translate_val, **kw), op.arg) + for arg in rest: + result = sg.func("list_concat", result, arg) + return result @translate_val.register(ops.ArrayRepeat) From 2c3ad0a413dbfc5eaafbce67404af3ad30595ad2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:23:13 -0400 Subject: [PATCH 149/222] chore: fix __xor__ --- ibis/backends/duckdb/compiler/values.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 21df941be87e..71ced716796e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1342,12 +1342,11 @@ def _xor(op, **kw): # https://github.com/tobymao/sqlglot/issues/2238 left = translate_val(op.left, **kw) right = translate_val(op.right, **kw) - return sg.exp.And( - this=sg.exp.Paren(this=sg.exp.Or(this=left, expression=right)), - expression=sg.exp.Paren( - this=sg.exp.Not(this=sg.exp.And(this=left, expression=right)) - ), + result = sg.exp.And( + this=sg.exp.Or(this=left, expression=right), + expression=sg.exp.Not(this=sg.exp.And(this=left, expression=right)), ) + return result ### Ordering From 286b3bc5ec035a8e02af9a2cc305a71697bb634d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:24:30 -0400 Subject: [PATCH 150/222] chore: adjust `show_sql` doctest --- ibis/expr/sql.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index 75b4845c5aaa..6989869317eb 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -314,12 +314,12 @@ def show_sql( >>> expr = t.select(c=_.a * 2) >>> ibis.show_sql(expr) # duckdb dialect by default SELECT - t0.a * CAST(2 AS TINYINT) AS c - FROM t AS t0 - >>> ibis.show_sql(expr, dialect="mysql") - SELECT - t0.a * 2 AS c - FROM t AS t0 + ( + t0.a + ) * ( + CAST(2 AS TINYINT) + ) AS c + FROM "t" AS t0 """ print(to_sql(expr, dialect=dialect), file=file) From 3b464ffb7b90031901470ead2930fc0bdcd1d54c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Sep 2023 15:26:03 -0400 Subject: [PATCH 151/222] chore: get outta here clickhouse --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 71ced716796e..e0d3b556651c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1190,7 +1190,7 @@ def _table_array_view(op, *, cache, **kw): @translate_val.register(ops.ExistsSubquery) def _exists_subquery(op, **kw): - from ibis.backends.clickhouse.compiler import translate + from ibis.backends.duckdb.compiler import translate foreign_table = translate(op.foreign_table, {}) predicates = translate_val(op.predicates, **kw) From 706d881067d42bc64c47e1138ca11498672d0d4f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:09:23 -0400 Subject: [PATCH 152/222] chore: remove a bunch of version checking in the duckdb backend --- ibis/backends/duckdb/__init__.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 0cde035d108d..b09cf3f102e6 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -14,7 +14,6 @@ import pyarrow as pa import sqlglot as sg import toolz -from packaging.version import parse as vparse import ibis import ibis.common.exceptions as exc @@ -367,10 +366,6 @@ def do_connect( # Default timezone self.con.execute("SET TimeZone = 'UTC'") - # the progress bar in duckdb <0.8.0 causes kernel crashes in - # jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831 - if vparse(duckdb.__version__) < vparse("0.8.0"): - self.con.execute("SET enable_progress_bar = false") self._record_batch_readers_consumed = {} self._temp_views: set[str] = set() @@ -641,10 +636,6 @@ def read_json( Table An ibis table expression """ - if (version := vparse(self.version)) < vparse("0.7.0"): - raise exc.IbisError( - f"`read_json` requires duckdb >= 0.7.0, duckdb {version} is installed" - ) if not table_name: table_name = util.gen_name("read_json") @@ -1036,24 +1027,17 @@ def to_pyarrow_batches( table = expr.as_table() sql = self.compile(table, limit=limit, params=params) - # handle the argument name change in duckdb 0.8.0 - fetch_record_batch = ( - (lambda cur: cur.fetch_record_batch(rows_per_batch=chunk_size)) - if vparse(duckdb.__version__) >= vparse("0.8.0") - else (lambda cur: cur.fetch_record_batch(chunk_size=chunk_size)) - ) - - def batch_producer(table): - yield from fetch_record_batch(table) + def batch_producer(cur): + yield from cur.fetch_record_batch(rows_per_batch=chunk_size) # TODO: check that this is still handled correctly # batch_producer keeps the `self.con` member alive long enough to # exhaust the record batch reader, even if the backend or connection # have gone out of scope in the caller - table = self.raw_sql(sql) + result = self.raw_sql(sql) return pa.RecordBatchReader.from_batches( - expr.as_table().schema().to_pyarrow(), batch_producer(table) + expr.as_table().schema().to_pyarrow(), batch_producer(result) ) def to_pyarrow( From baf7e7580b54609d0d39a32776a2d52ffa57b0a8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:09:53 -0400 Subject: [PATCH 153/222] chore: note why `_define_udf_translation_rules` is not necessary --- ibis/backends/duckdb/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index b09cf3f102e6..e5b641c3404a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -68,8 +68,7 @@ class Backend(BaseBackend, CanCreateSchema): supports_create_or_replace = True def _define_udf_translation_rules(self, expr): - # TODO: - ... + """No-op: the rules are defined in the compiler.""" @property def current_database(self) -> str: From 1cb2f521bb3d18483c00b6b4aff81cf70569ab6d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:17:31 -0400 Subject: [PATCH 154/222] chore: avoid overkill use of arrow for current_database and current_schema --- ibis/backends/duckdb/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index e5b641c3404a..205b0660533b 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -72,11 +72,13 @@ def _define_udf_translation_rules(self, expr): @property def current_database(self) -> str: - return self.raw_sql("SELECT CURRENT_DATABASE()").arrow()[0][0].as_py() + (db,) = self.raw_sql("SELECT CURRENT_DATABASE()").fetchone() + return db @property def current_schema(self) -> str: - return self.raw_sql("SELECT CURRENT_SCHEMA()").arrow()[0][0].as_py() + (schema,) = self.raw_sql("SELECT CURRENT_SCHEMA()").fetchone() + return schema def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: with contextlib.suppress(AttributeError): From 569eda2718c2c7467caa60461e193172a606ba95 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:18:23 -0400 Subject: [PATCH 155/222] chore: remove unnecessary `DuckDBTable` class --- ibis/backends/duckdb/__init__.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 205b0660533b..69343e48a369 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -51,18 +51,6 @@ def normalize_filenames(source_list): } -class DuckDBTable(ir.Table): - """References a physical table in DuckDB.""" - - @property - def _client(self): - return self.op().source - - @property - def name(self): - return self.op().name - - class Backend(BaseBackend, CanCreateSchema): name = "duckdb" supports_create_or_replace = True @@ -209,7 +197,7 @@ def table(self, name: str, database: str | None = None) -> ir.Table: """ schema = self.get_schema(name, database=database) qname = self._fully_qualified_name(name, database) - return DuckDBTable(ops.DatabaseTable(qname, schema, self)) + return ops.DatabaseTable(qname, schema, self).to_expr() def _fully_qualified_name(self, name: str, database: str | None) -> str: return name From 457270278ccfb487d68b54adc5d81be28247dc73 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:48:26 -0400 Subject: [PATCH 156/222] chore: fix broken xor --- ibis/backends/duckdb/compiler/values.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index e0d3b556651c..e618a6026a6a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1340,13 +1340,11 @@ def formatter(op, **kw): @translate_val.register(ops.Xor) def _xor(op, **kw): # https://github.com/tobymao/sqlglot/issues/2238 - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - result = sg.exp.And( - this=sg.exp.Or(this=left, expression=right), - expression=sg.exp.Not(this=sg.exp.And(this=left, expression=right)), + left = translate_val(op.left, **kw).sql("duckdb") + right = translate_val(op.right, **kw).sql("duckdb") + return sg.parse_one( + f"({left} OR {right}) AND NOT ({left} AND {right})", read="duckdb" ) - return result ### Ordering From 20142449e088f61cf8c81b0ac08574c43bf2ffa0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:53:26 -0400 Subject: [PATCH 157/222] chore: remove `aliases or {}` pattern --- ibis/backends/duckdb/compiler/values.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index e618a6026a6a..298679688679 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -47,8 +47,7 @@ def _val_table_node(op, *, aliases, needs_alias=False, **_): @translate_val.register(ops.TableColumn) def _column(op, *, aliases, **_): - table_name = (aliases or {}).get(op.table) - return sg.column(op.name, table=table_name) + return sg.column(op.name, table=aliases.get(op.table)) @translate_val.register(ops.Alias) @@ -1511,7 +1510,7 @@ def _argument(op, **_): @translate_val.register(ops.RowID) def _rowid(op, *, aliases, **_) -> str: table = op.table - return sg.column(op.name, (aliases or {}).get(table, table.name)) + return sg.column(op.name, aliases.get(table, table.name)) @translate_val.register(ops.ScalarUDF) From 325128588efb097578e2fff31461943e0a218dd8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:53:45 -0400 Subject: [PATCH 158/222] style: rename `pieces` to `args` --- ibis/backends/duckdb/compiler/values.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 298679688679..2302bbea5d44 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1477,7 +1477,7 @@ def formatter(op, **kw): default = op.default arg_fmt = translate_val(arg, **kw) - pieces = [arg_fmt] + args = [arg_fmt] if default is not None: if offset is None: @@ -1487,13 +1487,13 @@ def formatter(op, **kw): default_fmt = translate_val(default, **kw) - pieces.append(offset_fmt) - pieces.append(default_fmt) + args.append(offset_fmt) + args.append(default_fmt) elif offset is not None: offset_fmt = translate_val(offset, **kw) - pieces.append(offset_fmt) + args.append(offset_fmt) - return sg.func(name, *pieces) + return sg.func(name, *args) return formatter From 085c21296953941d5eec4a244b81cb022211b3d9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:54:07 -0400 Subject: [PATCH 159/222] chore: rename duplicate `_scalar_udf` rule to `_agg_udf` --- ibis/backends/duckdb/compiler/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 2302bbea5d44..8c30dc9fe4dc 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1520,5 +1520,5 @@ def _scalar_udf(op, **kw) -> str: @translate_val.register(ops.AggUDF) -def _scalar_udf(op, **kw) -> str: +def _agg_udf(op, **kw) -> str: return _aggregate(op, op.__class__.__name__, **kw) From 1c4e84000262f711f734bddea2c6c9a57c7dc8c4 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 05:59:45 -0400 Subject: [PATCH 160/222] chore: remove _fully_qualified_name use --- ibis/backends/duckdb/__init__.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 69343e48a369..deb6db9f3c71 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -149,7 +149,7 @@ def create_view( database: str | None = None, overwrite: bool = False, ) -> ir.Table: - qualname = self._fully_qualified_name(name, database) + qualname = sg.table(name, db=database).sql(self.name) replace = "OR REPLACE " * overwrite query = self.compile(obj) code = f"CREATE {replace}VIEW {qualname} AS {query}" @@ -160,13 +160,13 @@ def create_view( def drop_table( self, name: str, database: str | None = None, force: bool = False ) -> None: - ident = self._fully_qualified_name(name, database) + ident = sg.table(name, db=database).sql(self.name) self.raw_sql(f"DROP TABLE {'IF EXISTS ' * force}{ident}") def drop_view( self, name: str, *, database: str | None = None, force: bool = False ) -> None: - name = self._fully_qualified_name(name, database) + name = sg.table(name, db=database).sql(self.name) if_exists = "IF EXISTS " * force self.raw_sql(f"DROP VIEW {if_exists}{name}") @@ -196,20 +196,7 @@ def table(self, name: str, database: str | None = None) -> ir.Table: Table expression """ schema = self.get_schema(name, database=database) - qname = self._fully_qualified_name(name, database) - return ops.DatabaseTable(qname, schema, self).to_expr() - - def _fully_qualified_name(self, name: str, database: str | None) -> str: - return name - # TODO: make this less bad - # calls to here from `drop_table` already have `main` prepended to the table name - # so what's the more robust way to deduplicate that identifier? - db = database or self.current_database - if name.startswith(db): - # This is a hack to get around nested quoting of table name - # e.g. '"main._ibis_temp_table_2"' - return name - return sg.table(name, db=db) # .sql(dialect="duckdb") + return ops.DatabaseTable(name, schema, self, namespace=database).to_expr() def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema: """Return a Schema object for the indicated table and database. @@ -227,9 +214,7 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema sch.Schema Ibis schema """ - qualified_name = self._fully_qualified_name(table_name, database) - if isinstance(qualified_name, str): - qualified_name = sg.exp.Identifier(this=qualified_name, quoted=True) + qualified_name = sg.table(table_name, database).sql(self.name) query = sg.exp.Describe(this=qualified_name) results = self.raw_sql(query) names, types, nulls, *_ = results.fetch_arrow_table() From 7add1664e73c83537ffe5d9fc6dab84d6112d2de Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 06:21:35 -0400 Subject: [PATCH 161/222] chore: fix listing tables, schemas and databases --- ibis/backends/duckdb/__init__.py | 81 +++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index deb6db9f3c71..f9456bf22679 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -176,10 +176,6 @@ def _load_into_cache(self, name, expr): def _clean_up_cached_table(self, op): self.drop_table(op.name) - def list_schemas(self): - out = self.raw_sql("SELECT current_schemas(True) as schemas").arrow() - return list(set(out["schemas"].to_pylist()[0])) - def table(self, name: str, database: str | None = None) -> ir.Table: """Construct a table expression. @@ -237,26 +233,69 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema ) def list_databases(self, like: str | None = None) -> list[str]: - result = self.raw_sql("PRAGMA database_list;") - results = result.fetch_arrow_table() + col = "catalog_name" + query = sg.select(sg.exp.Distinct(expressions=[sg.column(col)])).from_( + sg.table("schemata", db="information_schema") + ) + result = self.raw_sql(query) + dbs = result.fetch_arrow_table()[col] + return self._filter_with_like(dbs.to_pylist(), like) + + def list_schemas( + self, like: str | None = None, database: str | None = None + ) -> list[str]: + col = "schema_name" + query = sg.select(sg.exp.Distinct(expressions=[sg.column(col)])).from_( + sg.table("schemata", db="information_schema") + ) - if results: - _, databases, *_ = results - databases = databases.to_pylist() - else: - databases = [] - return self._filter_with_like(databases, like) + if database is not None: + query = query.where( + sg.condition( + sg.exp.EQ( + this=sg.column("catalog_name"), + expression=sg.exp.Literal(this=database, is_string=True), + ) + ) + ) - def list_tables(self, like: str | None = None) -> list[str]: - result = self.raw_sql("PRAGMA show_tables;") - results = result.fetch_arrow_table() + out = self.raw_sql(query).arrow() + return self._filter_with_like(out[col].to_pylist(), like=like) - if results: - tables, *_ = results - tables = tables.to_pylist() - else: - tables = [] - return self._filter_with_like(tables, like) + def list_tables( + self, + like: str | None = None, + database: str | None = None, + schema: str | None = None, + ) -> list[str]: + col = "table_name" + query = sg.select(sg.exp.Distinct(expressions=[sg.column(col)])).from_( + sg.table("tables", db="information_schema") + ) + + conditions = [] + + if database is not None: + conditions.append( + sg.exp.EQ( + this=sg.column("table_catalog"), + expression=sg.exp.Literal(this=database, is_string=True), + ) + ) + + if schema is not None: + conditions.append( + sg.exp.EQ( + this=sg.column("table_schema"), + expression=sg.exp.Literal(this=schema, is_string=True), + ) + ) + + if conditions: + query = query.where(sg.condition(sg.and_(*conditions))) + + out = self.raw_sql(query).arrow() + return self._filter_with_like(out["table_name"].to_pylist(), like) @classmethod def has_operation(cls, operation: type[ops.Value]) -> bool: From c155b44719206521b4cc07ab08591d984b19aff2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 06:54:34 -0400 Subject: [PATCH 162/222] chore: make accessing schema information the same as upstream because of DESCRIBE bug in 0.8.1 --- ibis/backends/duckdb/__init__.py | 97 ++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index f9456bf22679..a628c74f6881 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -6,6 +6,7 @@ import contextlib import os import warnings +from operator import itemgetter from pathlib import Path from typing import TYPE_CHECKING, Any @@ -45,6 +46,14 @@ def normalize_filenames(source_list): return list(map(util.normalize_filename, source_list)) +def strlit(s: str) -> sg.exp.Literal: + return sg.exp.Literal(this=s, is_string=True) + + +def eq(left, right) -> sg.exp.EQ: + return sg.exp.EQ(this=left, expression=right) + + _UDF_INPUT_TYPE_MAPPING = { InputType.PYARROW: duckdb.functional.ARROW, InputType.PYTHON: duckdb.functional.NATIVE, @@ -194,14 +203,18 @@ def table(self, name: str, database: str | None = None) -> ir.Table: schema = self.get_schema(name, database=database) return ops.DatabaseTable(name, schema, self, namespace=database).to_expr() - def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema: - """Return a Schema object for the indicated table and database. + def get_schema( + self, table_name: str, schema: str | None = None, database: str | None = None + ) -> sch.Schema: + """Compute the schema of a `table`. Parameters ---------- table_name May **not** be fully qualified. Use `database` if you want to qualify the identifier. + schema + Schema name database Database name @@ -210,26 +223,45 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema sch.Schema Ibis schema """ - qualified_name = sg.table(table_name, database).sql(self.name) - query = sg.exp.Describe(this=qualified_name) - results = self.raw_sql(query) - names, types, nulls, *_ = results.fetch_arrow_table() - names = names.to_pylist() - types = types.to_pylist() - # DuckDB gives back "YES", "NO" for nullability - # TODO: remove code crime - # nulls = [bool(null[:-2]) for null in nulls.to_pylist()] - nulls = [null == "YES" for null in nulls.to_pylist()] + conditions = [eq(sg.column("table_name"), strlit(table_name))] + + if database is not None: + conditions.append(eq(sg.column("table_catalog"), strlit(database))) + + if schema is not None: + conditions.append(eq(sg.column("table_schema"), strlit(schema))) + + query = ( + sg.select( + "column_name", + "data_type", + sg.alias(eq(sg.column("is_nullable"), strlit("YES")), "nullable"), + # see https://github.com/tobymao/sqlglot/issues/2253 for why + # this column is included + "ordinal_position", + ) + .from_(sg.table("columns", db="information_schema")) + .where(sg.and_(*conditions)) + ) + + result = self.raw_sql(query) + meta = result.arrow() + + if not meta: + raise exc.IbisError(f"Table not found: {table_name!r}") + + names = meta["column_name"].to_pylist() + types = meta["data_type"].to_pylist() + nullables = meta["nullable"].to_pylist() + pos = meta["ordinal_position"].to_pylist() + return sch.Schema( - dict( - zip( - names, - ( - DuckDBType.from_string(typ, nullable=null) - for typ, null in zip(types, nulls) - ), + { + name: DuckDBType.from_string(typ, nullable=nullable) + for _, name, typ, nullable in sorted( + zip(pos, names, types, nullables), key=itemgetter(0) ) - ) + } ) def list_databases(self, like: str | None = None) -> list[str]: @@ -250,14 +282,7 @@ def list_schemas( ) if database is not None: - query = query.where( - sg.condition( - sg.exp.EQ( - this=sg.column("catalog_name"), - expression=sg.exp.Literal(this=database, is_string=True), - ) - ) - ) + query = query.where(eq(sg.column("catalog_name"), strlit(database))) out = self.raw_sql(query).arrow() return self._filter_with_like(out[col].to_pylist(), like=like) @@ -276,23 +301,13 @@ def list_tables( conditions = [] if database is not None: - conditions.append( - sg.exp.EQ( - this=sg.column("table_catalog"), - expression=sg.exp.Literal(this=database, is_string=True), - ) - ) + conditions.append(eq(sg.column("table_catalog"), strlit(database))) if schema is not None: - conditions.append( - sg.exp.EQ( - this=sg.column("table_schema"), - expression=sg.exp.Literal(this=schema, is_string=True), - ) - ) + conditions.append(eq(sg.column("table_schema"), strlit(schema))) if conditions: - query = query.where(sg.condition(sg.and_(*conditions))) + query = query.where(sg.and_(*conditions)) out = self.raw_sql(query).arrow() return self._filter_with_like(out["table_name"].to_pylist(), like) From b7c9cb3e6c9c063b21c958e3a3c41a497cf3e312 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 06:54:56 -0400 Subject: [PATCH 163/222] style: improve udf translation rule docstring --- ibis/backends/duckdb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index a628c74f6881..3c16e9139c2d 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -65,7 +65,7 @@ class Backend(BaseBackend, CanCreateSchema): supports_create_or_replace = True def _define_udf_translation_rules(self, expr): - """No-op: the rules are defined in the compiler.""" + """No-op: UDF translation rules are defined in the compiler.""" @property def current_database(self) -> str: From 4845aa49887b90c84ed6be89225a9ab96b7a4ced Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:01:21 -0400 Subject: [PATCH 164/222] chore: allow running a single doctest with `just doctest` --- justfile | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/justfile b/justfile index 0b25bdaa8a6c..04e312c8eef4 100644 --- a/justfile +++ b/justfile @@ -58,17 +58,22 @@ test +backends: doctest *args: #!/usr/bin/env bash - # TODO(cpcloud): why doesn't pytest --ignore-glob=test_*.py work? - pytest --doctest-modules {{ args }} $( - find \ - ibis \ - -wholename '*.py' \ - -and -not -wholename '*test*.py' \ - -and -not -wholename '*__init__*' \ - -and -not -wholename '*gen_*.py' \ - -and -not -wholename '*ibis/expr/selectors.py' \ - -and -not -wholename '*ibis/backends/flink/*' # FIXME(deepyaman) - ) + if [ -z "{{ args }}" ]; then + # TODO(cpcloud): why doesn't pytest --ignore-glob=test_*.py work? + args=($( + find \ + ibis \ + -wholename '*.py' \ + -and -not -wholename '*test*.py' \ + -and -not -wholename '*__init__*' \ + -and -not -wholename '*gen_*.py' \ + -and -not -wholename '*ibis/expr/selectors.py' \ + -and -not -wholename '*ibis/backends/flink/*' # FIXME(deepyaman) + )) + else + args=({{ args }}) + fi + pytest --doctest-modules "${args[@]}" # download testing data download-data owner="ibis-project" repo="testing-data" rev="master": From dca77c5aaefbd30cde9d74eef47f55940658ca6a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:01:30 -0400 Subject: [PATCH 165/222] chore: handle time literals --- ibis/backends/duckdb/compiler/values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 8c30dc9fe4dc..a3ff44c256fe 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -100,8 +100,8 @@ def _literal(op, **kw): to=sg.exp.DataType.Type.FLOAT if dtype.is_decimal() else sg_type, ) return sg.cast(sg_literal(value, is_string=False), to=sg_type) - elif dtype.is_interval(): - return _interval_format(op) + elif dtype.is_time(): + return sg.cast(sg_literal(value), to=sg_type) elif dtype.is_timestamp(): year = sg_literal(value.year, is_string=False) month = sg_literal(value.month, is_string=False) From d1bd037a8660d6323463616895e4a8993f1c7eec Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:06:21 -0400 Subject: [PATCH 166/222] revert: chore: allow running a single doctest with `just doctest` This reverts commit b1f706dbdce2621db91a44ea76c54ca9144c50f4. --- justfile | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/justfile b/justfile index 04e312c8eef4..0b25bdaa8a6c 100644 --- a/justfile +++ b/justfile @@ -58,22 +58,17 @@ test +backends: doctest *args: #!/usr/bin/env bash - if [ -z "{{ args }}" ]; then - # TODO(cpcloud): why doesn't pytest --ignore-glob=test_*.py work? - args=($( - find \ - ibis \ - -wholename '*.py' \ - -and -not -wholename '*test*.py' \ - -and -not -wholename '*__init__*' \ - -and -not -wholename '*gen_*.py' \ - -and -not -wholename '*ibis/expr/selectors.py' \ - -and -not -wholename '*ibis/backends/flink/*' # FIXME(deepyaman) - )) - else - args=({{ args }}) - fi - pytest --doctest-modules "${args[@]}" + # TODO(cpcloud): why doesn't pytest --ignore-glob=test_*.py work? + pytest --doctest-modules {{ args }} $( + find \ + ibis \ + -wholename '*.py' \ + -and -not -wholename '*test*.py' \ + -and -not -wholename '*__init__*' \ + -and -not -wholename '*gen_*.py' \ + -and -not -wholename '*ibis/expr/selectors.py' \ + -and -not -wholename '*ibis/backends/flink/*' # FIXME(deepyaman) + ) # download testing data download-data owner="ibis-project" repo="testing-data" rev="master": From 10aa6c445fa19d768233661b2b2ff973369b5e8a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:17:19 -0400 Subject: [PATCH 167/222] chore: remove redundant str call --- ibis/backends/duckdb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 3c16e9139c2d..8b2110dbceab 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -451,7 +451,7 @@ def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any return sql.sql(dialect="duckdb", pretty=True) def _to_sql(self, expr: ir.Expr, **kwargs) -> str: - return str(self.compile(expr, **kwargs)) + return self.compile(expr, **kwargs) def _log(self, sql: str) -> None: """Log `sql`. From 6a4873166ef3c8d868b4e70e21ed94a5af57a3f4 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:17:46 -0400 Subject: [PATCH 168/222] chore: unbind in `to_sql` call to decouple backend translation --- .../test_client/test_to_other_sql/out.sql | 15 +++++++++++++++ ibis/backends/duckdb/tests/test_client.py | 7 +++++++ ibis/expr/sql.py | 2 +- 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 ibis/backends/duckdb/tests/snapshots/test_client/test_to_other_sql/out.sql diff --git a/ibis/backends/duckdb/tests/snapshots/test_client/test_to_other_sql/out.sql b/ibis/backends/duckdb/tests/snapshots/test_client/test_to_other_sql/out.sql new file mode 100644 index 000000000000..41237f8decfe --- /dev/null +++ b/ibis/backends/duckdb/tests/snapshots/test_client/test_to_other_sql/out.sql @@ -0,0 +1,15 @@ +SELECT + t0."id", + t0."bool_col", + t0."tinyint_col", + t0."smallint_col", + t0."int_col", + t0."bigint_col", + t0."float_col", + t0."double_col", + t0."date_string_col", + t0."string_col", + t0."timestamp_col", + t0."year", + t0."month" +FROM "functional_alltypes" AS t0 \ No newline at end of file diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index 39cd09e7f8d4..2bd080a87aa1 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -83,3 +83,10 @@ def test_insert(con): con.insert(name, {"a": [1, 2]}, overwrite=True) assert t.count().execute() == 2 + + +def test_to_other_sql(con, snapshot): + t = con.table("functional_alltypes") + + sql = ibis.to_sql(t, dialect="snowflake") + snapshot.assert_match(sql, "out.sql") diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index 6989869317eb..82bb7dbfcd63 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -376,6 +376,6 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString: else: read = write = getattr(backend, "_sqlglot_dialect", dialect) - sql = backend._to_sql(expr, **kwargs) + sql = backend._to_sql(expr.unbind(), **kwargs) (pretty,) = sg.transpile(sql, read=read, write=write, pretty=True) return SQLString(pretty) From 10729fa1889c57d8a9df6af4b3badbb21969e71e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:20:29 -0400 Subject: [PATCH 169/222] chore: update for enforced naming for all expressions including literals --- .../snapshots/test_functions/test_timestamp_truncate/d/out.sql | 2 +- .../snapshots/test_functions/test_timestamp_truncate/h/out.sql | 2 +- .../snapshots/test_functions/test_timestamp_truncate/m/out.sql | 2 +- .../test_functions/test_timestamp_truncate/minute/out.sql | 2 +- .../snapshots/test_functions/test_timestamp_truncate/w/out.sql | 2 +- .../snapshots/test_functions/test_timestamp_truncate/y/out.sql | 2 +- .../test_select/test_scalar_exprs_no_table_refs/add/out.sql | 2 +- .../test_select/test_scalar_exprs_no_table_refs/now/out.sql | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql index a47bccf56e8e..62efc1ba5da3 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/d/out.sql @@ -1,2 +1,2 @@ SELECT - toDate(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toDate(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/h/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/h/out.sql index 258cc32ad0b9..151ab1446fcc 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/h/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/h/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfHour(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toStartOfHour(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/m/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/m/out.sql index 5eafd07e0c26..d8b699a2fad5 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/m/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/m/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfMinute(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toStartOfMinute(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/minute/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/minute/out.sql index 5eafd07e0c26..d8b699a2fad5 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/minute/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/minute/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfMinute(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toStartOfMinute(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql index 62f73f0f6223..c470182f8e7f 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/w/out.sql @@ -1,2 +1,2 @@ SELECT - toMonday(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toMonday(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql index 21150a0d7d2f..627ecf3e76d4 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_truncate/y/out.sql @@ -1,2 +1,2 @@ SELECT - toStartOfYear(toDateTime('2009-05-17T12:34:56')) \ No newline at end of file + toStartOfYear(toDateTime('2009-05-17T12:34:56')) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/add/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/add/out.sql index 23c737b05fe6..c1273e145900 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/add/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/add/out.sql @@ -1,2 +1,2 @@ SELECT - 1 + 2 \ No newline at end of file + 1 + 2 AS "Add(1, 2)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/now/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/now/out.sql index 057c9d542a56..8a90bb0bb098 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/now/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_select/test_scalar_exprs_no_table_refs/now/out.sql @@ -1,2 +1,2 @@ SELECT - now() \ No newline at end of file + now() AS "TimestampNow()" \ No newline at end of file From 626fdc6a36f22d90ac8f7911bbfbb1e3d13d431f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:36:29 -0400 Subject: [PATCH 170/222] docs: avoid executing `.close()` as it is not necessary for the new duckdb backend --- docs/how-to/extending/sql.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how-to/extending/sql.qmd b/docs/how-to/extending/sql.qmd index 1a07e02a77a6..447abf8a7024 100644 --- a/docs/how-to/extending/sql.qmd +++ b/docs/how-to/extending/sql.qmd @@ -177,7 +177,7 @@ with closing(con.raw_sql("CREATE TEMP TABLE my_table AS SELECT * FROM RANGE(10)" Here's an example: -```{python} +```python cur = con.raw_sql("CREATE TEMP TABLE t AS SELECT * FROM RANGE(10)") cur.close() # <1> ``` From fdf2579c7c5bd5722180264bdc91a81481c2c594 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:37:03 -0400 Subject: [PATCH 171/222] chore: avoid unnecessary additional node before finding base table --- ibis/expr/types/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index ef40fdc98eac..05bbac217dc8 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -976,9 +976,9 @@ def as_table(self) -> ir.Table: from ibis.expr.analysis import find_first_base_table op = self.op() + table = find_first_base_table(op) name = op.name op = ops.Alias(op, name) - table = find_first_base_table(op) if table is not None: return table.to_expr().aggregate([self.name(name)]) else: From 6a582c89d719b1cf1eb8dcf7ca3dbad94bb01488 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:48:42 -0400 Subject: [PATCH 172/222] chore: factor out unaliasing --- .../backends/clickhouse/compiler/relations.py | 27 +++++++++++++++++++ ibis/backends/duckdb/compiler/values.py | 7 ++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/ibis/backends/clickhouse/compiler/relations.py b/ibis/backends/clickhouse/compiler/relations.py index 5e209935830a..8435aa9c067e 100644 --- a/ibis/backends/clickhouse/compiler/relations.py +++ b/ibis/backends/clickhouse/compiler/relations.py @@ -7,7 +7,12 @@ import ibis.common.exceptions as com import ibis.expr.operations as ops +<<<<<<< HEAD from ibis.backends.base.sqlglot import FALSE, NULL, STAR +======= +from ibis.backends.base.sqlglot import unalias +from ibis.backends.clickhouse.compiler.values import translate_val +>>>>>>> 4ae077d66 (chore: factor out unaliasing) @functools.singledispatch @@ -85,6 +90,7 @@ def _selection( @translate_rel.register(ops.Aggregation) +<<<<<<< HEAD def _aggregation( op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_ ): @@ -104,6 +110,27 @@ def _aggregation( if sort_keys: sel = sel.order_by(*sort_keys) +======= +def _aggregation(op: ops.Aggregation, *, table, **kw): + tr_val = partial(translate_val, **kw) + + by = tuple(map(tr_val, op.by)) + metrics = tuple(map(tr_val, op.metrics)) + selections = (by + metrics) or "*" + sel = sg.select(*selections).from_(table) + + if group_keys := op.by: + sel = sel.group_by(*map(tr_val, map(unalias, group_keys)), dialect="clickhouse") + + if predicates := op.predicates: + sel = sel.where(*map(tr_val, map(unalias, predicates)), dialect="clickhouse") + + if having := op.having: + sel = sel.having(*map(tr_val, map(unalias, having)), dialect="clickhouse") + + if sort_keys := op.sort_keys: + sel = sel.order_by(*map(tr_val, map(unalias, sort_keys)), dialect="clickhouse") +>>>>>>> 4ae077d66 (chore: factor out unaliasing) return sel diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index a3ff44c256fe..49c05b42b12e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -15,6 +15,7 @@ import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.backends.base.sqlglot import unalias from ibis.backends.base.sqlglot.datatypes import DuckDBType if TYPE_CHECKING: @@ -1223,13 +1224,9 @@ def _struct_column(op, **kw): ) -def _unwrap_alias(op): - return op.arg if isinstance(op, ops.Alias) else op - - @translate_val.register(ops.StructField) def _struct_field(op, **kw): - arg = translate_val(_unwrap_alias(op.arg), **kw) + arg = translate_val(unalias(op.arg), **kw) return sg.exp.StructExtract( this=arg, expression=sg.exp.Literal(this=op.field, is_string=True) ) From 25527901db4cd26ad531c330da0666076ab46ae3 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:58:23 -0400 Subject: [PATCH 173/222] chore: fix renaming of pandas backend single column results --- ibis/backends/pandas/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/pandas/core.py b/ibis/backends/pandas/core.py index ef29b2bb29cc..efd31757afb3 100644 --- a/ibis/backends/pandas/core.py +++ b/ibis/backends/pandas/core.py @@ -502,7 +502,7 @@ def _apply_schema(op: ops.Node, result: pd.DataFrame | pd.Series): return PandasData.convert_table(df, op.schema) elif isinstance(result, pd.Series): schema = op.to_expr().as_table().schema() - df = PandasData.convert_table(result.to_frame(), schema) + df = PandasData.convert_table(result.to_frame(name=schema.names[0]), schema) return df.iloc[:, 0].reset_index(drop=True) else: return result From c3ef2dd179e45337f3aa7ff771cb5cd27acd57f6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:01:29 -0400 Subject: [PATCH 174/222] chore: fix insert by runing pre execute hooks before compiling --- ibis/backends/duckdb/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 8b2110dbceab..8c80b1801975 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -1378,6 +1378,7 @@ def insert( con.execute(f"TRUNCATE TABLE {table.sql('duckdb')}") if isinstance(obj, ir.Table): + self._run_pre_execute_hooks(obj) query = sg.exp.insert( expression=self.compile(obj), into=table, dialect="duckdb" ) From 86fc6f4bebfc1a303191025c5db826c93da7e5e3 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:03:10 -0400 Subject: [PATCH 175/222] chore: clean up append --- ibis/backends/duckdb/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 8c80b1801975..a7c144c00311 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -1383,7 +1383,8 @@ def insert( expression=self.compile(obj), into=table, dialect="duckdb" ) con.execute(query.sql("duckdb")) - elif isinstance(obj, pd.DataFrame): - con.append(table_name, obj) else: - con.append(table_name, pd.DataFrame(obj)) + con.append( + table_name, + obj if isinstance(obj, pd.DataFrame) else pd.DataFrame(obj), + ) From f4ee1e12d8fbda284170cf03175d0841d57601a5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:04:28 -0400 Subject: [PATCH 176/222] chore: consolidate under the banner of `raw_sql` --- ibis/backends/duckdb/__init__.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index a7c144c00311..ae7540e49c14 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -393,7 +393,7 @@ def do_connect( self._load_extensions(extensions) # Default timezone - self.con.execute("SET TimeZone = 'UTC'") + self.raw_sql("SET TimeZone = 'UTC'") self._record_batch_readers_consumed = {} self._temp_views: set[str] = set() @@ -1370,21 +1370,19 @@ def insert( ValueError If the type of `obj` isn't supported """ - con = self.con - table = sg.table(table_name, db=database) if overwrite: - con.execute(f"TRUNCATE TABLE {table.sql('duckdb')}") + self.raw_sql(f"TRUNCATE TABLE {table.sql('duckdb')}") if isinstance(obj, ir.Table): self._run_pre_execute_hooks(obj) query = sg.exp.insert( expression=self.compile(obj), into=table, dialect="duckdb" ) - con.execute(query.sql("duckdb")) + self.raw_sql(query) else: - con.append( + self.con.append( table_name, obj if isinstance(obj, pd.DataFrame) else pd.DataFrame(obj), ) From 800d36cc6228389974b49c21bbea171a835ee19b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:32:50 -0400 Subject: [PATCH 177/222] chore: skip missing snowflake deps --- ibis/backends/duckdb/tests/test_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index 2bd080a87aa1..59a1435d92b4 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -86,6 +86,9 @@ def test_insert(con): def test_to_other_sql(con, snapshot): + pytest.importorskip("snowflake.connector") + pytest.importorskip("snowflake.sqlalchemy") + t = con.table("functional_alltypes") sql = ibis.to_sql(t, dialect="snowflake") From bc701e0e5ab76a40671b8326121cced5464f0dbb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:40:50 -0400 Subject: [PATCH 178/222] chore: handle expressions in values of replacements mapping for fillna --- ibis/backends/polars/compiler.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 71a5075d3481..ecb0132d2eb4 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -330,13 +330,29 @@ def fillna(op, **kw): table = translate(op.table, **kw) columns = [] + + repls = op.replacements + + if isinstance(repls, Mapping): + + def get_replacement(name): + repl = repls.get(name) + if repl is not None: + _assert_literal(repl) + return repl.value + else: + return None + + else: + _assert_literal(repls) + value = repls.value + + def get_replacement(_): + return value + for name, dtype in op.table.schema.items(): column = pl.col(name) - if isinstance(op.replacements, Mapping): - value = op.replacements.get(name) - else: - _assert_literal(op.replacements) - value = op.replacements.value + value = get_replacement(name) if value is not None: if dtype.is_floating(): From f41d379d8a604f101a1fb0ae3e354a92b2df417f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:43:49 -0400 Subject: [PATCH 179/222] chore: handle expressions in values of replacements mapping for fillna for pandas --- ibis/backends/pandas/execution/generic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index 0a64f60e0f51..80d17e151e76 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -1346,7 +1346,8 @@ def execute_node_fillna_dataframe_scalar(op, df, replacements, **kwargs): @execute_node.register(ops.FillNa, pd.DataFrame) def execute_node_fillna_dataframe_dict(op, df, **kwargs): - return df.fillna(dict(op.replacements)) + replmap = {col: execute(repl, **kwargs) for col, repl in op.replacements.items()} + return df.fillna(replmap) @execute_node.register(ops.IfNull, pd.Series, simple_types) From de0755d86deb4676e093e2e486694f2c45927244 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:46:28 -0400 Subject: [PATCH 180/222] chore: handle expressions in values of replacements mapping for fillna for pyspark --- ibis/backends/pyspark/compiler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 81f5fe2447fc..ac7c7d50b4de 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -1771,11 +1771,11 @@ def compile_dropna_table(t, op, **kwargs): @compiles(ops.FillNa) def compile_fillna_table(t, op, **kwargs): table = t.translate(op.table, **kwargs) - raw_replacements = op.replacements + repls = op.replacements replacements = ( - dict(raw_replacements) - if isinstance(raw_replacements, frozendict) - else raw_replacements.value + {name: t.translate(value, **kwargs) for name, value in repls.items()} + if isinstance(repls, frozendict) + else repls.value ) return table.fillna(replacements) From 1dee6af22bd581a11d0f81a9f203d082d0867e97 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:46:56 -0400 Subject: [PATCH 181/222] fixup! chore: handle expressions in values of replacements mapping for fillna for pyspark --- ibis/backends/pyspark/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index ac7c7d50b4de..79ab2c7043af 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -1773,7 +1773,7 @@ def compile_fillna_table(t, op, **kwargs): table = t.translate(op.table, **kwargs) repls = op.replacements replacements = ( - {name: t.translate(value, **kwargs) for name, value in repls.items()} + {name: t.translate(value, raw=True, **kwargs) for name, value in repls.items()} if isinstance(repls, frozendict) else repls.value ) From 84a766219c087b3dbcf3ab75caf9037bca188e71 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:01:48 -0400 Subject: [PATCH 182/222] chore(clickhouse): force column names to match the schema because column names with characters that need escaping are a headache --- ibis/backends/clickhouse/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 39582d8fd16f..5435258db63d 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -368,6 +368,8 @@ def execute( if df.empty: df = pd.DataFrame(columns=schema.names) + else: + df.columns = list(schema.names) # TODO: remove the extra conversion # From 696ecb1d4e3daaa0b3156a8069b329854b8bcef2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:04:28 -0400 Subject: [PATCH 183/222] chore: bring back annoying workaround for ddb 8375 --- ibis/backends/duckdb/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index ae7540e49c14..1b33bd7bd616 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -386,6 +386,11 @@ def do_connect( import duckdb + # TODO(cpcloud): remove this when duckdb is >0.8.1 + # this is here to workaround https://github.com/duckdb/duckdb/issues/8735 + with contextlib.suppress(duckdb.InvalidInputException): + duckdb.execute("SELECT ?", (1,)) + self.con = duckdb.connect(str(database), config=config) # Load any pre-specified extensions From 2a9d9184866d987805a9f3043da55846b891226c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:13:36 -0400 Subject: [PATCH 184/222] chore(bigquery): force column names to match the expression schema --- ibis/backends/bigquery/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 0675fb3bf498..1276d00aa2c9 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -437,6 +437,7 @@ def execute(self, expr, params=None, limit="default", **kwargs): def fetch_from_cursor(self, cursor, schema): arrow_t = self._cursor_to_arrow(cursor) df = arrow_t.to_pandas(timestamp_as_object=True) + df.columns = list(schema.names) return PandasData.convert_table(df, schema) def _cursor_to_arrow( From c55fb67fda4d5db8b2bb4d4a607da3d9d0271385 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:45:42 -0400 Subject: [PATCH 185/222] chore(pyspark): handle intervals before raw --- ibis/backends/pyspark/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 79ab2c7043af..e7f1119a9edf 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -381,13 +381,13 @@ def compile_literal(t, op, *, raw=False, **kwargs): if value is None: return F.lit(None) - if raw: - return value - if dtype.is_interval(): # execute returns a Timedelta and value is nanoseconds return execute(op).value + if raw: + return value + if isinstance(value, collections.abc.Set): # Don't wrap set with F.lit if isinstance(value, frozenset): From 3dc9ccf13be5e3bf3aa640206630464ca064b3cb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:45:54 -0400 Subject: [PATCH 186/222] chore(snowflake): force column names to match the expression schema --- ibis/backends/snowflake/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index f875ef66cda2..b4a17dd1957d 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -372,6 +372,7 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: if (table := cursor.cursor.fetch_arrow_all()) is None: table = schema.to_pyarrow().empty_table() df = table.to_pandas(timestamp_as_object=True) + df.columns = list(schema.names) return SnowflakePandasData.convert_table(df, schema) def to_pyarrow_batches( From b68dfe6f33ddcb8d2658d2e4b0eb692aa1036da5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:22:37 -0400 Subject: [PATCH 187/222] test(duckdb): fix error type on extension download failure when sandboxed --- ibis/backends/duckdb/tests/test_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index 59a1435d92b4..d0863f6b2b53 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -2,7 +2,6 @@ import duckdb import pytest -import sqlalchemy as sa import ibis from ibis.conftest import LINUX, SANDBOXED @@ -21,7 +20,7 @@ def ext_directory(tmpdir_factory): @pytest.mark.xfail( LINUX and SANDBOXED, reason="nix on linux cannot download duckdb extensions or data due to sandboxing", - raises=sa.exc.OperationalError, + raises=duckdb.IOException, ) @pytest.mark.xdist_group(name="duckdb-extensions") def test_connect_extensions(ext_directory): From 566e399dcee1920005bedf633702d7885238ebc8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:24:32 -0400 Subject: [PATCH 188/222] chore(deps): remove duckdb-engine from deps --- pyproject.toml | 13 +------------ requirements-dev.txt | 1 - 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05dfb14d9904..e5a48620741b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,6 @@ datafusion = { version = ">=0.6,<32", optional = true } db-dtypes = { version = ">=0.3,<2", optional = true } deltalake = { version = ">=0.9.0,<1", optional = true } duckdb = { version = ">=0.8.1,<1", optional = true } -duckdb-engine = { version = ">=0.1.8,<1", optional = true } fsspec = { version = ">=2022.1.0", optional = true } GeoAlchemy2 = { version = ">=0.6.3,<1,!=0.13.0,!=0.14.0,!=0.14.1", optional = true } geopandas = { version = ">=0.6,<1", optional = true } @@ -149,7 +148,6 @@ all = [ "datafusion", "db-dtypes", "duckdb", - "duckdb-engine", "deltalake", "fsspec", "GeoAlchemy2", @@ -186,13 +184,7 @@ clickhouse = ["clickhouse-connect", "sqlalchemy"] dask = ["dask", "regex"] datafusion = ["datafusion"] druid = ["pydruid", "sqlalchemy"] -duckdb = [ - "duckdb", - "duckdb-engine", - "packaging", - "sqlalchemy", - "sqlalchemy-views", -] +duckdb = ["duckdb", "packaging", "sqlalchemy"] flink = [] geospatial = ["GeoAlchemy2", "geopandas", "shapely"] impala = ["fsspec", "impyla", "requests", "sqlalchemy"] @@ -293,9 +285,6 @@ filterwarnings = [ 'ignore:`np\.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning', # numpy, coming from a pandas call 'ignore:In the future `np\.bool` will be defined as the corresponding NumPy scalar:FutureWarning', - # duckdb-engine - 'ignore:Dialect .+ does \*not\* support Decimal:', - "ignore:duckdb-engine doesn't yet support reflection on indices:", # druid 'ignore:Dialect druid.rest will not make use of SQL compilation caching:', # ibis diff --git a/requirements-dev.txt b/requirements-dev.txt index 89e6eb39400b..5069352d72fe 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -39,7 +39,6 @@ debugpy==1.6.7.post1 ; python_version >= "3.10" and python_version < "4.0" decorator==5.1.1 ; python_version >= "3.9" and python_version < "4.0" deltalake==0.10.1 ; python_version >= "3.9" and python_version < "4.0" distlib==0.3.7 ; python_version >= "3.9" and python_version < "4.0" -duckdb-engine==0.9.2 ; python_version >= "3.9" and python_version < "4.0" duckdb==0.8.1 ; python_version >= "3.9" and python_version < "4.0" dunamai==1.18.0 ; python_version >= "3.9" and python_version < "4.0" exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11" From 6779ee3aca62ad8ea3c77eac6dfef14e963ecebd Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:25:42 -0400 Subject: [PATCH 189/222] fixup! chore(deps): remove duckdb-engine from deps --- ibis/backends/duckdb/tests/conftest.py | 2 +- ibis/backends/duckdb/tests/test_datatypes.py | 39 -------------------- ibis/tests/benchmarks/test_benchmarks.py | 3 -- 3 files changed, 1 insertion(+), 43 deletions(-) diff --git a/ibis/backends/duckdb/tests/conftest.py b/ibis/backends/duckdb/tests/conftest.py index 3fe3cfd2cdaa..81eeffe5ea5c 100644 --- a/ibis/backends/duckdb/tests/conftest.py +++ b/ibis/backends/duckdb/tests/conftest.py @@ -18,7 +18,7 @@ class TestConf(BackendTest, RoundAwayFromZero): supports_map = True - deps = "duckdb", "duckdb_engine" + deps = ("duckdb",) stateful = False supports_tpch = True diff --git a/ibis/backends/duckdb/tests/test_datatypes.py b/ibis/backends/duckdb/tests/test_datatypes.py index 5533c5e08ce3..65a5b50f91c0 100644 --- a/ibis/backends/duckdb/tests/test_datatypes.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -1,13 +1,10 @@ from __future__ import annotations -import duckdb_engine import pytest -import sqlalchemy as sa import sqlglot as sg from packaging.version import parse as vparse from pytest import param -import ibis.backends.base.sql.alchemy.datatypes as sat import ibis.common.exceptions as exc import ibis.expr.datatypes as dt from ibis.backends.base.sqlglot.datatypes import DuckDBType @@ -96,39 +93,3 @@ def test_parse_quoted_struct_field(): assert DuckDBType.from_string('STRUCT("a" INTEGER, "a b c" INTEGER)') == dt.Struct( {"a": dt.int32, "a b c": dt.int32} ) - - -def test_generate_quoted_struct(): - typ = sat.StructType( - {"in come": sa.VARCHAR(), "my count": sa.BIGINT(), "thing": sa.INTEGER()} - ) - result = typ.compile(dialect=duckdb_engine.Dialect()) - expected = 'STRUCT("in come" VARCHAR, "my count" BIGINT, thing INTEGER)' - assert result == expected - - -@pytest.mark.xfail( - condition=vparse(duckdb_engine.__version__) < vparse("0.9.2"), - raises=AssertionError, - reason="mapping from UINTEGER query metadata fixed in 0.9.2", -) -def test_read_uint8_from_parquet(tmp_path): - import numpy as np - - import ibis - - con = ibis.duckdb.connect() - - # There is an incorrect mapping in duckdb-engine from UInteger -> UInt8 - # In order to get something that reads as a UInt8, we cast to UInt32 (UInteger) - t = ibis.memtable({"a": np.array([1, 2, 3, 4], dtype="uint32")}) - assert t.a.type() == dt.uint32 - - parqpath = tmp_path / "uint.parquet" - - con.to_parquet(t, parqpath) - - # If this doesn't fail, then things are working - t2 = con.read_parquet(parqpath) - - assert t2.schema() == t.schema() diff --git a/ibis/tests/benchmarks/test_benchmarks.py b/ibis/tests/benchmarks/test_benchmarks.py index 8ad6266141cc..4b341077e6ba 100644 --- a/ibis/tests/benchmarks/test_benchmarks.py +++ b/ibis/tests/benchmarks/test_benchmarks.py @@ -715,7 +715,6 @@ def test_repr_join(benchmark, customers, orders, orders_items, products): @pytest.mark.parametrize("overwrite", [True, False], ids=["overwrite", "no_overwrite"]) def test_insert_duckdb(benchmark, overwrite, tmp_path): pytest.importorskip("duckdb") - pytest.importorskip("duckdb_engine") n_rows = int(1e4) table_name = "t" @@ -806,7 +805,6 @@ def test_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: def test_ibis_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: pytest.importorskip("duckdb") - pytest.importorskip("duckdb_engine") con = ibis.duckdb.connect(ddb, read_only=True) @@ -876,7 +874,6 @@ def test_big_join_expr(benchmark, src, diff): def test_big_join_execute(benchmark, nrels): pytest.importorskip("duckdb") - pytest.importorskip("duckdb_engine") con = ibis.duckdb.connect() From 5ce93ddfb66782940a29071bc6517cc46645ed7e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:27:07 -0400 Subject: [PATCH 190/222] ci: remove duckdb-engine from ci --- .github/renovate.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/renovate.json b/.github/renovate.json index 1ce10ad261a6..b7a005f0e255 100644 --- a/.github/renovate.json +++ b/.github/renovate.json @@ -46,7 +46,7 @@ "addLabels": ["bigquery"] }, { - "matchPackagePatterns": ["duckdb", "duckdb-engine"], + "matchPackagePatterns": ["duckdb"], "addLabels": ["duckdb"] }, { From 787acdc065e7f2cffba2afafe5984bbf4f87a76d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:39:43 -0400 Subject: [PATCH 191/222] chore: translate the table full stop, for table array view, instead of mucking around with `translate_rel` --- ibis/backends/duckdb/compiler/values.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 49c05b42b12e..80a64c8e17ec 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1180,12 +1180,9 @@ def _table_array_view(op, *, cache, **kw): try: return cache[table] except KeyError: - from ibis.backends.duckdb.compiler.relations import translate_rel + from ibis.backends.duckdb.compiler import translate - # ignore the top level table, so that we can compile its dependencies - (leaf,) = an.find_immediate_parent_tables(table, keep_input=False) - res = translate_rel(table, table=cache[leaf], cache=cache, **kw) - return res.subquery() + return translate(table, {}) @translate_val.register(ops.ExistsSubquery) From 2eab3b8ace16e95d540f4ce6d2af3064b29e7d73 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:40:14 -0400 Subject: [PATCH 192/222] chore: handle subquery only on selects --- ibis/backends/duckdb/compiler/values.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 80a64c8e17ec..341d71fef1cf 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1190,11 +1190,14 @@ def _exists_subquery(op, **kw): from ibis.backends.duckdb.compiler import translate foreign_table = translate(op.foreign_table, {}) + + # only construct a subquery if we cannot refer to the table directly + if isinstance(foreign_table, sg.exp.Select): + foreign_table = foreign_table.subquery() + predicates = translate_val(op.predicates, **kw) return sg.exp.Exists( - this=sg.select(1) - .from_(foreign_table.subquery()) - .where(sg.condition(predicates)) + this=sg.select(1).from_(foreign_table).where(sg.condition(predicates)) ) From eeaa2aac18d3cfd83c5db1f5b09ae8fcce5fa54a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:43:30 -0400 Subject: [PATCH 193/222] chore: bring back decimal warning --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e5a48620741b..57339f3062c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -299,6 +299,7 @@ filterwarnings = [ # sqlalchemy "ignore:Class ST_.+ will not make use of SQL compilation caching:", "ignore:UserDefinedType Geometry:", + 'ignore:Dialect .+ does \*not\* support Decimal:', # google "ignore:Deprecated call to `pkg_resources\\.declare_namespace\\('.*'\\):DeprecationWarning", # pyspark on python 3.11 From 76d78ce38ae3cbc63774b31f80d258526244cdd3 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:30:55 -0400 Subject: [PATCH 194/222] chore: remove unused `_sql` function --- ibis/backends/duckdb/compiler/values.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 341d71fef1cf..4e0aede62bad 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -62,13 +62,6 @@ def _alias(op, render_aliases: bool = True, **kw): ### Literals -def _sql(obj): - try: - return obj.sql(dialect="duckdb") - except AttributeError: - return obj - - def sg_literal(arg, is_string=True): return sg.exp.Literal(this=f"{arg}", is_string=is_string) From c87f9dd3fe53c1e4175ac1afb976bdf2c1114907 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:50:36 -0400 Subject: [PATCH 195/222] chore: avoid reprojection of dataframe when converting to ibis types --- ibis/formats/pandas.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ibis/formats/pandas.py b/ibis/formats/pandas.py index ab8e0b3a1151..6f0977722bb2 100644 --- a/ibis/formats/pandas.py +++ b/ibis/formats/pandas.py @@ -111,9 +111,7 @@ def convert_table(cls, df, schema): for name, dtype in schema.items(): df[name] = cls.convert_column(df[name], dtype) - # return data with the schema's columns which may be different than the - # input columns - return df.loc[:, list(schema.names)] + return df @classmethod def convert_column(cls, obj, dtype): From 8bfe95bccd1477d822e9c88cfca560b9e138175c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 06:33:43 -0400 Subject: [PATCH 196/222] chore: impala fix naming --- ibis/backends/impala/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index e92f06dc44bd..c0d1c8e6e546 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -348,10 +348,11 @@ def list_tables(self, like=None, database=None): tables = [row[0] for row in cursor.fetchall()] return self._filter_with_like(tables) - def fetch_from_cursor(self, cursor, schema): + def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: batches = cursor.fetchall(columnar=True) names = [name for name, *_ in cursor.description] df = _column_batches_to_dataframe(names, batches) + df.columns = list(schema.names) if schema: return PandasData.convert_table(df, schema) return df From 74618476efb3710f8c49f44e6134f92e62905542 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:51:21 -0400 Subject: [PATCH 197/222] chore: generate group by indices instead of repeating the expression --- ibis/backends/duckdb/compiler/relations.py | 11 +++++++++-- .../test_group_by_has_index/duckdb/out.sql | 18 +----------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 09f900387726..5e83494dd7a2 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -76,8 +76,15 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): selections = (by + metrics) or "*" sel = sg.select(*selections).from_(table) - if group_keys := op.by: - sel = sel.group_by(*map(tr_val_no_alias, group_keys)) + if op.by: + # avoids translation of group by keys twice and makes the output more + # concise + sel = sel.group_by( + *( + sg.exp.Literal(this=str(key), is_string=False) + for key in range(1, len(op.by) + 1) + ) + ) if predicates := op.predicates: sel = sel.where(*map(tr_val_no_alias, predicates)) diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql index 241f1095fd1e..c59492dbc306 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/duckdb/out.sql @@ -19,20 +19,4 @@ SELECT SUM(t0.population) AS total_pop FROM "countries" AS t0 GROUP BY - CASE t0.continent - WHEN 'NA' - THEN 'North America' - WHEN 'SA' - THEN 'South America' - WHEN 'EU' - THEN 'Europe' - WHEN 'AF' - THEN 'Africa' - WHEN 'AS' - THEN 'Asia' - WHEN 'OC' - THEN 'Oceania' - WHEN 'AN' - THEN 'Antarctica' - ELSE 'Unknown continent' - END \ No newline at end of file + 1 \ No newline at end of file From aaa25dd8cbfdd6be885d604ba872afcefe3c117f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:52:01 -0400 Subject: [PATCH 198/222] fixup! chore: generate group by indices instead of repeating the expression --- ibis/backends/impala/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index c0d1c8e6e546..a8f95d10c0f7 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -352,8 +352,8 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: batches = cursor.fetchall(columnar=True) names = [name for name, *_ in cursor.description] df = _column_batches_to_dataframe(names, batches) - df.columns = list(schema.names) if schema: + df.columns = list(schema.names) return PandasData.convert_table(df, schema) return df From 46c0cdaa452142d57ec2f2a660e359e7dd6e5d88 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:00:01 -0400 Subject: [PATCH 199/222] chore: remove use of `sg.condition` in favor of `sg.and_` --- ibis/backends/duckdb/compiler/relations.py | 17 +++++------------ ibis/backends/duckdb/compiler/values.py | 6 ++---- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 5e83494dd7a2..acbf03564053 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -54,11 +54,7 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): if predicates := op.predicates: if join is not None: sel = sg.select("*").from_(sel.subquery(kw["aliases"][op.table])) - res = functools.reduce( - lambda left, right: left.and_(right), - (sg.condition(tr_val(predicate)) for predicate in predicates), - ) - sel = sel.where(res) + sel = sel.where(sg.and_(*map(tr_val, predicates))) if sort_keys := op.sort_keys: sel = sel.order_by(*map(tr_val, sort_keys)) @@ -82,6 +78,7 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): sel = sel.group_by( *( sg.exp.Literal(this=str(key), is_string=False) + # keys are refer for key in range(1, len(op.by) + 1) ) ) @@ -113,13 +110,9 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): @translate_rel.register def _join(op: ops.Join, *, left, right, **kw): predicates = op.predicates - if predicates: - on = functools.reduce( - lambda left, right: left.and_(right), - (sg.condition(translate_val(predicate, **kw)) for predicate in predicates), - ) - else: - on = None + + on = sg.and_(*map(partial(translate_val, **kw), predicates)) if predicates else None + join_type = _JOIN_TYPES[type(op)] try: return left.join(right, join_type=join_type, on=on) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 4e0aede62bad..4cc9319dad24 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -1188,10 +1188,8 @@ def _exists_subquery(op, **kw): if isinstance(foreign_table, sg.exp.Select): foreign_table = foreign_table.subquery() - predicates = translate_val(op.predicates, **kw) - return sg.exp.Exists( - this=sg.select(1).from_(foreign_table).where(sg.condition(predicates)) - ) + predicate = sg.and_(*map(partial(translate_val, **kw), op.predicates)) + return sg.exp.Exists(this=sg.select(1).from_(foreign_table).where(predicate)) @translate_val.register(ops.NotExistsSubquery) From fe72a4172193843fe9872be81efd6285281e1969 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:03:45 -0400 Subject: [PATCH 200/222] chore: pluck out `aliases` kwarg --- ibis/backends/duckdb/compiler/relations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index acbf03564053..26e80c4dbe9e 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -32,7 +32,7 @@ def _physical_table(op, **_): @translate_rel.register(ops.Selection) -def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): +def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): # needs_alias should never be true here in explicitly, but it may get # passed via a (recursive) call to translate_val assert not needs_alias, "needs_alias is True" @@ -44,7 +44,7 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): (join,) = args["joins"] else: from_ = join = None - tr_val = partial(translate_val, needs_alias=needs_alias, **kw) + tr_val = partial(translate_val, needs_alias=needs_alias, aliases=aliases, **kw) selections = tuple(map(tr_val, op.selections)) or "*" sel = sg.select(*selections).from_(from_ if from_ is not None else table) @@ -53,7 +53,7 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, **kw): if predicates := op.predicates: if join is not None: - sel = sg.select("*").from_(sel.subquery(kw["aliases"][op.table])) + sel = sg.select("*").from_(sel.subquery(aliases[op.table])) sel = sel.where(sg.and_(*map(tr_val, predicates))) if sort_keys := op.sort_keys: From 9b9b5e4d0fd7cf944e6e4e6c47ae409a02779f34 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:04:45 -0400 Subject: [PATCH 201/222] chore: remove `remove_aliases` kwarg from duckdb compiler --- ibis/backends/duckdb/compiler/relations.py | 2 +- ibis/backends/duckdb/compiler/values.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 26e80c4dbe9e..da84e0665b21 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -65,7 +65,7 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): @translate_rel.register(ops.Aggregation) def _aggregation(op: ops.Aggregation, *, table, **kw): tr_val = partial(translate_val, **kw) - tr_val_no_alias = partial(translate_val, render_aliases=False, **kw) + tr_val_no_alias = partial(translate_val, **kw) by = tuple(map(tr_val, op.by)) metrics = tuple(map(tr_val, op.metrics)) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 4cc9319dad24..44c2a3d86aa3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -52,11 +52,9 @@ def _column(op, *, aliases, **_): @translate_val.register(ops.Alias) -def _alias(op, render_aliases: bool = True, **kw): - val = translate_val(op.arg, render_aliases=render_aliases, **kw) - if render_aliases: - return sg.alias(val, op.name) - return val +def _alias(op, **kw): + val = translate_val(op.arg, **kw) + return sg.alias(val, op.name) ### Literals From d9e94840bf21ac1b0fa0a3bfa4bea0c30af5275e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:16:40 -0400 Subject: [PATCH 202/222] chore(duckdb): use `sg_literal` everywhere --- ibis/backends/duckdb/compiler/values.py | 59 ++++++------------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 44c2a3d86aa3..66020b5879ff 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -472,7 +472,7 @@ def _timestamp_from_ymdhms(op, **kw): func = "make_timestamp" if (timezone := op.dtype.timezone) is not None: func += "tz" - args.append(sg.exp.Literal(this=timezone, is_string=True)) + args.append(sg_literal(timezone)) return sg.func(func, *args) @@ -524,7 +524,7 @@ def _extract_epoch_seconds(op, **kw): def _extract_time(op, **kw): part = _extract_mapping[type(op)] timestamp = translate_val(op.arg, **kw) - return sg.func("extract", sg.exp.Literal(this=part, is_string=True), timestamp) + return sg.func("extract", sg_literal(part), timestamp) # DuckDB extracts subminute microseconds and milliseconds @@ -534,12 +534,8 @@ def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) return sg.exp.Mod( - this=sg.func( - "extract", - sg.exp.Literal(this="us", is_string=True), - arg, - ), - expression=sg.exp.Literal(this="1000000", is_string=False), + this=sg.func("extract", sg_literal("us"), arg), + expression=sg_literal(1_000_000, is_string=False), ) @@ -548,12 +544,8 @@ def _extract_microsecond(op, **kw): arg = translate_val(op.arg, **kw) return sg.exp.Mod( - this=sg.func( - "extract", - sg.exp.Literal(this="ms", is_string=True), - arg, - ), - expression=sg.exp.Literal(this="1000", is_string=False), + this=sg.func("extract", sg_literal("ms"), arg), + expression=sg_literal(1_000, is_string=False), ) @@ -650,8 +642,7 @@ def _interval_format(op): ) return sg.exp.Interval( - this=sg.exp.Literal(this=op.value, is_string=False), - unit=dtype.resolution.upper(), + this=sg_literal(op.value, is_string=False), unit=dtype.resolution.upper() ) @@ -700,9 +691,7 @@ def _substring(op, **kw): if_neg = sg.exp.Substring(this=arg, start=start, length=length) return sg.exp.If( - this=sg.exp.GTE( - this=start, expression=sg.exp.Literal(this="0", is_string=False) - ), + this=sg.exp.GTE(this=start, expression=sg_literal(0, is_string=False)), true=if_pos, false=if_neg, ) @@ -723,9 +712,7 @@ def _string_find(op, **kw): def _regex_search(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) - return sg.func( - "regexp_matches", arg, pattern, sg.exp.Literal(this="s", is_string=True) - ) + return sg.func("regexp_matches", arg, pattern, sg_literal("s")) @translate_val.register(ops.RegexReplace) @@ -733,13 +720,7 @@ def _regex_replace(op, **kw): arg = translate_val(op.arg, **kw) pattern = translate_val(op.pattern, **kw) replacement = translate_val(op.replacement, **kw) - return sg.func( - "regexp_replace", - arg, - pattern, - replacement, - sg.exp.Literal(this="g", is_string=True), - ) + return sg.func("regexp_replace", arg, pattern, replacement, sg_literal("g")) @translate_val.register(ops.RegexExtract) @@ -1204,10 +1185,7 @@ def _array_column(op, **kw): def _struct_column(op, **kw): return sg.exp.Struct.from_arg_list( [ - sg.exp.Slice( - this=sg.exp.Literal(this=name, is_string=True), - expression=translate_val(value, **kw), - ) + sg.exp.Slice(this=sg_literal(name), expression=translate_val(value, **kw)) for name, value in zip(op.names, op.values) ] ) @@ -1216,9 +1194,7 @@ def _struct_column(op, **kw): @translate_val.register(ops.StructField) def _struct_field(op, **kw): arg = translate_val(unalias(op.arg), **kw) - return sg.exp.StructExtract( - this=arg, expression=sg.exp.Literal(this=op.field, is_string=True) - ) + return sg.exp.StructExtract(this=arg, expression=sg_literal(op.field)) @translate_val.register(ops.ScalarParameter) @@ -1265,15 +1241,8 @@ def _map_contains(op, **kw): arg = translate_val(op.arg, **kw) key = translate_val(op.key, **kw) return sg.exp.NEQ( - this=sg.func( - "array_length", - sg.func( - "element_at", - arg, - key, - ), - ), - expression=sg.exp.Literal(this="0", is_string=False), + this=sg.func("array_length", sg.func("element_at", arg, key)), + expression=sg_literal(0, is_string=False), ) From 5be5c91904559aa56b6388c2d1f99ca0941febbb Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:31:42 -0400 Subject: [PATCH 203/222] chore: unalias everything that should not have `expr AS name` --- ibis/backends/duckdb/compiler/relations.py | 22 +++++++++++++--------- ibis/backends/duckdb/compiler/values.py | 6 +++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index da84e0665b21..aa130c254a69 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -10,6 +10,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.backends.base.sqlglot import unalias from ibis.backends.duckdb.compiler.values import translate_val @@ -54,10 +55,10 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): if predicates := op.predicates: if join is not None: sel = sg.select("*").from_(sel.subquery(aliases[op.table])) - sel = sel.where(sg.and_(*map(tr_val, predicates))) + sel = sel.where(sg.and_(*map(tr_val, map(unalias, predicates)))) if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val, sort_keys)) + sel = sel.order_by(*map(tr_val, map(unalias, sort_keys))) return sel @@ -65,7 +66,6 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): @translate_rel.register(ops.Aggregation) def _aggregation(op: ops.Aggregation, *, table, **kw): tr_val = partial(translate_val, **kw) - tr_val_no_alias = partial(translate_val, **kw) by = tuple(map(tr_val, op.by)) metrics = tuple(map(tr_val, op.metrics)) @@ -84,13 +84,13 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): ) if predicates := op.predicates: - sel = sel.where(*map(tr_val_no_alias, predicates)) + sel = sel.where(*map(tr_val, map(unalias, predicates))) if having := op.having: - sel = sel.having(*map(tr_val_no_alias, having)) + sel = sel.having(*map(tr_val, map(unalias, having))) if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val_no_alias, sort_keys)) + sel = sel.order_by(*map(tr_val, map(unalias, sort_keys))) return sel @@ -111,7 +111,11 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): def _join(op: ops.Join, *, left, right, **kw): predicates = op.predicates - on = sg.and_(*map(partial(translate_val, **kw), predicates)) if predicates else None + on = ( + sg.and_(*map(partial(translate_val, **kw), map(unalias, predicates))) + if predicates + else None + ) join_type = _JOIN_TYPES[type(op)] try: @@ -220,9 +224,9 @@ def _dropna(op: ops.DropNa, *, table, **kw): tr_val = partial(translate_val, **kw) predicate = tr_val(raw_predicate) try: - return table.where(predicate) + return table.where(unalias(predicate)) except AttributeError: - return sg.select("*").from_(table).where(predicate) + return sg.select("*").from_(table).where(unalias(predicate)) @translate_rel.register diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 66020b5879ff..1e06d810c739 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -394,7 +394,7 @@ def _not(op, **kw): def _apply_agg_filter(expr, *, where, **kw): if where is not None: return sg.exp.Filter( - this=expr, expression=sg.exp.Where(this=translate_val(where, **kw)) + this=expr, expression=sg.exp.Where(this=translate_val(unalias(where), **kw)) ) return expr @@ -870,7 +870,7 @@ def _array_repeat_op(op, **kw): return sg.func( "flatten", sg.select( - sg.func("array", sg.select(arg).from_(sg.func("range", times))) + sg.func("array", sg.select(arg).from_(sg.func("range", unalias(times)))) ).subquery(), ) @@ -1167,7 +1167,7 @@ def _exists_subquery(op, **kw): if isinstance(foreign_table, sg.exp.Select): foreign_table = foreign_table.subquery() - predicate = sg.and_(*map(partial(translate_val, **kw), op.predicates)) + predicate = sg.and_(*map(partial(translate_val, **kw), map(unalias, op.predicates))) return sg.exp.Exists(this=sg.select(1).from_(foreign_table).where(predicate)) From 282bdfb0115a6067430cd05bd796018501e47b77 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:56:04 -0400 Subject: [PATCH 204/222] chore: remove rebase screwups --- .../backends/clickhouse/compiler/relations.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/ibis/backends/clickhouse/compiler/relations.py b/ibis/backends/clickhouse/compiler/relations.py index 8435aa9c067e..5e209935830a 100644 --- a/ibis/backends/clickhouse/compiler/relations.py +++ b/ibis/backends/clickhouse/compiler/relations.py @@ -7,12 +7,7 @@ import ibis.common.exceptions as com import ibis.expr.operations as ops -<<<<<<< HEAD from ibis.backends.base.sqlglot import FALSE, NULL, STAR -======= -from ibis.backends.base.sqlglot import unalias -from ibis.backends.clickhouse.compiler.values import translate_val ->>>>>>> 4ae077d66 (chore: factor out unaliasing) @functools.singledispatch @@ -90,7 +85,6 @@ def _selection( @translate_rel.register(ops.Aggregation) -<<<<<<< HEAD def _aggregation( op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_ ): @@ -110,27 +104,6 @@ def _aggregation( if sort_keys: sel = sel.order_by(*sort_keys) -======= -def _aggregation(op: ops.Aggregation, *, table, **kw): - tr_val = partial(translate_val, **kw) - - by = tuple(map(tr_val, op.by)) - metrics = tuple(map(tr_val, op.metrics)) - selections = (by + metrics) or "*" - sel = sg.select(*selections).from_(table) - - if group_keys := op.by: - sel = sel.group_by(*map(tr_val, map(unalias, group_keys)), dialect="clickhouse") - - if predicates := op.predicates: - sel = sel.where(*map(tr_val, map(unalias, predicates)), dialect="clickhouse") - - if having := op.having: - sel = sel.having(*map(tr_val, map(unalias, having)), dialect="clickhouse") - - if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val, map(unalias, sort_keys)), dialect="clickhouse") ->>>>>>> 4ae077d66 (chore: factor out unaliasing) return sel From 189e427a38eda80122cd8467999eae8727fbb8c0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:59:54 -0400 Subject: [PATCH 205/222] chore: bring back unalias temporarily --- ibis/backends/base/sqlglot/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index cc6d016bfa81..0ae090a8f960 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: import ibis.expr.datatypes as dt + import ibis.expr.operations as ops from ibis.backends.base.sqlglot.datatypes import SqlglotType @@ -67,3 +68,14 @@ def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast: return sg.cast(arg, to=converter.from_ibis(to)) return cast + + +def unalias(op: ops.Value) -> ops.Value: + """Unwrap `Alias` objects. + + Necessary when rendering `WHERE`, `GROUP BY` and `ORDER BY` and other + clauses. + """ + import ibis.expr.operations as ops + + return op.arg if isinstance(op, ops.Alias) else op From 29aeac5458ad5ac1aebda144eb9b4d8dd418568e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:11:06 -0400 Subject: [PATCH 206/222] chore: regen sql --- .../test_many_subqueries/duckdb/out.sql | 102 ++++++++--- .../test_union_aliasing/duckdb/out.sql | 171 +++++++++++------- 2 files changed, 173 insertions(+), 100 deletions(-) diff --git a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql index fb8c40cd69ba..de6339220a50 100644 --- a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql @@ -1,32 +1,74 @@ -WITH t0 AS ( - SELECT - t5.street AS street, - ROW_NUMBER() OVER (ORDER BY t5.street ASC) - 1 AS key - FROM data AS t5 -), t1 AS ( - SELECT - t0.key AS key - FROM t0 -), t2 AS ( - SELECT - t0.street AS street, - t0.key AS key - FROM t0 - JOIN t1 - ON t0.key = t1.key -), t3 AS ( +SELECT + t5.street, + t5.key +FROM ( SELECT - t2.street AS street, - ROW_NUMBER() OVER (ORDER BY t2.street ASC) - 1 AS key - FROM t2 -), t4 AS ( + t4.street, + ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM ( + SELECT + t1.street, + t1.key + FROM ( + SELECT + t0.*, + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM "data" AS t0 + ) AS t1 + INNER JOIN ( + SELECT + t1.key + FROM ( + SELECT + t0.*, + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM "data" AS t0 + ) AS t1 + ) AS t2 + ON ( + t1.key + ) = ( + t2.key + ) + ) AS t4 +) AS t5 +INNER JOIN ( SELECT - t3.key AS key - FROM t3 -) -SELECT - t3.street, - t3.key -FROM t3 -JOIN t4 - ON t3.key = t4.key \ No newline at end of file + t5.key + FROM ( + SELECT + t4.street, + ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM ( + SELECT + t1.street, + t1.key + FROM ( + SELECT + t0.*, + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM "data" AS t0 + ) AS t1 + INNER JOIN ( + SELECT + t1.key + FROM ( + SELECT + t0.*, + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + FROM "data" AS t0 + ) AS t1 + ) AS t2 + ON ( + t1.key + ) = ( + t2.key + ) + ) AS t4 + ) AS t5 +) AS t6 + ON ( + t5.key + ) = ( + t6.key + ) \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql index 8335befe6765..62adad32f7e3 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql @@ -1,75 +1,106 @@ -WITH t0 AS ( - SELECT - t7.field_of_study AS field_of_study, - UNNEST( - CAST([{'years': '1970-71', 'degrees': t7."1970-71"}, {'years': '1975-76', 'degrees': t7."1975-76"}, {'years': '1980-81', 'degrees': t7."1980-81"}, {'years': '1985-86', 'degrees': t7."1985-86"}, {'years': '1990-91', 'degrees': t7."1990-91"}, {'years': '1995-96', 'degrees': t7."1995-96"}, {'years': '2000-01', 'degrees': t7."2000-01"}, {'years': '2005-06', 'degrees': t7."2005-06"}, {'years': '2010-11', 'degrees': t7."2010-11"}, {'years': '2011-12', 'degrees': t7."2011-12"}, {'years': '2012-13', 'degrees': t7."2012-13"}, {'years': '2013-14', 'degrees': t7."2013-14"}, {'years': '2014-15', 'degrees': t7."2014-15"}, {'years': '2015-16', 'degrees': t7."2015-16"}, {'years': '2016-17', 'degrees': t7."2016-17"}, {'years': '2017-18', 'degrees': t7."2017-18"}, {'years': '2018-19', 'degrees': t7."2018-19"}, {'years': '2019-20', 'degrees': t7."2019-20"}] AS STRUCT(years TEXT, degrees BIGINT)[]) - ) AS __pivoted__ - FROM humanities AS t7 -), t1 AS ( - SELECT - t0.field_of_study AS field_of_study, - STRUCT_EXTRACT(t0.__pivoted__, 'years') AS years, - STRUCT_EXTRACT(t0.__pivoted__, 'degrees') AS degrees - FROM t0 -), t2 AS ( - SELECT - t1.field_of_study AS field_of_study, - t1.years AS years, - t1.degrees AS degrees, - FIRST_VALUE(t1.degrees) OVER (PARTITION BY t1.field_of_study ORDER BY t1.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS earliest_degrees, - LAST_VALUE(t1.degrees) OVER (PARTITION BY t1.field_of_study ORDER BY t1.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS latest_degrees - FROM t1 -), t3 AS ( - SELECT - t2.field_of_study AS field_of_study, - t2.years AS years, - t2.degrees AS degrees, - t2.earliest_degrees AS earliest_degrees, - t2.latest_degrees AS latest_degrees, - t2.latest_degrees - t2.earliest_degrees AS diff - FROM t2 -), t4 AS ( - SELECT - t3.field_of_study AS field_of_study, - FIRST(t3.diff) AS diff - FROM t3 - GROUP BY - 1 -), anon_1 AS ( - SELECT - t4.field_of_study AS field_of_study, - t4.diff AS diff - FROM t4 - ORDER BY - t4.diff DESC - LIMIT 10 -), t5 AS ( - SELECT - t4.field_of_study AS field_of_study, - t4.diff AS diff - FROM t4 - WHERE - t4.diff < CAST(0 AS TINYINT) -), anon_2 AS ( - SELECT - t5.field_of_study AS field_of_study, - t5.diff AS diff - FROM t5 - ORDER BY - t5.diff ASC - LIMIT 10 -) SELECT - t6.field_of_study, - t6.diff + t11.field_of_study, + t11.diff FROM ( SELECT - anon_1.field_of_study AS field_of_study, - anon_1.diff AS diff - FROM anon_1 + * + FROM ( + SELECT + * + FROM ( + SELECT + t4.field_of_study, + FIRST(t4.diff) AS diff + FROM ( + SELECT + *, + ( + t3.latest_degrees + ) - ( + t3.earliest_degrees + ) AS diff + FROM ( + SELECT + *, + FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS earliest_degrees, + LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS latest_degrees + FROM ( + SELECT + t1.field_of_study, + STRUCT_EXTRACT(t1.__pivoted__, 'years') AS years, + STRUCT_EXTRACT(t1.__pivoted__, 'degrees') AS degrees + FROM ( + SELECT + t0.field_of_study, + UNNEST( + [{'years': '1970-71', 'degrees': t0."1970-71"}, {'years': '1975-76', 'degrees': t0."1975-76"}, {'years': '1980-81', 'degrees': t0."1980-81"}, {'years': '1985-86', 'degrees': t0."1985-86"}, {'years': '1990-91', 'degrees': t0."1990-91"}, {'years': '1995-96', 'degrees': t0."1995-96"}, {'years': '2000-01', 'degrees': t0."2000-01"}, {'years': '2005-06', 'degrees': t0."2005-06"}, {'years': '2010-11', 'degrees': t0."2010-11"}, {'years': '2011-12', 'degrees': t0."2011-12"}, {'years': '2012-13', 'degrees': t0."2012-13"}, {'years': '2013-14', 'degrees': t0."2013-14"}, {'years': '2014-15', 'degrees': t0."2014-15"}, {'years': '2015-16', 'degrees': t0."2015-16"}, {'years': '2016-17', 'degrees': t0."2016-17"}, {'years': '2017-18', 'degrees': t0."2017-18"}, {'years': '2018-19', 'degrees': t0."2018-19"}, {'years': '2019-20', 'degrees': t0."2019-20"}] + ) AS __pivoted__ + FROM "humanities" AS t0 + ) AS t1 + ) AS t2 + ) AS t3 + ) AS t4 + GROUP BY + 1 + ) AS t5 + ORDER BY + t5.diff DESC + ) AS t6 + LIMIT 10 UNION ALL SELECT - anon_2.field_of_study AS field_of_study, - anon_2.diff AS diff - FROM anon_2 -) AS t6 \ No newline at end of file + * + FROM ( + SELECT + * + FROM ( + SELECT + * + FROM ( + SELECT + t4.field_of_study, + FIRST(t4.diff) AS diff + FROM ( + SELECT + *, + ( + t3.latest_degrees + ) - ( + t3.earliest_degrees + ) AS diff + FROM ( + SELECT + *, + FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS earliest_degrees, + LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS latest_degrees + FROM ( + SELECT + t1.field_of_study, + STRUCT_EXTRACT(t1.__pivoted__, 'years') AS years, + STRUCT_EXTRACT(t1.__pivoted__, 'degrees') AS degrees + FROM ( + SELECT + t0.field_of_study, + UNNEST( + [{'years': '1970-71', 'degrees': t0."1970-71"}, {'years': '1975-76', 'degrees': t0."1975-76"}, {'years': '1980-81', 'degrees': t0."1980-81"}, {'years': '1985-86', 'degrees': t0."1985-86"}, {'years': '1990-91', 'degrees': t0."1990-91"}, {'years': '1995-96', 'degrees': t0."1995-96"}, {'years': '2000-01', 'degrees': t0."2000-01"}, {'years': '2005-06', 'degrees': t0."2005-06"}, {'years': '2010-11', 'degrees': t0."2010-11"}, {'years': '2011-12', 'degrees': t0."2011-12"}, {'years': '2012-13', 'degrees': t0."2012-13"}, {'years': '2013-14', 'degrees': t0."2013-14"}, {'years': '2014-15', 'degrees': t0."2014-15"}, {'years': '2015-16', 'degrees': t0."2015-16"}, {'years': '2016-17', 'degrees': t0."2016-17"}, {'years': '2017-18', 'degrees': t0."2017-18"}, {'years': '2018-19', 'degrees': t0."2018-19"}, {'years': '2019-20', 'degrees': t0."2019-20"}] + ) AS __pivoted__ + FROM "humanities" AS t0 + ) AS t1 + ) AS t2 + ) AS t3 + ) AS t4 + GROUP BY + 1 + ) AS t5 + WHERE + ( + t5.diff + ) < ( + CAST(0 AS TINYINT) + ) + ) AS t7 + ORDER BY + t7.diff ASC + ) AS t9 + LIMIT 10 +) AS t11 \ No newline at end of file From f46cb0f5e8778ab63d66d541d539c33752f494a7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:11:21 -0400 Subject: [PATCH 207/222] chore: raw_sql --- ibis/backends/duckdb/tests/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/duckdb/tests/conftest.py b/ibis/backends/duckdb/tests/conftest.py index 81eeffe5ea5c..5748167cc6ff 100644 --- a/ibis/backends/duckdb/tests/conftest.py +++ b/ibis/backends/duckdb/tests/conftest.py @@ -48,8 +48,7 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend: ) def load_tpch(self) -> None: - with self.connection.begin() as con: - con.exec_driver_sql("CALL dbgen(sf=0.1)") + self.connection.raw_sql("CALL dbgen(sf=0.1)") def _load_data(self, **_: Any) -> None: """Load test data into a backend.""" From 1360eaa0dd69889e892f3f264fabe1e6ee616be7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:18:46 -0400 Subject: [PATCH 208/222] chore: fix for new array function representation --- ibis/backends/duckdb/compiler/values.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 1e06d810c739..679548af0453 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any import sqlglot as sg -from toolz import flip import ibis import ibis.common.exceptions as com @@ -919,10 +918,10 @@ def _array_string_join(op, **kw): @translate_val.register(ops.ArrayMap) def _array_map(op, **kw): arg = translate_val(op.arg, **kw) - result = translate_val(op.result, **kw) + result = translate_val(op.body, **kw) lamduh = sg.exp.Lambda( this=result, - expressions=[sg.to_identifier(f"{op.parameter}", quoted=False)], + expressions=[sg.to_identifier(op.param, quoted=False)], ) return sg.func("list_transform", arg, lamduh) @@ -930,19 +929,20 @@ def _array_map(op, **kw): @translate_val.register(ops.ArrayFilter) def _array_filter(op, **kw): arg = translate_val(op.arg, **kw) - result = translate_val(op.result, **kw) + result = translate_val(op.body, **kw) lamduh = sg.exp.Lambda( this=result, - expressions=[sg.exp.Identifier(this=f"{op.parameter}", quoted=False)], + expressions=[sg.to_identifier(op.param, quoted=False)], ) return sg.func("list_filter", arg, lamduh) @translate_val.register(ops.ArrayIntersect) def _array_intersect(op, **kw): - return translate_val( - ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)), **kw - ) + param = "x" + x = ops.Argument(name=param, shape=op.left.shape, dtype=op.left.dtype.value_type) + body = ops.ArrayContains(op.right, x) + return translate_val(ops.ArrayFilter(arg=op.left, body=body, param=param), **kw) @translate_val.register(ops.ArrayPosition) @@ -954,7 +954,11 @@ def _array_position(op, **kw): @translate_val.register(ops.ArrayRemove) def _array_remove(op, **kw): - return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw) + param = "x" + arg = op.arg + x = ops.Argument(name=param, shape=arg.shape, dtype=arg.dtype.value_type) + body = ops.NotEquals(x, op.other) + return translate_val(ops.ArrayFilter(arg=arg, body=body, param=param), **kw) @translate_val.register(ops.ArrayUnion) From a2b848da07355d302bf9ef70bb9e23527ee6eba5 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:19:27 -0400 Subject: [PATCH 209/222] chore: regen passing tpch queries --- .../test_h01/test_tpc_h01/duckdb/h01.sql | 71 ++++--- .../test_h04/test_tpc_h04/duckdb/h04.sql | 36 +++- .../test_h06/test_tpc_h06/duckdb/h06.sql | 28 ++- .../test_h12/test_tpc_h12/duckdb/h12.sql | 53 +++-- .../test_h13/test_tpc_h13/duckdb/h13.sql | 42 ++-- .../test_h14/test_tpc_h14/duckdb/h14.sql | 69 ++++-- .../test_h15/test_tpc_h15/duckdb/h15.sql | 137 +++++++----- .../test_h16/test_tpc_h16/duckdb/h16.sql | 70 +++--- .../test_h17/test_tpc_h17/duckdb/h17.sql | 57 +++-- .../test_h19/test_tpc_h19/duckdb/h19.sql | 199 +++++++++++++++--- .../test_h20/test_tpc_h20/duckdb/h20.sql | 146 +++++++------ 11 files changed, 622 insertions(+), 286 deletions(-) diff --git a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql index d62bcc0045f3..6bf4e97fbfac 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql @@ -1,41 +1,56 @@ SELECT - t0.l_returnflag, - t0.l_linestatus, - t0.sum_qty, - t0.sum_base_price, - t0.sum_disc_price, - t0.sum_charge, - t0.avg_qty, - t0.avg_price, - t0.avg_disc, - t0.count_order + * FROM ( SELECT - t1.l_returnflag AS l_returnflag, - t1.l_linestatus AS l_linestatus, - SUM(t1.l_quantity) AS sum_qty, - SUM(t1.l_extendedprice) AS sum_base_price, - SUM(t1.l_extendedprice * ( - CAST(1 AS TINYINT) - t1.l_discount - )) AS sum_disc_price, + t0.l_returnflag, + t0.l_linestatus, + SUM(t0.l_quantity) AS sum_qty, + SUM(t0.l_extendedprice) AS sum_base_price, SUM( - t1.l_extendedprice * ( - CAST(1 AS TINYINT) - t1.l_discount + ( + t0.l_extendedprice ) * ( - t1.l_tax + CAST(1 AS TINYINT) + ( + CAST(1 AS TINYINT) + ) - ( + t0.l_discount + ) + ) + ) AS sum_disc_price, + SUM( + ( + ( + t0.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t0.l_discount + ) + ) + ) * ( + ( + t0.l_tax + ) + ( + CAST(1 AS TINYINT) + ) ) ) AS sum_charge, - AVG(t1.l_quantity) AS avg_qty, - AVG(t1.l_extendedprice) AS avg_price, - AVG(t1.l_discount) AS avg_disc, + AVG(t0.l_quantity) AS avg_qty, + AVG(t0.l_extendedprice) AS avg_price, + AVG(t0.l_discount) AS avg_disc, COUNT(*) AS count_order - FROM main.lineitem AS t1 + FROM "lineitem" AS t0 WHERE - t1.l_shipdate <= CAST('1998-09-02' AS DATE) + ( + t0.l_shipdate + ) <= ( + MAKE_DATE(1998, 9, 2) + ) GROUP BY 1, 2 -) AS t0 +) AS t1 ORDER BY - t0.l_returnflag ASC, - t0.l_linestatus ASC \ No newline at end of file + t1.l_returnflag ASC, + t1.l_linestatus ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql index 4c0294c32c0d..f58d71f9e961 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql @@ -1,19 +1,33 @@ SELECT t0.o_orderpriority, COUNT(*) AS order_count -FROM main.orders AS t0 +FROM "orders" AS t0 WHERE - ( - EXISTS( - SELECT - CAST(1 AS TINYINT) AS anon_1 - FROM main.lineitem AS t1 - WHERE - t1.l_orderkey = t0.o_orderkey AND t1.l_commitdate < t1.l_receiptdate - ) + EXISTS( + SELECT + 1 + FROM "lineitem" + WHERE + ( + l_orderkey + ) = ( + t0.o_orderkey + ) AND ( + l_commitdate + ) < ( + l_receiptdate + ) + ) + AND ( + t0.o_orderdate + ) >= ( + MAKE_DATE(1993, 7, 1) + ) + AND ( + t0.o_orderdate + ) < ( + MAKE_DATE(1993, 10, 1) ) - AND t0.o_orderdate >= CAST('1993-07-01' AS DATE) - AND t0.o_orderdate < CAST('1993-10-01' AS DATE) GROUP BY 1 ORDER BY diff --git a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql index 2dc5fbbe6aca..8c780e28b15a 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql @@ -1,8 +1,24 @@ SELECT - SUM(t0.l_extendedprice * t0.l_discount) AS revenue -FROM main.lineitem AS t0 + SUM(( + t0.l_extendedprice + ) * ( + t0.l_discount + )) AS revenue +FROM "lineitem" AS t0 WHERE - t0.l_shipdate >= CAST('1994-01-01' AS DATE) - AND t0.l_shipdate < CAST('1995-01-01' AS DATE) - AND t0.l_discount BETWEEN CAST(0.05 AS REAL(53)) AND CAST(0.07 AS REAL(53)) - AND t0.l_quantity < CAST(24 AS TINYINT) \ No newline at end of file + ( + t0.l_shipdate + ) >= ( + MAKE_DATE(1994, 1, 1) + ) + AND ( + t0.l_shipdate + ) < ( + MAKE_DATE(1995, 1, 1) + ) + AND t0.l_discount BETWEEN CAST(0.05 AS DOUBLE) AND CAST(0.07 AS DOUBLE) + AND ( + t0.l_quantity + ) < ( + CAST(24 AS TINYINT) + ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql index 1673eceba76c..af51e5041edf 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql @@ -1,12 +1,10 @@ SELECT - t0.l_shipmode, - t0.high_line_count, - t0.low_line_count + * FROM ( SELECT - t2.l_shipmode AS l_shipmode, + t2.l_shipmode, SUM( - CASE t1.o_orderpriority + CASE t2.o_orderpriority WHEN '1-URGENT' THEN CAST(1 AS TINYINT) WHEN '2-HIGH' @@ -15,7 +13,7 @@ FROM ( END ) AS high_line_count, SUM( - CASE t1.o_orderpriority + CASE t2.o_orderpriority WHEN '1-URGENT' THEN CAST(0 AS TINYINT) WHEN '2-HIGH' @@ -23,17 +21,42 @@ FROM ( ELSE CAST(1 AS TINYINT) END ) AS low_line_count - FROM main.orders AS t1 - JOIN main.lineitem AS t2 - ON t1.o_orderkey = t2.l_orderkey + FROM ( + SELECT + t0.*, + t1.* + FROM "orders" AS t0 + INNER JOIN "lineitem" AS t1 + ON ( + t0.o_orderkey + ) = ( + t1.l_orderkey + ) + ) AS t2 WHERE t2.l_shipmode IN ('MAIL', 'SHIP') - AND t2.l_commitdate < t2.l_receiptdate - AND t2.l_shipdate < t2.l_commitdate - AND t2.l_receiptdate >= CAST('1994-01-01' AS DATE) - AND t2.l_receiptdate < CAST('1995-01-01' AS DATE) + AND ( + t2.l_commitdate + ) < ( + t2.l_receiptdate + ) + AND ( + t2.l_shipdate + ) < ( + t2.l_commitdate + ) + AND ( + t2.l_receiptdate + ) >= ( + MAKE_DATE(1994, 1, 1) + ) + AND ( + t2.l_receiptdate + ) < ( + MAKE_DATE(1995, 1, 1) + ) GROUP BY 1 -) AS t0 +) AS t3 ORDER BY - t0.l_shipmode ASC \ No newline at end of file + t3.l_shipmode ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql b/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql index ddc206f3e537..b0275e931a23 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql @@ -1,24 +1,32 @@ -WITH t0 AS ( - SELECT - t2.c_custkey AS c_custkey, - COUNT(t3.o_orderkey) AS c_count - FROM main.customer AS t2 - LEFT OUTER JOIN main.orders AS t3 - ON t2.c_custkey = t3.o_custkey AND NOT t3.o_comment LIKE '%special%requests%' - GROUP BY - 1 -) SELECT - t1.c_count, - t1.custdist + * FROM ( SELECT - t0.c_count AS c_count, + t3.c_count, COUNT(*) AS custdist - FROM t0 + FROM ( + SELECT + t2.c_custkey, + COUNT(t2.o_orderkey) AS c_count + FROM ( + SELECT + t0.*, + t1.* + FROM "customer" AS t0 + LEFT JOIN "orders" AS t1 + ON ( + t0.c_custkey + ) = ( + t1.o_custkey + ) + AND NOT t1.o_comment LIKE '%special%requests%' + ) AS t2 + GROUP BY + 1 + ) AS t3 GROUP BY 1 -) AS t1 +) AS t4 ORDER BY - t1.custdist DESC, - t1.c_count DESC \ No newline at end of file + t4.custdist DESC, + t4.c_count DESC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql index 3728c34ea65b..aba4eafae950 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql @@ -1,22 +1,57 @@ SELECT ( + ( + SUM( + CASE + WHEN t2.p_type LIKE 'PROMO%' + THEN ( + t2.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t2.l_discount + ) + ) + ELSE CAST(0 AS TINYINT) + END + ) + ) * ( + CAST(100 AS TINYINT) + ) + ) / ( SUM( - CASE - WHEN ( - t1.p_type LIKE 'PROMO%' + ( + t2.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t2.l_discount ) - THEN t0.l_extendedprice * ( - CAST(1 AS TINYINT) - t0.l_discount - ) - ELSE CAST(0 AS TINYINT) - END - ) * CAST(100 AS TINYINT) - ) / SUM(t0.l_extendedprice * ( - CAST(1 AS TINYINT) - t0.l_discount - )) AS promo_revenue -FROM main.lineitem AS t0 -JOIN main.part AS t1 - ON t0.l_partkey = t1.p_partkey + ) + ) + ) AS promo_revenue +FROM ( + SELECT + t0.*, + t1.* + FROM "lineitem" AS t0 + INNER JOIN "part" AS t1 + ON ( + t0.l_partkey + ) = ( + t1.p_partkey + ) +) AS t2 WHERE - t0.l_shipdate >= CAST('1995-09-01' AS DATE) - AND t0.l_shipdate < CAST('1995-10-01' AS DATE) \ No newline at end of file + ( + t2.l_shipdate + ) >= ( + MAKE_DATE(1995, 9, 1) + ) + AND ( + t2.l_shipdate + ) < ( + MAKE_DATE(1995, 10, 1) + ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql index 52161b72fd26..a04eb32850b1 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql @@ -1,54 +1,91 @@ -WITH t0 AS ( - SELECT - t3.l_suppkey AS l_suppkey, - SUM(t3.l_extendedprice * ( - CAST(1 AS TINYINT) - t3.l_discount - )) AS total_revenue - FROM main.lineitem AS t3 - WHERE - t3.l_shipdate >= CAST('1996-01-01' AS DATE) - AND t3.l_shipdate < CAST('1996-04-01' AS DATE) - GROUP BY - 1 -), t1 AS ( - SELECT - t3.s_suppkey AS s_suppkey, - t3.s_name AS s_name, - t3.s_address AS s_address, - t3.s_nationkey AS s_nationkey, - t3.s_phone AS s_phone, - t3.s_acctbal AS s_acctbal, - t3.s_comment AS s_comment, - t0.l_suppkey AS l_suppkey, - t0.total_revenue AS total_revenue - FROM main.supplier AS t3 - JOIN t0 - ON t3.s_suppkey = t0.l_suppkey - WHERE - t0.total_revenue = ( - SELECT - MAX(t0.total_revenue) AS "Max(total_revenue)" - FROM t0 - ) -) SELECT - t2.s_suppkey, - t2.s_name, - t2.s_address, - t2.s_phone, - t2.total_revenue + t5.s_suppkey, + t5.s_name, + t5.s_address, + t5.s_phone, + t5.total_revenue FROM ( SELECT - t1.s_suppkey AS s_suppkey, - t1.s_name AS s_name, - t1.s_address AS s_address, - t1.s_nationkey AS s_nationkey, - t1.s_phone AS s_phone, - t1.s_acctbal AS s_acctbal, - t1.s_comment AS s_comment, - t1.l_suppkey AS l_suppkey, - t1.total_revenue AS total_revenue - FROM t1 + * + FROM ( + SELECT + * + FROM ( + SELECT + * + FROM "supplier" AS t0 + INNER JOIN ( + SELECT + t1.l_suppkey, + SUM( + ( + t1.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t1.l_discount + ) + ) + ) AS total_revenue + FROM "lineitem" AS t1 + WHERE + ( + t1.l_shipdate + ) >= ( + MAKE_DATE(1996, 1, 1) + ) + AND ( + t1.l_shipdate + ) < ( + MAKE_DATE(1996, 4, 1) + ) + GROUP BY + 1 + ) AS t2 + ON ( + t0.s_suppkey + ) = ( + t2.l_suppkey + ) + ) AS t3 + WHERE + ( + t3.total_revenue + ) = ( + SELECT + MAX(t1.total_revenue) AS "Max(total_revenue)" + FROM ( + SELECT + t0.l_suppkey, + SUM( + ( + t0.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t0.l_discount + ) + ) + ) AS total_revenue + FROM "lineitem" AS t0 + WHERE + ( + t0.l_shipdate + ) >= ( + MAKE_DATE(1996, 1, 1) + ) + AND ( + t0.l_shipdate + ) < ( + MAKE_DATE(1996, 4, 1) + ) + GROUP BY + 1 + ) AS t1 + ) + ) AS t4 ORDER BY - t1.s_suppkey ASC -) AS t2 \ No newline at end of file + t4.s_suppkey ASC +) AS t5 \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql b/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql index 5f9ebafc8322..e8c3a55986ad 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql @@ -1,47 +1,49 @@ SELECT - t0.p_brand, - t0.p_type, - t0.p_size, - t0.supplier_cnt + * FROM ( SELECT - t2.p_brand AS p_brand, - t2.p_type AS p_type, - t2.p_size AS p_size, - COUNT(DISTINCT t1.ps_suppkey) AS supplier_cnt - FROM main.partsupp AS t1 - JOIN main.part AS t2 - ON t2.p_partkey = t1.ps_partkey + t2.p_brand, + t2.p_type, + t2.p_size, + COUNT(DISTINCT t2.ps_suppkey) AS supplier_cnt + FROM ( + SELECT + t0.*, + t1.* + FROM "partsupp" AS t0 + INNER JOIN "part" AS t1 + ON ( + t1.p_partkey + ) = ( + t0.ps_partkey + ) + ) AS t2 WHERE - t2.p_brand <> 'Brand#45' + ( + t2.p_brand + ) <> ( + 'Brand#45' + ) AND NOT t2.p_type LIKE 'MEDIUM POLISHED%' AND t2.p_size IN (CAST(49 AS TINYINT), CAST(14 AS TINYINT), CAST(23 AS TINYINT), CAST(45 AS TINYINT), CAST(19 AS TINYINT), CAST(3 AS TINYINT), CAST(36 AS TINYINT), CAST(9 AS TINYINT)) - AND ( - NOT t1.ps_suppkey IN ( + AND NOT t2.ps_suppkey IN ( + SELECT + t1.s_suppkey + FROM ( SELECT - t3.s_suppkey - FROM ( - SELECT - t4.s_suppkey AS s_suppkey, - t4.s_name AS s_name, - t4.s_address AS s_address, - t4.s_nationkey AS s_nationkey, - t4.s_phone AS s_phone, - t4.s_acctbal AS s_acctbal, - t4.s_comment AS s_comment - FROM main.supplier AS t4 - WHERE - t4.s_comment LIKE '%Customer%Complaints%' - ) AS t3 - ) + * + FROM "supplier" AS t0 + WHERE + t0.s_comment LIKE '%Customer%Complaints%' + ) AS t1 ) GROUP BY 1, 2, 3 -) AS t0 +) AS t3 ORDER BY - t0.supplier_cnt DESC, - t0.p_brand ASC, - t0.p_type ASC, - t0.p_size ASC \ No newline at end of file + t3.supplier_cnt DESC, + t3.p_brand ASC, + t3.p_type ASC, + t3.p_size ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql index 4e9c6e9f6da4..9017582a898a 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql @@ -1,15 +1,46 @@ SELECT - SUM(t0.l_extendedprice) / CAST(7.0 AS REAL(53)) AS avg_yearly -FROM main.lineitem AS t0 -JOIN main.part AS t1 - ON t1.p_partkey = t0.l_partkey + ( + SUM(t2.l_extendedprice) + ) / ( + CAST(7.0 AS DOUBLE) + ) AS avg_yearly +FROM ( + SELECT + t0.*, + t1.* + FROM "lineitem" AS t0 + INNER JOIN "part" AS t1 + ON ( + t1.p_partkey + ) = ( + t0.l_partkey + ) +) AS t2 WHERE - t1.p_brand = 'Brand#23' - AND t1.p_container = 'MED BOX' - AND t0.l_quantity < ( - SELECT - AVG(t0.l_quantity) AS "Mean(l_quantity)" - FROM main.lineitem AS t0 - WHERE - t0.l_partkey = t1.p_partkey - ) * CAST(0.2 AS REAL(53)) \ No newline at end of file + ( + t2.p_brand + ) = ( + 'Brand#23' + ) + AND ( + t2.p_container + ) = ( + 'MED BOX' + ) + AND ( + t2.l_quantity + ) < ( + ( + SELECT + AVG(t0.l_quantity) AS "Mean(l_quantity)" + FROM "lineitem" AS t0 + WHERE + ( + t0.l_partkey + ) = ( + p_partkey + ) + ) * ( + CAST(0.2 AS DOUBLE) + ) + ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql b/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql index e5d84f17ac70..ae042784a4bd 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql @@ -1,29 +1,174 @@ SELECT - SUM(t0.l_extendedprice * ( - CAST(1 AS TINYINT) - t0.l_discount - )) AS revenue -FROM main.lineitem AS t0 -JOIN main.part AS t1 - ON t1.p_partkey = t0.l_partkey + SUM( + ( + t2.l_extendedprice + ) * ( + ( + CAST(1 AS TINYINT) + ) - ( + t2.l_discount + ) + ) + ) AS revenue +FROM ( + SELECT + t0.*, + t1.* + FROM "lineitem" AS t0 + INNER JOIN "part" AS t1 + ON ( + t1.p_partkey + ) = ( + t0.l_partkey + ) +) AS t2 WHERE - t1.p_brand = 'Brand#12' - AND t1.p_container IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') - AND t0.l_quantity >= CAST(1 AS TINYINT) - AND t0.l_quantity <= CAST(11 AS TINYINT) - AND t1.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(5 AS TINYINT) - AND t0.l_shipmode IN ('AIR', 'AIR REG') - AND t0.l_shipinstruct = 'DELIVER IN PERSON' - OR t1.p_brand = 'Brand#23' - AND t1.p_container IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') - AND t0.l_quantity >= CAST(10 AS TINYINT) - AND t0.l_quantity <= CAST(20 AS TINYINT) - AND t1.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(10 AS TINYINT) - AND t0.l_shipmode IN ('AIR', 'AIR REG') - AND t0.l_shipinstruct = 'DELIVER IN PERSON' - OR t1.p_brand = 'Brand#34' - AND t1.p_container IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') - AND t0.l_quantity >= CAST(20 AS TINYINT) - AND t0.l_quantity <= CAST(30 AS TINYINT) - AND t1.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(15 AS TINYINT) - AND t0.l_shipmode IN ('AIR', 'AIR REG') - AND t0.l_shipinstruct = 'DELIVER IN PERSON' \ No newline at end of file + ( + ( + ( + ( + ( + ( + ( + ( + ( + t2.p_brand + ) = ( + 'Brand#12' + ) + ) + AND ( + t2.p_container IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + ) + ) + AND ( + ( + t2.l_quantity + ) >= ( + CAST(1 AS TINYINT) + ) + ) + ) + AND ( + ( + t2.l_quantity + ) <= ( + CAST(11 AS TINYINT) + ) + ) + ) + AND ( + t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(5 AS TINYINT) + ) + ) + AND ( + t2.l_shipmode IN ('AIR', 'AIR REG') + ) + ) + AND ( + ( + t2.l_shipinstruct + ) = ( + 'DELIVER IN PERSON' + ) + ) + ) + OR ( + ( + ( + ( + ( + ( + ( + ( + t2.p_brand + ) = ( + 'Brand#23' + ) + ) + AND ( + t2.p_container IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + ) + ) + AND ( + ( + t2.l_quantity + ) >= ( + CAST(10 AS TINYINT) + ) + ) + ) + AND ( + ( + t2.l_quantity + ) <= ( + CAST(20 AS TINYINT) + ) + ) + ) + AND ( + t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(10 AS TINYINT) + ) + ) + AND ( + t2.l_shipmode IN ('AIR', 'AIR REG') + ) + ) + AND ( + ( + t2.l_shipinstruct + ) = ( + 'DELIVER IN PERSON' + ) + ) + ) + ) + OR ( + ( + ( + ( + ( + ( + ( + ( + t2.p_brand + ) = ( + 'Brand#34' + ) + ) + AND ( + t2.p_container IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + ) + ) + AND ( + ( + t2.l_quantity + ) >= ( + CAST(20 AS TINYINT) + ) + ) + ) + AND ( + ( + t2.l_quantity + ) <= ( + CAST(30 AS TINYINT) + ) + ) + ) + AND ( + t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(15 AS TINYINT) + ) + ) + AND ( + t2.l_shipmode IN ('AIR', 'AIR REG') + ) + ) + AND ( + ( + t2.l_shipinstruct + ) = ( + 'DELIVER IN PERSON' + ) + ) + ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql index 9714fecea9d5..7776153680bc 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/duckdb/h20.sql @@ -1,73 +1,83 @@ -WITH t0 AS ( +SELECT + * +FROM ( SELECT - t2.s_suppkey AS s_suppkey, - t2.s_name AS s_name, - t2.s_address AS s_address, - t2.s_nationkey AS s_nationkey, - t2.s_phone AS s_phone, - t2.s_acctbal AS s_acctbal, - t2.s_comment AS s_comment, - t3.n_nationkey AS n_nationkey, - t3.n_name AS n_name, - t3.n_regionkey AS n_regionkey, - t3.n_comment AS n_comment - FROM main.supplier AS t2 - JOIN main.nation AS t3 - ON t2.s_nationkey = t3.n_nationkey - WHERE - t3.n_name = 'CANADA' - AND t2.s_suppkey IN ( + t3.s_name, + t3.s_address + FROM ( + SELECT + * + FROM ( SELECT - t4.ps_suppkey - FROM ( + * + FROM "supplier" AS t0 + INNER JOIN "nation" AS t1 + ON ( + t0.s_nationkey + ) = ( + t1.n_nationkey + ) + ) AS t2 + WHERE + ( + t2.n_name + ) = ( + 'CANADA' + ) + AND t2.s_suppkey IN ( SELECT - t5.ps_partkey AS ps_partkey, - t5.ps_suppkey AS ps_suppkey, - t5.ps_availqty AS ps_availqty, - t5.ps_supplycost AS ps_supplycost, - t5.ps_comment AS ps_comment - FROM main.partsupp AS t5 - WHERE - t5.ps_partkey IN ( - SELECT - t6.p_partkey - FROM ( + t1.ps_suppkey + FROM ( + SELECT + * + FROM "partsupp" AS t0 + WHERE + t0.ps_partkey IN ( SELECT - t7.p_partkey AS p_partkey, - t7.p_name AS p_name, - t7.p_mfgr AS p_mfgr, - t7.p_brand AS p_brand, - t7.p_type AS p_type, - t7.p_size AS p_size, - t7.p_container AS p_container, - t7.p_retailprice AS p_retailprice, - t7.p_comment AS p_comment - FROM main.part AS t7 - WHERE - t7.p_name LIKE 'forest%' - ) AS t6 - ) - AND t5.ps_availqty > ( - SELECT - SUM(t6.l_quantity) AS "Sum(l_quantity)" - FROM main.lineitem AS t6 - WHERE - t6.l_partkey = t5.ps_partkey - AND t6.l_suppkey = t5.ps_suppkey - AND t6.l_shipdate >= CAST('1994-01-01' AS DATE) - AND t6.l_shipdate < CAST('1995-01-01' AS DATE) - ) * CAST(0.5 AS REAL(53)) - ) AS t4 - ) -) -SELECT - t1.s_name, - t1.s_address -FROM ( - SELECT - t0.s_name AS s_name, - t0.s_address AS s_address - FROM t0 -) AS t1 + t1.p_partkey + FROM ( + SELECT + * + FROM "part" AS t0 + WHERE + t0.p_name LIKE 'forest%' + ) AS t1 + ) + AND ( + t0.ps_availqty + ) > ( + ( + SELECT + SUM(t0.l_quantity) AS "Sum(l_quantity)" + FROM "lineitem" AS t0 + WHERE + ( + t0.l_partkey + ) = ( + ps_partkey + ) + AND ( + t0.l_suppkey + ) = ( + ps_suppkey + ) + AND ( + t0.l_shipdate + ) >= ( + MAKE_DATE(1994, 1, 1) + ) + AND ( + t0.l_shipdate + ) < ( + MAKE_DATE(1995, 1, 1) + ) + ) * ( + CAST(0.5 AS DOUBLE) + ) + ) + ) AS t1 + ) + ) AS t3 +) AS t4 ORDER BY - t1.s_name ASC \ No newline at end of file + t4.s_name ASC \ No newline at end of file From ff5855e3f1b58b0346ac9fc61555a5b51ccdb022 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:25:41 -0400 Subject: [PATCH 210/222] chore: fix clip --- ibis/backends/duckdb/compiler/values.py | 12 ++++++++++-- ibis/backends/tests/test_numeric.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 679548af0453..44a198e897af 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -307,10 +307,18 @@ def _generic_log(op, **kw): def _clip(op, **kw): arg = translate_val(op.arg, **kw) if (upper := op.upper) is not None: - arg = sg.exp.Least.from_arg_list([translate_val(upper, **kw), arg]) + arg = sg.exp.If( + this=arg.is_(NULL), + true=sg.exp.NULL, + false=sg.exp.Least.from_arg_list([translate_val(upper, **kw), arg]), + ) if (lower := op.lower) is not None: - arg = sg.exp.Greatest.from_arg_list([translate_val(lower, **kw), arg]) + arg = sg.exp.If( + this=arg.is_(NULL), + true=sg.exp.NULL, + false=sg.exp.Greatest.from_arg_list([translate_val(lower, **kw), arg]), + ) return arg diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index c427f4eb14d0..f86b5b18c8bd 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -1386,7 +1386,8 @@ def test_random(con): @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_clip(backend, alltypes, df, ibis_func, pandas_func): result = ibis_func(alltypes.int_col).execute() - expected = pandas_func(df.int_col).astype(result.dtype) + raw_expected = pandas_func(df.int_col) + expected = raw_expected.astype(result.dtype) # Names won't match in the PySpark backend since PySpark # gives 'tmp' name when executing a Column backend.assert_series_equal(result, expected, check_names=False) From 995ff1907a15f0ba508bb765b282c42ac221cd2f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:05:24 -0400 Subject: [PATCH 211/222] chore: add base compiler functionality --- .../base/sqlglot/compiler/__init__.py | 0 ibis/backends/base/sqlglot/compiler/core.py | 168 ++++++++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 ibis/backends/base/sqlglot/compiler/__init__.py create mode 100644 ibis/backends/base/sqlglot/compiler/core.py diff --git a/ibis/backends/base/sqlglot/compiler/__init__.py b/ibis/backends/base/sqlglot/compiler/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/backends/base/sqlglot/compiler/core.py b/ibis/backends/base/sqlglot/compiler/core.py new file mode 100644 index 000000000000..e37f4f211f82 --- /dev/null +++ b/ibis/backends/base/sqlglot/compiler/core.py @@ -0,0 +1,168 @@ +"""ClickHouse ibis expression to sqlglot compiler. + +The compiler is built with a few `singledispatch` functions: + + 1. `translate` for table expressions + 1. `translate` for table nodes + 1. `translate_rel` + 1. `translate_val` + +## `translate` + +### Expression Implementation + +The table expression implementation of `translate` is a pass through to the +node implementation. + +### Node Implementation + +There's a single `ops.Node` implementation for `ops.TableNode`s instances. + +This function: + + 1. Topologically sorts the expression graph. + 1. Seeds the compilation cache with in-degree-zero table names. + 1. Iterates though nodes with at least one in-degree and places the result + in the compilation cache. The cache is used to construct `ops.TableNode` + keyword arguments to the current translation rule. + +## `translate_rel` + +Translates a table operation given already-translated table inputs. + +If a table node needs to translate value expressions, for example, an +`ops.Aggregation` that rule is responsible for calling `translate_val`. + +## `translate_val` + +Recurses top-down and translates the arguments of the value expression and uses +those as input to construct the output. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import sqlglot as sg + +import ibis.expr.analysis as an +import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.patterns import Call +from ibis.expr.analysis import c, p, x, y + +if TYPE_CHECKING: + from collections.abc import Mapping + + +a = Call.namespace(an) + + +def translate( + op: ops.TableNode, + *, + params: Mapping[ir.Value, Any], + translate_rel, + translate_val, +) -> sg.exp.Expression: + """Translate an ibis operation to a sqlglot expression. + + Parameters + ---------- + op + An ibis `TableNode` + params + A mapping of expressions to concrete values + translate_rel + Relation node translator + translate_val + Value node translator + + Returns + ------- + sqlglot.expressions.Expression + A sqlglot expression + """ + + def _translate_node(node, *args, **kwargs): + if isinstance(node, ops.Value): + return translate_val(node, *args, **kwargs) + assert isinstance(node, ops.TableNode) + return translate_rel(node, *args, **kwargs) + + alias_index = 0 + aliases = {} + + def fn(node, _, **kwargs): + nonlocal alias_index + + result = _translate_node(node, aliases=aliases, **kwargs) + + if not isinstance(node, ops.TableNode): + return result + + # don't alias the root node + if node is not op: + aliases[node] = f"t{alias_index:d}" + alias_index += 1 + + if alias := aliases.get(node): + try: + return result.subquery(alias=alias) + except AttributeError: + return sg.alias(result, alias=alias) + else: + return result + + # substitute parameters immediately to avoid having to define a + # ScalarParameter translation rule + # + # this lets us avoid threading `params` through every `translate_val` call + # only to be used in the one place it would be needed: the ScalarParameter + # `translate_val` rule + params = {param.op(): value for param, value in params.items()} + replace_literals = p.ScalarParameter >> ( + lambda op, _: ops.Literal(value=params[op], dtype=op.dtype) + ) + + # rewrite cumulative functions to window functions, so that we don't have + # to think about handling them in the compiler, we need only compile window + # functions + replace_cumulative_ops = p.WindowFunction( + x @ p.Cumulative, y + ) >> a.cumulative_to_window(x, y) + + # replace the right side of InColumn into a scalar subquery for sql + # backends + replace_in_column_with_table_array_view = p.InColumn >> ( + lambda op, _: op.__class__( + op.value, + ops.TableArrayView( + ops.Selection( + table=an.find_first_base_table(op.options), selections=(op.options,) + ) + ), + ) + ) + + # replace any checks against an empty right side of the IN operation with + # `False` + replace_empty_in_values_with_false = p.InValues(x, ()) >> c.Literal( + False, dtype="bool" + ) + + replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(x) >> c.Not( + c.ExistsSubquery(x) + ) + + op = op.replace( + replace_literals + | replace_cumulative_ops + | replace_in_column_with_table_array_view + | replace_empty_in_values_with_false + | replace_notexists_subquery_with_not_exists + ) + # apply translate rules in topological order + results = op.map(fn, filter=(ops.TableNode, ops.Value)) + node = results[op] + return node.this if isinstance(node, sg.exp.Subquery) else node From d61d530361ce7913350b20cd5dd2fe9b31e7ac09 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:05:38 -0400 Subject: [PATCH 212/222] chore(clickhouse): move to base compiler functionality --- ibis/backends/clickhouse/compiler/core.py | 160 +--------------------- 1 file changed, 6 insertions(+), 154 deletions(-) diff --git a/ibis/backends/clickhouse/compiler/core.py b/ibis/backends/clickhouse/compiler/core.py index 26f4a050efe5..013073c245b7 100644 --- a/ibis/backends/clickhouse/compiler/core.py +++ b/ibis/backends/clickhouse/compiler/core.py @@ -1,161 +1,13 @@ -"""ClickHouse ibis expression to sqlglot compiler. - -The compiler is built with a few `singledispatch` functions: - - 1. `translate` for table expressions - 1. `translate` for table nodes - 1. `translate_rel` - 1. `translate_val` - -## `translate` - -### Expression Implementation - -The table expression implementation of `translate` is a pass through to the -node implementation. - -### Node Implementation - -There's a single `ops.Node` implementation for `ops.TableNode`s instances. - -This function: - - 1. Topologically sorts the expression graph. - 1. Seeds the compilation cache with in-degree-zero table names. - 1. Iterates though nodes with at least one in-degree and places the result - in the compilation cache. The cache is used to construct `ops.TableNode` - keyword arguments to the current translation rule. - -## `translate_rel` - -Translates a table operation given already-translated table inputs. - -If a table node needs to translate value expressions, for example, an -`ops.Aggregation` that rule is responsible for calling `translate_val`. - -## `translate_val` - -Recurses top-down and translates the arguments of the value expression and uses -those as input to construct the output. -""" +"""ClickHouse ibis expression to sqlglot compiler.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +import functools -import sqlglot as sg - -import ibis.expr.analysis as an -import ibis.expr.operations as ops -import ibis.expr.types as ir +from ibis.backends.base.sqlglot.compiler.core import translate as _translate from ibis.backends.clickhouse.compiler.relations import translate_rel from ibis.backends.clickhouse.compiler.values import translate_val -from ibis.common.patterns import Call -from ibis.expr.analysis import c, p, x, y - -if TYPE_CHECKING: - from collections.abc import Mapping - - -a = Call.namespace(an) - - -def _translate_node(node, *args, **kwargs): - if isinstance(node, ops.Value): - return translate_val(node, *args, **kwargs) - assert isinstance(node, ops.TableNode) - return translate_rel(node, *args, **kwargs) - - -def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression: - """Translate an ibis operation to a sqlglot expression. - - Parameters - ---------- - op - An ibis `TableNode` - params - A mapping of expressions to concrete values - - Returns - ------- - sqlglot.expressions.Expression - A sqlglot expression - """ - - alias_index = 0 - aliases = {} - - def fn(node, _, **kwargs): - nonlocal alias_index - - result = _translate_node(node, aliases=aliases, **kwargs) - - if not isinstance(node, ops.TableNode): - return result - - # don't alias the root node - if node is not op: - aliases[node] = f"t{alias_index:d}" - alias_index += 1 - - if alias := aliases.get(node): - try: - return result.subquery(alias=alias) - except AttributeError: - return sg.alias(result, alias=alias) - else: - return result - - # substitute parameters immediately to avoid having to define a - # ScalarParameter translation rule - # - # this lets us avoid threading `params` through every `translate_val` call - # only to be used in the one place it would be needed: the ScalarParameter - # `translate_val` rule - params = {param.op(): value for param, value in params.items()} - replace_literals = p.ScalarParameter >> ( - lambda op, _: ops.Literal(value=params[op], dtype=op.dtype) - ) - - # rewrite cumulative functions to window functions, so that we don't have - # to think about handling them in the compiler, we need only compile window - # functions - replace_cumulative_ops = p.WindowFunction( - x @ p.Cumulative, y - ) >> a.cumulative_to_window(x, y) - - # replace the right side of InColumn into a scalar subquery for sql - # backends - replace_in_column_with_table_array_view = p.InColumn >> ( - lambda op, _: op.__class__( - op.value, - ops.TableArrayView( - ops.Selection( - table=an.find_first_base_table(op.options), selections=(op.options,) - ) - ), - ) - ) - - # replace any checks against an empty right side of the IN operation with - # `False` - replace_empty_in_values_with_false = p.InValues(x, ()) >> c.Literal( - False, dtype="bool" - ) - - replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(x) >> c.Not( - c.ExistsSubquery(x) - ) - op = op.replace( - replace_literals - | replace_cumulative_ops - | replace_in_column_with_table_array_view - | replace_empty_in_values_with_false - | replace_notexists_subquery_with_not_exists - ) - # apply translate rules in topological order - results = op.map(fn, filter=(ops.TableNode, ops.Value)) - node = results[op] - return node.this if isinstance(node, sg.exp.Subquery) else node +translate = functools.partial( + _translate, translate_rel=translate_rel, translate_val=translate_val +) From fe96efc4ed6568a8b2e4aa832f10885781e1e721 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:06:14 -0400 Subject: [PATCH 213/222] chore(duckdb): move to base compiler functionality --- ibis/backends/duckdb/compiler/core.py | 100 +- ibis/backends/duckdb/compiler/relations.py | 159 ++- ibis/backends/duckdb/compiler/values.py | 1008 +++++++------------- 3 files changed, 444 insertions(+), 823 deletions(-) diff --git a/ibis/backends/duckdb/compiler/core.py b/ibis/backends/duckdb/compiler/core.py index 9979e63a9eaa..8df85a9e0db9 100644 --- a/ibis/backends/duckdb/compiler/core.py +++ b/ibis/backends/duckdb/compiler/core.py @@ -1,99 +1,13 @@ -"""DuckDB ibis expression to sqlglot compiler. - -The compiler is built with a few `singledispatch` functions: - - 1. `translate` for table expressions - 1. `translate` for table nodes - 1. `translate_rel` - 1. `translate_val` - -## `translate` - -### Expression Implementation - -The table expression implementation of `translate` is a pass through to the -node implementation. - -### Node Implementation - -There's a single `ops.Node` implementation for `ops.TableNode`s instances. - -This function: - - 1. Topologically sorts the expression graph. - 1. Seeds the compilation cache with in-degree-zero table names. - 1. Iterates though nodes with at least one in-degree and places the result - in the compilation cache. The cache is used to construct `ops.TableNode` - keyword arguments to the current translation rule. - -## `translate_rel` - -Translates a table operation given already-translated table inputs. - -If a table node needs to translate value expressions, for example, an -`ops.Aggregation` that rule is responsible for calling `translate_val`. - -## `translate_val` - -Recurses top-down and translates the arguments of the value expression and uses -those as input to construct the output. -""" +"""DuckDB ibis expression to sqlglot compiler.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import sqlglot as sg +import functools -import ibis.expr.operations as ops -import ibis.expr.types as ir +from ibis.backends.base.sqlglot.compiler.core import translate as _translate from ibis.backends.duckdb.compiler.relations import translate_rel +from ibis.backends.duckdb.compiler.values import translate_val -if TYPE_CHECKING: - from collections.abc import Mapping - - -def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression: - """Translate an ibis operation to a sqlglot expression. - - Parameters - ---------- - op - An ibis `TableNode` - params - A mapping of expressions to concrete values - - Returns - ------- - sqlglot.expressions.Expression - A sqlglot expression - """ - params = {param.op(): value for param, value in params.items()} - - alias_index = 0 - aliases = {} - - def fn(node, cache, params=params, **kwargs): - nonlocal alias_index - - # don't alias the root node - if node is not op: - # TODO: do we want to create sqlglot tables here? - aliases[node] = f"t{alias_index:d}" - alias_index += 1 - - raw_rel = translate_rel( - node, aliases=aliases, params=params, cache=cache, **kwargs - ) - - if alias := aliases.get(node): - try: - return raw_rel.subquery(alias) - except AttributeError: - return sg.alias(raw_rel, alias) - else: - return raw_rel - - results = op.map(fn, filter=ops.TableNode) - node = results[op] - return node.this if isinstance(node, sg.exp.Subquery) else node +translate = functools.partial( + _translate, translate_rel=translate_rel, translate_val=translate_val +) diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index aa130c254a69..641cf1f3df4a 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -2,16 +2,13 @@ import functools from collections.abc import Mapping -from functools import partial from typing import Any import sqlglot as sg import ibis.common.exceptions as com -import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.base.sqlglot import unalias -from ibis.backends.duckdb.compiler.values import translate_val +from ibis.backends.base.sqlglot import FALSE, NULL, STAR, lit @functools.singledispatch @@ -21,23 +18,41 @@ def translate_rel(op: ops.TableNode, **_): @translate_rel.register(ops.DummyTable) -def _dummy(op: ops.DummyTable, **kw): - return sg.select(*map(partial(translate_val, **kw), op.values)) +def _dummy(op: ops.DummyTable, *, values, **_): + return sg.select(*values) -@translate_rel.register(ops.DatabaseTable) @translate_rel.register(ops.UnboundTable) @translate_rel.register(ops.InMemoryTable) def _physical_table(op, **_): return sg.expressions.Table(this=sg.to_identifier(op.name, quoted=True)) -@translate_rel.register(ops.Selection) -def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): +@translate_rel.register(ops.DatabaseTable) +def _database_table(op, *, name, namespace, **_): + try: + db, catalog = namespace.split(".") + except AttributeError: + db = catalog = None + return sg.table(name, db=db, catalog=catalog) + + +def replace_tables_with_star_selection(node, alias=None): + if isinstance(node, (sg.exp.Subquery, sg.exp.Table, sg.exp.CTE)): + return sg.exp.Column( + this=STAR, + table=sg.to_identifier(alias if alias is not None else node.alias_or_name), + ) + return node + + +@translate_rel.register +def _selection( + op: ops.Selection, *, table, selections, predicates, sort_keys, aliases, **_ +): # needs_alias should never be true here in explicitly, but it may get # passed via a (recursive) call to translate_val - assert not needs_alias, "needs_alias is True" - if needs_alias := isinstance(op.table, ops.Join) and not isinstance( + if isinstance(op.table, ops.Join) and not isinstance( op.table, (ops.LeftSemiJoin, ops.LeftAntiJoin) ): args = table.this.args @@ -45,52 +60,54 @@ def _selection(op: ops.Selection, *, table, needs_alias=False, aliases, **kw): (join,) = args["joins"] else: from_ = join = None - tr_val = partial(translate_val, needs_alias=needs_alias, aliases=aliases, **kw) - selections = tuple(map(tr_val, op.selections)) or "*" + + selections = tuple( + replace_tables_with_star_selection( + node, + # replace the table name with the alias if the table is **not** a + # join, because we may be selecting from a subquery or an aliased + # table; otherwise we'll select from the _unaliased_ table or the + # _child_ table, which may have a different alias than the one we + # generated for the input table + table.alias_or_name if from_ is None and join is None else None, + ) + for node in selections + ) or (STAR,) + sel = sg.select(*selections).from_(from_ if from_ is not None else table) if join is not None: sel = sel.join(join) - if predicates := op.predicates: + if predicates: if join is not None: - sel = sg.select("*").from_(sel.subquery(aliases[op.table])) - sel = sel.where(sg.and_(*map(tr_val, map(unalias, predicates)))) + sel = sg.select(STAR).from_(sel.subquery(aliases[op.table])) + sel = sel.where(*predicates) - if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val, map(unalias, sort_keys))) + if sort_keys: + sel = sel.order_by(*sort_keys) return sel @translate_rel.register(ops.Aggregation) -def _aggregation(op: ops.Aggregation, *, table, **kw): - tr_val = partial(translate_val, **kw) - - by = tuple(map(tr_val, op.by)) - metrics = tuple(map(tr_val, op.metrics)) - selections = (by + metrics) or "*" +def _aggregation( + op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_ +): + selections = (by + metrics) or (STAR,) sel = sg.select(*selections).from_(table) - if op.by: - # avoids translation of group by keys twice and makes the output more - # concise - sel = sel.group_by( - *( - sg.exp.Literal(this=str(key), is_string=False) - # keys are refer - for key in range(1, len(op.by) + 1) - ) - ) + if by: + sel = sel.group_by(*map(lit, range(1, len(by) + 1))) - if predicates := op.predicates: - sel = sel.where(*map(tr_val, map(unalias, predicates))) + if predicates: + sel = sel.where(*predicates) - if having := op.having: - sel = sel.having(*map(tr_val, map(unalias, having))) + if having: + sel = sel.having(*having) - if sort_keys := op.sort_keys: - sel = sel.order_by(*map(tr_val, map(unalias, sort_keys))) + if sort_keys: + sel = sel.order_by(*sort_keys) return sel @@ -108,14 +125,8 @@ def _aggregation(op: ops.Aggregation, *, table, **kw): @translate_rel.register -def _join(op: ops.Join, *, left, right, **kw): - predicates = op.predicates - - on = ( - sg.and_(*map(partial(translate_val, **kw), map(unalias, predicates))) - if predicates - else None - ) +def _join(op: ops.Join, *, left, right, predicates, **_): + on = sg.and_(*predicates) if predicates else None join_type = _JOIN_TYPES[type(op)] try: @@ -175,7 +186,7 @@ def _limit(op: ops.Limit, *, table, n, offset, **kw): if isinstance(n, int): result = result.limit(n) elif n is not None: - limit = translate_val(n, **kw) + limit = n # TODO: calling `.sql` is a workaround for sqlglot not supporting # scalar subqueries in limits limit = sg.select(limit).from_(table).subquery().sql(dialect="duckdb") @@ -184,7 +195,7 @@ def _limit(op: ops.Limit, *, table, n, offset, **kw): assert offset is not None, "offset is None" if not isinstance(offset, int): - skip = translate_val(offset, **kw) + skip = offset skip = sg.select(skip).from_(table).subquery().sql(dialect="duckdb") elif not offset: return result @@ -196,42 +207,35 @@ def _limit(op: ops.Limit, *, table, n, offset, **kw): @translate_rel.register def _distinct(_: ops.Distinct, *, table, **kw): - return sg.select("*").distinct().from_(table) + return sg.select(STAR).distinct().from_(table) @translate_rel.register(ops.DropNa) -def _dropna(op: ops.DropNa, *, table, **kw): - how = op.how - - if op.subset is None: - columns = [ops.TableColumn(op.table, name) for name in op.table.schema.names] - else: - columns = op.subset - - if columns: - raw_predicate = functools.reduce( - ops.And if how == "any" else ops.Or, - map(ops.NotNull, columns), +def _dropna(op: ops.DropNa, *, table, how, subset, **_): + if subset is None: + subset = [sg.column(name, table=table) for name in op.table.schema.names] + + if subset: + predicate = functools.reduce( + sg.and_ if how == "any" else sg.or_, + (sg.not_(col.is_(NULL)) for col in subset), ) elif how == "all": - raw_predicate = ops.Literal(False, dtype=dt.bool) + predicate = FALSE else: - raw_predicate = None + predicate = None - if not raw_predicate: + if predicate is None: return table - tr_val = partial(translate_val, **kw) - predicate = tr_val(raw_predicate) try: - return table.where(unalias(predicate)) + return table.where(predicate) except AttributeError: - return sg.select("*").from_(table).where(unalias(predicate)) + return sg.select(STAR).from_(table).where(predicate) @translate_rel.register -def _fillna(op: ops.FillNa, *, table, **kw): - replacements = op.replacements +def _fillna(op: ops.FillNa, *, table, replacements, **_): if isinstance(replacements, Mapping): mapping = replacements else: @@ -240,12 +244,7 @@ def _fillna(op: ops.FillNa, *, table, **kw): } exprs = [ ( - sg.alias( - sg.exp.Coalesce( - this=sg.column(col), expressions=[translate_val(alt, **kw)] - ), - col, - ) + sg.alias(sg.exp.Coalesce(this=sg.column(col), expressions=[alt]), col) if (alt := mapping.get(col)) is not None else sg.column(col) ) @@ -258,11 +257,11 @@ def _fillna(op: ops.FillNa, *, table, **kw): def _view(op: ops.View, *, child, name: str, **_): # TODO: find a way to do this without creating a temporary view backend = op.child.to_expr()._find_backend() - backend._create_temp_view(table_name=name, source=sg.select("*").from_(child)) + backend._create_temp_view(table_name=name, source=sg.select(STAR).from_(child)) return sg.table(name) @translate_rel.register def _sql_string_view(op: ops.SQLStringView, query: str, **_: Any): table = sg.table(op.name) - return sg.select("*").from_(table).with_(table, as_=query, dialect="duckdb") + return sg.select(STAR).from_(table).with_(table, as_=query, dialect="duckdb") diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 44a198e897af..0a3af382dd5c 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -5,23 +5,29 @@ import math import string from functools import partial -from typing import TYPE_CHECKING, Any +from typing import Any import sqlglot as sg -import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.base.sqlglot import unalias +from ibis.backends.base.sqlglot import NULL, STAR, AggGen, FuncGen, lit, make_cast from ibis.backends.base.sqlglot.datatypes import DuckDBType -if TYPE_CHECKING: - from collections.abc import Mapping +def _aggregate(funcname, *args, where=None): + expr = f[funcname](*args) + if where is not None: + return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where)) + return expr -NULL = sg.exp.Null() + +f = FuncGen() +if_ = f["if"] +cast = make_cast(DuckDBType) +agg = AggGen(aggfunc=_aggregate) @functools.singledispatch @@ -41,8 +47,8 @@ def _val_physical_table(op, *, aliases, **kw): @translate_val.register(ops.TableNode) -def _val_table_node(op, *, aliases, needs_alias=False, **_): - return f"{aliases[op]}.*" if needs_alias else "*" +def _val_table_node(op, *, aliases, **_): + return f"{aliases[op]}.*" @translate_val.register(ops.TableColumn) @@ -51,98 +57,99 @@ def _column(op, *, aliases, **_): @translate_val.register(ops.Alias) -def _alias(op, **kw): - val = translate_val(op.arg, **kw) - return sg.alias(val, op.name) +def _alias(op, *, arg, name, **_): + return sg.alias(arg, name) ### Literals -def sg_literal(arg, is_string=True): - return sg.exp.Literal(this=f"{arg}", is_string=is_string) - - @translate_val.register(ops.Literal) -def _literal(op, **kw): - value = op.value - dtype = op.dtype - +def _literal(op, *, value, dtype, **kw): if dtype.is_interval() and value is not None: return _interval_format(op) - sg_type = DuckDBType.from_ibis(dtype) - if value is None and dtype.nullable: - null = NULL - return null if dtype.is_null() else sg.cast(null, to=sg_type) + return NULL if dtype.is_null() else cast(NULL, dtype) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): - return sg_literal(value) + return lit(value) elif dtype.is_numeric(): # cast non finite values to float because that's the behavior of # duckdb when a mixed decimal/float operation is performed # # float will be upcast to double if necessary by duckdb if not math.isfinite(value): - return sg.cast( - sg_literal(value), - to=sg.exp.DataType.Type.FLOAT if dtype.is_decimal() else sg_type, - ) - return sg.cast(sg_literal(value, is_string=False), to=sg_type) + return cast(lit(value), to=dt.float32 if dtype.is_decimal() else dtype) + return cast(lit(value), dtype) elif dtype.is_time(): - return sg.cast(sg_literal(value), to=sg_type) + return cast(lit(value), dtype) elif dtype.is_timestamp(): - year = sg_literal(value.year, is_string=False) - month = sg_literal(value.month, is_string=False) - day = sg_literal(value.day, is_string=False) - hour = sg_literal(value.hour, is_string=False) - minute = sg_literal(value.minute, is_string=False) - second = sg_literal(value.second, is_string=False) + year = lit(value.year) + month = lit(value.month) + day = lit(value.day) + hour = lit(value.hour) + minute = lit(value.minute) + second = lit(value.second) if us := value.microsecond: - microsecond = sg_literal(us / 1e6, is_string=False) + microsecond = lit(us / 1e6) second += microsecond if (tz := dtype.timezone) is not None: - timezone = sg_literal(tz, is_string=True) + timezone = lit(tz) return sg.func( "make_timestamptz", year, month, day, hour, minute, second, timezone ) else: return sg.func("make_timestamp", year, month, day, hour, minute, second) elif dtype.is_date(): - year = sg_literal(value.year, is_string=False) - month = sg_literal(value.month, is_string=False) - day = sg_literal(value.day, is_string=False) + year = lit(value.year) + month = lit(value.month) + day = lit(value.day) return sg.exp.DateFromParts(year=year, month=month, day=day) elif dtype.is_array(): value_type = dtype.value_type return sg.exp.Array.from_arg_list( - [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value] + [ + _literal( + ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw + ) + for v in value + ] ) elif dtype.is_map(): key_type = dtype.key_type value_type = dtype.value_type keys = sg.exp.Array.from_arg_list( - [_literal(ops.Literal(k, dtype=key_type), **kw) for k in value.keys()] + [ + _literal(ops.Literal(k, dtype=key_type), value=k, dtype=key_type, **kw) + for k in value.keys() + ] ) values = sg.exp.Array.from_arg_list( - [_literal(ops.Literal(v, dtype=value_type), **kw) for v in value.values()] + [ + _literal( + ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw + ) + for v in value.values() + ] ) return sg.exp.Map(keys=keys, values=values) elif dtype.is_struct(): - keys = list(map(sg_literal, value.keys())) + keys = list(map(lit, value.keys())) values = [ - _literal(ops.Literal(v, dtype=field_dtype), **kw) + _literal( + ops.Literal(v, dtype=field_dtype), value=v, dtype=field_dtype, **kw + ) for field_dtype, v in zip(dtype.types, value.values()) ] return sg.exp.Struct.from_arg_list( [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] ) elif dtype.is_uuid(): - return sg.cast(sg_literal(value), to=sg_type) + return cast(lit(str(value)), dtype) elif dtype.is_binary(): - return sg.cast(sg_literal("".join(map("\\x{:02x}".format, value))), to=sg_type) + return cast(lit("".join(map("\\x{:02x}".format, value))), dtype) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -232,25 +239,19 @@ def _literal(op, **kw): } -def _aggregate(op, func, **kw): - args = [ - translate_val(arg, **kw) - for argname, arg in zip(op.argnames, op.args) - if argname not in ("where", "how") - ] - agg = sg.func(func, *args) - return _apply_agg_filter(agg, where=op.where, **kw) - - for _op, _name in _simple_ops.items(): assert isinstance(type(_op), type), type(_op) if issubclass(_op, ops.Reduction): - translate_val.register(_op)(partial(_aggregate, func=_name)) + + @translate_val.register(_op) + def _fmt(op, _name: str = _name, *, where=None, aliases, **kw): + return agg[_name](*kw.values(), where=where) + else: @translate_val.register(_op) - def _fmt(op, _name: str = _name, **kw): - return sg.func(_name, *map(partial(translate_val, **kw), op.args)) + def _fmt(op, _name: str = _name, *, aliases, **kw): + return f[_name](*kw.values()) del _fmt, _name, _op @@ -270,70 +271,54 @@ def _fmt(op, _name: str = _name, **kw): @translate_val.register(ops.BitwiseAnd) @translate_val.register(ops.BitwiseOr) @translate_val.register(ops.BitwiseXor) -def _bitwise_binary(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) +def _bitwise_binary(op, *, left, right, **_): sg_expr = _bitwise_mapping[type(op)] - return sg_expr(this=left, expression=right) @translate_val.register(ops.BitwiseNot) -def _bitwise_not(op, **kw): - value = translate_val(op.arg, **kw) - - return sg.exp.BitwiseNot(this=value) +def _bitwise_not(op, *, arg, **_): + return sg.exp.BitwiseNot(this=arg) ### Mathematical Calisthenics @translate_val.register(ops.E) -def _euler(op, **kw): - return sg.func("exp", 1) +def _euler(op, **_): + return f.exp(1) @translate_val.register(ops.Log) -def _generic_log(op, **kw): - arg, base = op.args - arg = translate_val(arg, **kw) - if base is not None: - base = translate_val(base, **kw) - return sg.func("ln", arg) / sg.func("ln", base) - return sg.func("ln", arg) +def _generic_log(op, *, arg, base, **_): + if base is None: + return f.ln(arg) + elif str(base) in ("2", "10"): + return f[f"log{base}"](arg) + else: + return f.ln(arg) / f.ln(base) @translate_val.register(ops.Clip) -def _clip(op, **kw): - arg = translate_val(op.arg, **kw) - if (upper := op.upper) is not None: - arg = sg.exp.If( - this=arg.is_(NULL), - true=sg.exp.NULL, - false=sg.exp.Least.from_arg_list([translate_val(upper, **kw), arg]), - ) +def _clip(op, *, arg, lower, upper, **_): + if upper is not None: + arg = if_(arg.is_(NULL), arg, f.least(upper, arg)) - if (lower := op.lower) is not None: - arg = sg.exp.If( - this=arg.is_(NULL), - true=sg.exp.NULL, - false=sg.exp.Greatest.from_arg_list([translate_val(lower, **kw), arg]), - ) + if lower is not None: + arg = if_(arg.is_(NULL), arg, f.greatest(upper, arg)) return arg @translate_val.register(ops.FloorDivide) -def _floor_divide(op, **kw): - new_op = ops.Floor(ops.Divide(op.left, op.right)) - return translate_val(new_op, **kw) +def _floor_divide(op, *, left, right, **_): + return cast(f.fdiv(left, right), op.dtype) @translate_val.register(ops.Round) -def _round(op, **kw): - arg = translate_val(op.arg, **kw) - if (digits := op.digits) is not None: - return sg.exp.Round(this=arg, decimals=translate_val(digits, **kw)) +def _round(op, *, arg, digits, **_): + if digits is not None: + return sg.exp.Round(this=arg, decimals=digits) return sg.exp.Round(this=arg) @@ -353,96 +338,72 @@ def _round(op, **kw): @translate_val.register(ops.Cast) -def _cast(op, **kw): - arg = translate_val(op.arg, **kw) - to = op.to - +def _cast(op, *, arg, to, **kw): if to.is_interval(): - return sg.func( - f"to_{_interval_suffixes[to.unit.short]}", - sg.cast(arg, to=DuckDBType.from_ibis(dt.int32)), + return f[f"to_{_interval_suffixes[to.unit.short]}"]( + sg.cast(arg, to=DuckDBType.from_ibis(dt.int32)) ) elif to.is_timestamp() and op.arg.dtype.is_integer(): - return sg.func("to_timestamp", arg) + return f.to_timestamp(arg) - return sg.cast(expression=arg, to=translate_val(to, **kw)) + return cast(arg, to) @translate_val.register(ops.TryCast) -def _try_cast(op, **kw): - return sg.exp.TryCast( - this=translate_val(op.arg, **kw), to=DuckDBType.from_ibis(op.to) - ) +def _try_cast(op, *, arg, to, **_): + return sg.exp.TryCast(this=arg, to=DuckDBType.from_ibis(to)) ### Comparator Conundrums @translate_val.register(ops.Between) -def _between(op, **kw): - arg = translate_val(op.arg, **kw) - lower_bound = translate_val(op.lower_bound, **kw) - upper_bound = translate_val(op.upper_bound, **kw) +def _between(op, *, arg, lower_bound, upper_bound, **_): return sg.exp.Between(this=arg, low=lower_bound, high=upper_bound) @translate_val.register(ops.Negate) -def _negate(op, **kw): - arg = translate_val(op.arg, **kw) +def _negate(op, *, arg, **_): return sg.exp.Neg(this=arg) @translate_val.register(ops.Not) -def _not(op, **kw): - arg = translate_val(op.arg, **kw) +def _not(op, *, arg, **_): return sg.exp.Not(this=arg) -def _apply_agg_filter(expr, *, where, **kw): - if where is not None: - return sg.exp.Filter( - this=expr, expression=sg.exp.Where(this=translate_val(unalias(where), **kw)) - ) - return expr - - @translate_val.register(ops.NotAny) -def _not_any(op, **kw): - return translate_val(ops.All(ops.Not(op.arg), where=op.where), **kw) +def _not_any(op, *, arg, where, **kw): + return agg.bool_and(sg.not_(arg), where=where) @translate_val.register(ops.NotAll) -def _not_all(op, **kw): - return translate_val(ops.Any(ops.Not(op.arg), where=op.where), **kw) +def _not_all(op, *, arg, where, **kw): + return agg.bool_or(sg.not_(arg), where=where) ### Timey McTimeFace @translate_val.register(ops.Date) -def _to_date(op, **kw): - arg = translate_val(op.arg, **kw) +def _to_date(op, *, arg, **_): return sg.exp.Date(this=arg) @translate_val.register(ops.DateFromYMD) -def _date_from_ymd(op, **kw): - y = translate_val(op.year, **kw) - m = translate_val(op.month, **kw) - d = translate_val(op.day, **kw) - return sg.exp.DateFromParts(year=y, month=m, day=d) +def _date_from_ymd(op, *, year, month, day, **_): + return sg.exp.DateFromParts(year=year, month=month, day=day) @translate_val.register(ops.Time) -def _time(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.cast(expression=arg, to=sg.exp.DataType.Type.TIME) +def _time(op, *, arg, **_): + return cast(arg, to=dt.time) @translate_val.register(ops.TimestampNow) -def _timestamp_now(op, **kw): +def _timestamp_now(op, **_): """DuckDB current timestamp defaults to timestamp + tz.""" - return sg.cast(expression=sg.func("current_timestamp"), to="TIMESTAMP") + return cast(f.current_timestamp(), dt.timestamp) _POWERS_OF_TEN = { @@ -454,11 +415,10 @@ def _timestamp_now(op, **kw): @translate_val.register(ops.TimestampFromUNIX) -def _timestamp_from_unix(op, **kw): - arg = translate_val(op.arg, **kw) - unit = op.unit.short +def _timestamp_from_unix(op, *, arg, unit, **_): + unit = unit.short if unit == "ms": - return sg.func("epoch_ms", arg) + return f.epoch_ms(arg) elif unit == "s": return sg.exp.UnixToTime(this=arg) else: @@ -466,44 +426,28 @@ def _timestamp_from_unix(op, **kw): @translate_val.register(ops.TimestampFromYMDHMS) -def _timestamp_from_ymdhms(op, **kw): - year = translate_val(op.year, **kw) - month = translate_val(op.month, **kw) - day = translate_val(op.day, **kw) - hour = translate_val(op.hours, **kw) - minute = translate_val(op.minutes, **kw) - second = translate_val(op.seconds, **kw) - - args = [year, month, day, hour, minute, second] +def _timestamp_from_ymdhms(op, *, year, month, day, hours, minutes, seconds, **_): + args = [year, month, day, hours, minutes, seconds] func = "make_timestamp" if (timezone := op.dtype.timezone) is not None: func += "tz" - args.append(sg_literal(timezone)) - return sg.func(func, *args) + args.append(lit(timezone)) + return f[func](*args) @translate_val.register(ops.Strftime) -def _strftime(op, **kw): +def _strftime(op, *, arg, format_str, **_): if not isinstance(op.format_str, ops.Literal): raise com.UnsupportedOperationError( f"DuckDB format_str must be a literal `str`; got {type(op.format_str)}" ) - arg = translate_val(op.arg, **kw) - format_str = translate_val(op.format_str, **kw) - return sg.func("strftime", arg, format_str) + return f.strftime(arg, format_str) @translate_val.register(ops.ExtractEpochSeconds) -def _extract_epoch_seconds(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func( - "epoch", - sg.exp.cast( - expression=arg, - to=sg.exp.DataType.Type.TIMESTAMP, - ), - ) +def _extract_epoch_seconds(op, *, arg, **_): + return f.epoch(cast(arg, dt.timestamp)) _extract_mapping = { @@ -528,38 +472,27 @@ def _extract_epoch_seconds(op, **kw): @translate_val.register(ops.ExtractHour) @translate_val.register(ops.ExtractMinute) @translate_val.register(ops.ExtractSecond) -def _extract_time(op, **kw): +def _extract_time(op, *, arg, **_): part = _extract_mapping[type(op)] - timestamp = translate_val(op.arg, **kw) - return sg.func("extract", sg_literal(part), timestamp) + return f.extract(part, arg) # DuckDB extracts subminute microseconds and milliseconds # so we have to finesse it a little bit @translate_val.register(ops.ExtractMicrosecond) -def _extract_microsecond(op, **kw): - arg = translate_val(op.arg, **kw) - - return sg.exp.Mod( - this=sg.func("extract", sg_literal("us"), arg), - expression=sg_literal(1_000_000, is_string=False), - ) +def _extract_microsecond(op, *, arg, **_): + return sg.exp.Mod(this=f.extract("us", arg), expression=lit(1_000_000)) @translate_val.register(ops.ExtractMillisecond) -def _extract_microsecond(op, **kw): - arg = translate_val(op.arg, **kw) - - return sg.exp.Mod( - this=sg.func("extract", sg_literal("ms"), arg), - expression=sg_literal(1_000, is_string=False), - ) +def _extract_microsecond(op, *, arg, **_): + return sg.exp.Mod(this=f.extract("ms", arg), expression=lit(1_000)) @translate_val.register(ops.DateTruncate) @translate_val.register(ops.TimestampTruncate) @translate_val.register(ops.TimeTruncate) -def _truncate(op, **kw): +def _truncate(op, *, arg, unit, **_): unit_mapping = { "Y": "year", "M": "month", @@ -572,51 +505,29 @@ def _truncate(op, **kw): "us": "us", } - unit = op.unit.short - arg = translate_val(op.arg, **kw) + unit = unit.short try: duckunit = unit_mapping[unit] except KeyError: raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") - return sg.func("date_trunc", sg_literal(duckunit), arg) + return f.date_trunc(duckunit, arg) @translate_val.register(ops.DayOfWeekIndex) -def _day_of_week_index(op, **kw): - arg = translate_val(op.arg, **kw) - return (sg.func("dayofweek", arg) + 6) % 7 +def _day_of_week_index(op, *, arg, **_): + return (f.dayofweek(arg) + 6) % 7 @translate_val.register(ops.DayOfWeekName) -def day_of_week_name(op, **kw): +def day_of_week_name(op, *, arg, **_): # day of week number is 0-indexed # Sunday == 0 # Saturday == 6 - arg = op.arg - nullable = arg.dtype.nullable - empty_string = ops.Literal("", dtype=dt.String(nullable=nullable)) weekdays = range(7) - return translate_val( - ops.NullIf( - ops.SimpleCase( - base=ops.DayOfWeekIndex(arg), - cases=[ - ops.Literal(day, dtype=dt.Int8(nullable=nullable)) - for day in weekdays - ], - results=[ - ops.Literal( - calendar.day_name[day], - dtype=dt.String(nullable=nullable), - ) - for day in weekdays - ], - default=empty_string, - ), - empty_string, - ), - **kw, + return sg.exp.Case( + this=(f.dayofweek(arg) + 6) % 7, + ifs=[if_(day, calendar.day_name[day]) for day in weekdays], ) @@ -633,11 +544,8 @@ def day_of_week_name(op, **kw): @translate_val.register(ops.IntervalAdd) @translate_val.register(ops.IntervalSubtract) @translate_val.register(ops.IntervalMultiply) -def _interval_binary(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) +def _interval_binary(op, *, left, right, **_): sg_expr = _interval_mapping[type(op)] - return sg_expr(this=left, expression=right) @@ -648,244 +556,184 @@ def _interval_format(op): "Duckdb doesn't support nanosecond interval resolutions" ) - return sg.exp.Interval( - this=sg_literal(op.value, is_string=False), unit=dtype.resolution.upper() - ) + return sg.exp.Interval(this=lit(op.value), unit=dtype.resolution.upper()) @translate_val.register(ops.IntervalFromInteger) -def _interval_from_integer(op, **kw): +def _interval_from_integer(op, *, arg, **_): dtype = op.dtype if dtype.unit.short == "ns": raise com.UnsupportedOperationError( "Duckdb doesn't support nanosecond interval resolutions" ) - arg = translate_val(op.arg, **kw) if op.dtype.resolution == "week": - return sg.func("to_days", arg * 7) - return sg.func(f"to_{op.dtype.resolution}s", arg) + return f.to_days(arg * 7) + return f[f"to_{op.dtype.resolution}s"](arg) ### String Instruments @translate_val.register(ops.Strip) -def _strip(op, **kw): - return sg.func("trim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) +def _strip(op, *, arg, **_): + return f.trim(arg, string.whitespace) @translate_val.register(ops.RStrip) -def _rstrip(op, **kw): - return sg.func("rtrim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) +def _rstrip(op, *, arg, **_): + return f.rtrim(arg, string.whitespace) @translate_val.register(ops.LStrip) -def _lstrip(op, **kw): - return sg.func("ltrim", translate_val(op.arg, **kw), sg_literal(string.whitespace)) +def _lstrip(op, *, arg, **_): + return f.ltrim(arg, string.whitespace) @translate_val.register(ops.Substring) -def _substring(op, **kw): - arg = translate_val(op.arg, **kw) - start = translate_val(op.start, **kw) - if op.length is not None: - length = translate_val(op.length, **kw) - else: - length = None - +def _substring(op, *, arg, start, length, **_): if_pos = sg.exp.Substring(this=arg, start=start + 1, length=length) if_neg = sg.exp.Substring(this=arg, start=start, length=length) - return sg.exp.If( - this=sg.exp.GTE(this=start, expression=sg_literal(0, is_string=False)), - true=if_pos, - false=if_neg, - ) + return if_(start >= 0, if_pos, if_neg) @translate_val.register(ops.StringFind) -def _string_find(op, **kw): - if op.end is not None: - raise com.UnsupportedOperationError("String find doesn't support end argument") +def _string_find(op, *, arg, substr, start, end, **_): + if end is not None: + raise com.UnsupportedOperationError( + "String find doesn't support `end` argument" + ) - arg = translate_val(op.arg, **kw) - substr = translate_val(op.substr, **kw) + if start is not None: + arg = f.substr(arg, start + 1) + pos = f.strpos(arg, substr) + return if_(pos > 0, pos - 1 + start, -1) - return sg.func("instr", arg, substr) - 1 + return f.strpos(arg, substr) - 1 @translate_val.register(ops.RegexSearch) -def _regex_search(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - return sg.func("regexp_matches", arg, pattern, sg_literal("s")) +def _regex_search(op, *, arg, pattern, **_): + return f.regexp_matches(arg, pattern, "s") @translate_val.register(ops.RegexReplace) -def _regex_replace(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - replacement = translate_val(op.replacement, **kw) - return sg.func("regexp_replace", arg, pattern, replacement, sg_literal("g")) +def _regex_replace(op, *, arg, pattern, replacement, **_): + return f.regexp_replace(arg, pattern, replacement, "g") @translate_val.register(ops.RegexExtract) -def _regex_extract(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - group = translate_val(op.index, **kw) - return sg.func("regexp_extract", arg, pattern, group, dialect="duckdb") +def _regex_extract(op, *, arg, pattern, index, **_): + return f.regexp_extract(arg, pattern, index, dialect="duckdb") @translate_val.register(ops.StringSplit) -def _string_split(op, **kw): - arg = translate_val(op.arg, **kw) - delimiter = translate_val(op.delimiter, **kw) +def _string_split(op, *, arg, delimiter, **_): return sg.exp.Split(this=arg, expression=delimiter) @translate_val.register(ops.StringJoin) -def _string_join(op, **kw): - elements = list(map(partial(translate_val, **kw), op.arg)) - sep = translate_val(op.sep, **kw) - return sg.func( - "list_aggr", sg.exp.Array(expressions=elements), sg_literal("string_agg"), sep - ) +def _string_join(op, *, arg, sep, **_): + return f.list_aggr(f.array(*arg), "string_agg", sep) @translate_val.register(ops.StringConcat) -def _string_concat(op, **kw): - return sg.exp.Concat(expressions=list(map(partial(translate_val, **kw), op.arg))) +def _string_concat(op, *, arg, **_): + return sg.exp.Concat.from_arg_list(list(arg)) @translate_val.register(ops.StringSQLLike) -def _string_like(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - return sg.exp.Like(this=arg, expression=pattern) +def _string_like(op, *, arg, pattern, **_): + return arg.like(pattern) @translate_val.register(ops.StringSQLILike) -def _string_ilike(op, **kw): - arg = translate_val(op.arg, **kw) - pattern = translate_val(op.pattern, **kw) - return sg.exp.ILike(this=arg, expression=pattern) +def _string_ilike(op, *, arg, pattern, **_): + return arg.ilike(pattern) @translate_val.register(ops.Capitalize) -def _string_capitalize(op, **kw): - arg = translate_val(op.arg, **kw) +def _string_capitalize(op, *, arg, **_): return sg.exp.Concat( - expressions=[ - sg.func("upper", sg.func("substr", arg, 1, 1)), - sg.func("lower", sg.func("substr", arg, 2)), - ] + expressions=[f.upper(f.substr(arg, 1, 1)), f.lower(f.substr(arg, 2))] ) ### NULL PLAYER CHARACTER @translate_val.register(ops.IsNull) -def _is_null(op, **kw): - return translate_val(op.arg, **kw).is_(sg.exp.null()) +def _is_null(op, *, arg, **_): + return arg.is_(NULL) @translate_val.register(ops.NotNull) -def _is_not_null(op, **kw): - return translate_val(op.arg, **kw).is_(sg.not_(sg.exp.null())) +def _is_not_null(op, *, arg, **_): + return sg.not_(arg.is_(NULL)) @translate_val.register(ops.IfNull) -def _if_null(op, **kw): - arg = translate_val(op.arg, **kw) - ifnull = translate_val(op.ifnull_expr, **kw) - return sg.func("ifnull", arg, ifnull) +def _if_null(op, *, arg, ifnull_expr, **_): + return f.ifnull(arg, ifnull_expr) @translate_val.register(ops.NullIfZero) -def _null_if_zero(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func("nullif", arg, 0) +def _null_if_zero(op, *, arg, **_): + return f.nullif(arg, 0) @translate_val.register(ops.ZeroIfNull) -def _zero_if_null(op, **kw): - arg = translate_val(op.arg, **kw) - return sg.func("ifnull", arg, 0) +def _zero_if_null(op, *, arg, **_): + return f.ifnull(arg, 0) ### Definitely Not Tensors @translate_val.register(ops.ArrayDistinct) -def _array_sort(op, **kw): - arg = translate_val(op.arg, **kw) - - return sg.exp.If( - this=arg.is_(NULL), - true=NULL, - false=sg.func("list_distinct", arg) - + sg.exp.If( - this=sg.func("list_count", arg) < sg.func("len", arg), - true=sg.exp.Array.from_arg_list([NULL]), - false=sg.exp.Array.from_arg_list([]), - ), +def _array_sort(op, *, arg, **_): + return if_( + arg.is_(NULL), + NULL, + f.list_distinct(arg) + + if_(f.list_count(arg) < f.len(arg), f.array(NULL), f.array()), ) @translate_val.register(ops.ArrayIndex) -def _array_index_op(op, **kw): - arg = translate_val(op.arg, **kw) - index = translate_val(op.index, **kw) +def _array_index_op(op, *, arg, index, **_): correct_idx = sg.func("if", index >= 0, index + 1, index) - return sg.func("list_extract", arg, correct_idx) + return f.list_element(arg, correct_idx) @translate_val.register(ops.InValues) -def _in_values(op, **kw): - if not op.options: - return sg.exp.FALSE - - value = translate_val(op.value, **kw) - return sg.exp.In( - this=value, - expressions=[translate_val(opt, **kw) for opt in op.options], - ) +def _in_values(op, *, value, options, **_): + return value.isin(*options) @translate_val.register(ops.InColumn) -def _in_column(op, **kw): - from ibis.backends.duckdb.compiler import translate - - value = translate_val(op.value, **kw) - options = translate(op.options.to_expr().as_table().op(), {}) - return value.isin(options) +def _in_column(op, *, value, options, **_): + return value.isin(options.this if isinstance(options, sg.exp.Subquery) else options) @translate_val.register(ops.ArrayConcat) -def _array_concat(op, **kw): - result, *rest = map(partial(translate_val, **kw), op.arg) +def _array_concat(op, *, arg, **_): + result, *rest = arg for arg in rest: - result = sg.func("list_concat", result, arg) + result = f.list_concat(result, arg) return result @translate_val.register(ops.ArrayRepeat) -def _array_repeat_op(op, **kw): - arg = translate_val(op.arg, **kw) - times = translate_val(op.times, **kw) - return sg.func( - "flatten", - sg.select( - sg.func("array", sg.select(arg).from_(sg.func("range", unalias(times)))) - ).subquery(), +def _array_repeat_op(op, *, arg, times, **_): + return f.flatten( + sg.select(sg.func("array", sg.select(arg).from_(f.range(times)))).subquery(), ) def _neg_idx_to_pos(array, idx): - arg_length = sg.func("len", array) + arg_length = f.len(array) return sg.exp.If( - this=sg.exp.LT(this=idx, expression=sg_literal(0, is_string=False)), + this=idx < 0, # Need to have the greatest here to handle the case where # abs(neg_index) > arg_length # e.g. where the magnitude of the negative index is greater than the @@ -897,104 +745,86 @@ def _neg_idx_to_pos(array, idx): @translate_val.register(ops.ArraySlice) -def _array_slice_op(op, **kw): - arg = translate_val(op.arg, **kw) - - arg_length = sg.func("len", arg) +def _array_slice_op(op, *, arg, start, stop, **_): + arg_length = f.len(arg) - if (start := op.start) is None: - start = sg_literal(0, is_string=False) + if start is None: + start = 0 else: - start = translate_val(op.start, **kw) - start = sg.func("least", arg_length, _neg_idx_to_pos(arg, start)) + start = f.least(arg_length, _neg_idx_to_pos(arg, start)) - if (stop := op.stop) is None: + if stop is None: stop = NULL else: - stop = _neg_idx_to_pos(arg, translate_val(stop, **kw)) + stop = _neg_idx_to_pos(arg, stop) - return sg.func("list_slice", arg, start + 1, stop) + return f.list_slice(arg, start + 1, stop) @translate_val.register(ops.ArrayStringJoin) -def _array_string_join(op, **kw): - arg = translate_val(op.arg, **kw) - sep = translate_val(op.sep, **kw) - return sg.func("list_aggr", arg, sg_literal("string_agg"), sep) +def _array_string_join(op, *, sep, arg, **_): + return f.list_aggr(arg, "string_agg", sep) @translate_val.register(ops.ArrayMap) -def _array_map(op, **kw): - arg = translate_val(op.arg, **kw) - result = translate_val(op.body, **kw) - lamduh = sg.exp.Lambda( - this=result, - expressions=[sg.to_identifier(op.param, quoted=False)], - ) - return sg.func("list_transform", arg, lamduh) +def _array_map(op, *, arg, body, param, **_): + lamduh = sg.exp.Lambda(this=body, expressions=[sg.to_identifier(param)]) + return f.list_apply(arg, lamduh) @translate_val.register(ops.ArrayFilter) -def _array_filter(op, **kw): - arg = translate_val(op.arg, **kw) - result = translate_val(op.body, **kw) - lamduh = sg.exp.Lambda( - this=result, - expressions=[sg.to_identifier(op.param, quoted=False)], - ) - return sg.func("list_filter", arg, lamduh) +def _array_filter(op, *, arg, body, param, **_): + lamduh = sg.exp.Lambda(this=body, expressions=[sg.to_identifier(param)]) + return f.list_filter(arg, lamduh) @translate_val.register(ops.ArrayIntersect) -def _array_intersect(op, **kw): - param = "x" - x = ops.Argument(name=param, shape=op.left.shape, dtype=op.left.dtype.value_type) - body = ops.ArrayContains(op.right, x) - return translate_val(ops.ArrayFilter(arg=op.left, body=body, param=param), **kw) +def _array_intersect(op, *, left, right, **_): + param = sg.to_identifier("x") + body = f.list_contains(right, param) + lamduh = sg.exp.Lambda(this=body, expressions=[param]) + return f.list_filter(left, lamduh) @translate_val.register(ops.ArrayPosition) -def _array_position(op, **kw): - arg = translate_val(op.arg, **kw) - el = translate_val(op.other, **kw) - return sg.func("list_indexof", arg, el) - 1 +def _array_position(op, *, arg, other, **_): + return f.list_indexof(arg, other) - 1 @translate_val.register(ops.ArrayRemove) -def _array_remove(op, **kw): - param = "x" - arg = op.arg - x = ops.Argument(name=param, shape=arg.shape, dtype=arg.dtype.value_type) - body = ops.NotEquals(x, op.other) - return translate_val(ops.ArrayFilter(arg=arg, body=body, param=param), **kw) +def _array_remove(op, *, arg, other, **_): + param = sg.to_identifier("x") + body = param.neq(other) + lamduh = sg.exp.Lambda(this=body, expressions=[param]) + return f.list_filter(arg, lamduh) @translate_val.register(ops.ArrayUnion) -def _array_union(op, **kw): - return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) +def _array_union(op, *, left, right, **_): + arg = f.list_concat(left, right) + return if_( + arg.is_(NULL), + NULL, + f.list_distinct(arg) + + if_(f.list_count(arg) < f.len(arg), f.array(NULL), f.array()), + ) @translate_val.register(ops.ArrayZip) -def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: - i = sg.to_identifier("i", quoted=False) - args = [translate_val(arg, **kw) for arg in op.arg] +def _array_zip(op: ops.ArrayZip, *, arg, **_) -> str: + i = sg.to_identifier("i") result = sg.exp.Struct( expressions=[ - sg.exp.Slice( - this=sg_literal(name), - expression=sg.func("list_extract", arg, i), - ) - for name, arg in zip(op.dtype.value_type.names, args) + sg.exp.Slice(this=lit(name), expression=arg[i]) + for name, arg in zip(op.dtype.value_type.names, arg) ] ) lamduh = sg.exp.Lambda(this=result, expressions=[i]) - return sg.func( - "list_transform", - sg.func( - "range", - sg_literal(1, is_string=False), - # DuckDB Range is not inclusive of upper bound - sg.func("greatest", *[sg.func("len", arg) for arg in args]) + 1, + return f.list_apply( + f.range( + 1, + # DuckDB Range excludes upper bound + f.greatest(*map(f.len, arg)) + 1, ), lamduh, ) @@ -1004,40 +834,33 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str: @translate_val.register(ops.CountDistinct) -def _count_distinct(op, **kw): - arg = translate_val(op.arg, **kw) - count_expr = sg.exp.Count(this=sg.exp.Distinct(expressions=[arg])) - return _apply_agg_filter(count_expr, where=op.where, **kw) +def _count_distinct(op, *, arg, where, **_): + return agg.count(sg.exp.Distinct(expressions=[arg]), where=where) @translate_val.register(ops.CountDistinctStar) -def _count_distinct_star(op, **kw): +def _count_distinct_star(op, *, arg, where, **_): # use a tuple because duckdb doesn't accept COUNT(DISTINCT a, b, c, ...) # # this turns the expression into COUNT(DISTINCT (a, b, c, ...)) row = sg.exp.Tuple(expressions=list(map(sg.column, op.arg.schema.keys()))) - expr = sg.exp.Count(this=sg.exp.Distinct(expressions=[row])) - return _apply_agg_filter(expr, where=op.where, **kw) + return agg.count(sg.exp.Distinct(expressions=[row]), where=where) @translate_val.register(ops.CountStar) -def _count_star(op, **kw): - return _apply_agg_filter(sg.exp.Count(this=sg.exp.Star()), where=op.where, **kw) +def _count_star(op, *, where, **_): + return agg.count(STAR, where=where) @translate_val.register(ops.Sum) -def _sum(op, **kw): - arg = translate_val( - ops.Cast(arg, to=op.dtype) if (arg := op.arg).dtype.is_boolean() else arg, **kw - ) - return _apply_agg_filter(sg.exp.Sum(this=arg), where=op.where, **kw) +def _sum(op, *, arg, where, **_): + arg = cast(arg, op.dtype) if op.arg.dtype.is_boolean() else arg + return agg.sum(arg, where=where) @translate_val.register(ops.NthValue) -def _nth_value(op, **kw): - arg = translate_val(op.arg, **kw) - nth = translate_val(op.nth, **kw) - return sg.func("nth_value", arg, nth + 1) +def _nth_value(op, *, arg, nth, **_): + return f.nth_value(arg, nth + 1) ### Stats @@ -1045,228 +868,128 @@ def _nth_value(op, **kw): @translate_val.register(ops.Quantile) @translate_val.register(ops.MultiQuantile) -def _quantile(op, **kw): - arg = translate_val(op.arg, **kw) - quantile = translate_val(op.quantile, **kw) - sg_expr = sg.func("quantile_cont", arg, quantile) - return _apply_agg_filter(sg_expr, where=op.where, **kw) +def _quantile(op, *, arg, quantile, where, **_): + return agg.quantile_cont(arg, quantile, where=where) @translate_val.register(ops.Correlation) -def _corr(op, **kw): - if op.how == "sample": +def _corr(op, *, left, right, how, where, **_): + if how == "sample": raise com.UnsupportedOperationError( "DuckDB only implements `pop` correlation coefficient" ) - left = translate_val(op.left, **kw) + # TODO: rewrite rule? if (left_type := op.left.dtype).is_boolean(): - left = sg.cast( - expression=left, - to=DuckDBType.from_ibis(dt.Int32(nullable=left_type.nullable)), - ) + left = cast(left, dt.Int32(nullable=left_type.nullable)) - right = translate_val(op.right, **kw) if (right_type := op.right.dtype).is_boolean(): - right = sg.cast( - expression=right, - to=DuckDBType.from_ibis(dt.Int32(nullable=right_type.nullable)), - ) + right = cast(right, dt.Int32(nullable=right_type.nullable)) - sg_func = sg.func("corr", left, right) - return _apply_agg_filter(sg_func, where=op.where, **kw) + return agg.corr(left, right, where=where) @translate_val.register(ops.Covariance) -def _covariance(op, **kw): +def _covariance(op, *, left, right, how, where, **_): _how = {"sample": "samp", "pop": "pop"} - left = translate_val(op.left, **kw) + # TODO: rewrite rule? if (left_type := op.left.dtype).is_boolean(): - left = sg.cast( - expression=left, - to=DuckDBType.from_ibis(dt.Int32(nullable=left_type.nullable)), - ) + left = cast(left, dt.Int32(nullable=left_type.nullable)) - right = translate_val(op.right, **kw) if (right_type := op.right.dtype).is_boolean(): - right = sg.cast( - expression=right, - to=DuckDBType.from_ibis(dt.Int32(nullable=right_type.nullable)), - ) + right = cast(right, dt.Int32(nullable=right_type.nullable)) - sg_func = sg.func(f"covar_{_how[op.how]}", left, right) - return _apply_agg_filter(sg_func, where=op.where, **kw) + return agg[f"covar_{_how[how]}"](left, right, where=where) @translate_val.register(ops.Variance) @translate_val.register(ops.StandardDev) -def _variance(op, **kw): +def _variance(op, *, arg, how, where, **_): _how = {"sample": "samp", "pop": "pop"} _func = {ops.Variance: "var", ops.StandardDev: "stddev"} - arg = op.arg - if (arg_dtype := arg.dtype).is_boolean(): - arg = ops.Cast(arg, to=dt.Int32(nullable=arg_dtype)) + if (arg_dtype := op.arg.dtype).is_boolean(): + arg = cast(arg, dt.Int32(nullable=arg_dtype)) - arg = translate_val(arg, **kw) - - sg_func = sg.func(f"{_func[type(op)]}_{_how[op.how]}", arg) - return _apply_agg_filter(sg_func, where=op.where, **kw) + return agg[f"{_func[type(op)]}_{_how[how]}"](arg, where=where) @translate_val.register(ops.Arbitrary) -def _arbitrary(op, **kw): - if op.how == "heavy": +def _arbitrary(op, *, arg, how, where, **_): + if how == "heavy": raise com.UnsupportedOperationError("how='heavy' not supported in the backend") - functions = { - "first": "first", - "last": "last", - } - return _aggregate(op, functions[op.how], **kw) + funcs = {"first": agg.first, "last": agg.last} + return funcs[how](arg, where=where) @translate_val.register(ops.FindInSet) -def _index_of(op: ops.FindInSet, **kw): - needle = translate_val(op.needle, **kw) - args = sg.exp.Array(expressions=list(map(partial(translate_val, **kw), op.values))) - return sg.func("list_indexof", args, needle) - 1 - - -@translate_val.register(tuple) -def _node_list(op, **kw): - return sg.exp.Tuple(expressions=list(map(partial(translate_val, **kw), op))) +def _index_of(op, *, needle, values, **_): + return f.list_indexof(f.array(*values), needle) - 1 @translate_val.register(ops.SimpleCase) @translate_val.register(ops.SearchedCase) -def _case(op, **kw): - case = sg.exp.Case() - - if (base := getattr(op, "base", None)) is not None: - case = sg.exp.Case(this=translate_val(base, **kw)) - - for when, then in zip(op.cases, op.results): - case = case.when( - condition=translate_val(when, **kw), - then=translate_val(then, **kw), - ) - - if (default := op.default) is not None: - case = case.else_(condition=translate_val(default, **kw)) - - return case +def _case(op, *, base=None, cases, results, default, **_): + return sg.exp.Case(this=base, ifs=list(map(if_, cases, results)), default=default) @translate_val.register(ops.TableArrayView) -def _table_array_view(op, *, cache, **kw): - table = op.table - try: - return cache[table] - except KeyError: - from ibis.backends.duckdb.compiler import translate - - return translate(table, {}) +def _table_array_view(op, *, table, **_): + return table.args["this"].subquery() @translate_val.register(ops.ExistsSubquery) -def _exists_subquery(op, **kw): - from ibis.backends.duckdb.compiler import translate - - foreign_table = translate(op.foreign_table, {}) - - # only construct a subquery if we cannot refer to the table directly - if isinstance(foreign_table, sg.exp.Select): - foreign_table = foreign_table.subquery() - - predicate = sg.and_(*map(partial(translate_val, **kw), map(unalias, op.predicates))) - return sg.exp.Exists(this=sg.select(1).from_(foreign_table).where(predicate)) - - -@translate_val.register(ops.NotExistsSubquery) -def _not_exists_subquery(op, **kw): - return sg.not_(_exists_subquery(op, **kw)) +def _exists_subquery(op, *, foreign_table, predicates, **_): + subq = sg.select(1).from_(foreign_table).where(sg.condition(predicates)).subquery() + return f.exists(subq) @translate_val.register(ops.ArrayColumn) -def _array_column(op, **kw): - return sg.exp.Array(expressions=[translate_val(col, **kw) for col in op.cols]) +def _array_column(op, *, cols, **kw): + return f.array(*cols) @translate_val.register(ops.StructColumn) -def _struct_column(op, **kw): +def _struct_column(op, *, names, values, **_): return sg.exp.Struct.from_arg_list( [ - sg.exp.Slice(this=sg_literal(name), expression=translate_val(value, **kw)) - for name, value in zip(op.names, op.values) + sg.exp.Slice(this=lit(name), expression=value) + for name, value in zip(names, values) ] ) @translate_val.register(ops.StructField) -def _struct_field(op, **kw): - arg = translate_val(unalias(op.arg), **kw) - return sg.exp.StructExtract(this=arg, expression=sg_literal(op.field)) - - -@translate_val.register(ops.ScalarParameter) -def _scalar_param(op, params: Mapping[ops.Node, Any], **kw): - raw_value = params[op] - dtype = op.dtype - if isinstance(dtype, dt.Struct): - literal = ibis.struct(raw_value, type=dtype) - elif isinstance(dtype, dt.Map): - literal = ibis.map(raw_value) - else: - literal = ibis.literal(raw_value, type=dtype) - return translate_val(literal.op(), **kw) +def _struct_field(op, *, arg, field, **_): + return sg.exp.StructExtract(this=arg, expression=lit(field)) @translate_val.register(ops.IdenticalTo) -def _identical_to(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) +def _identical_to(op, *, left, right, **_): return sg.exp.NullSafeEQ(this=left, expression=right) @translate_val.register(ops.Greatest) @translate_val.register(ops.Least) @translate_val.register(ops.Coalesce) -def _vararg_func(op, **kw): - return sg.func( - f"{op.__class__.__name__.lower()}", *map(partial(translate_val, **kw), op.arg) - ) +def _vararg_func(op, *, arg, **_): + return f[op.__class__.__name__.lower()](*arg) @translate_val.register(ops.MapGet) -def _map_get(op, **kw): - arg = translate_val(op.arg, **kw) - key = translate_val(op.key, **kw) - default = translate_val(op.default, **kw) - return sg.func( - "ifnull", sg.func("list_extract", sg.func("element_at", arg, key), 1), default - ) +def _map_get(op, *, arg, key, default, **_): + return f.ifnull(f.list_extract(f.element_at(arg, key), 1), default) @translate_val.register(ops.MapContains) -def _map_contains(op, **kw): - arg = translate_val(op.arg, **kw) - key = translate_val(op.key, **kw) - return sg.exp.NEQ( - this=sg.func("array_length", sg.func("element_at", arg, key)), - expression=sg_literal(0, is_string=False), - ) +def _map_contains(op, *, arg, key, **_): + return sg.exp.NEQ(this=f.array_length(f.element_at(arg, key)), expression=lit(0)) def _binary_infix(sg_expr: sg.exp._Expression): - def formatter(op, **kw): - left = translate_val(op.left, **kw) - right = translate_val(op.right, **kw) - - return sg_expr( - this=sg.exp.Paren(this=left), - expression=sg.exp.Paren(this=right), - ) + def formatter(op, *, left, right, **_): + return sg.exp.Paren(this=sg_expr(this=left, expression=right)) return formatter @@ -1288,6 +1011,7 @@ def formatter(op, **kw): # Boolean comparisons ops.And: sg.exp.And, ops.Or: sg.exp.Or, + ops.Xor: sg.exp.Xor, ops.DateAdd: sg.exp.Add, ops.DateSub: sg.exp.Sub, ops.DateDiff: sg.exp.Sub, @@ -1303,14 +1027,14 @@ def formatter(op, **kw): del _op, _sym -@translate_val.register(ops.Xor) -def _xor(op, **kw): - # https://github.com/tobymao/sqlglot/issues/2238 - left = translate_val(op.left, **kw).sql("duckdb") - right = translate_val(op.right, **kw).sql("duckdb") - return sg.parse_one( - f"({left} OR {right}) AND NOT ({left} AND {right})", read="duckdb" - ) +# @translate_val.register(ops.Xor) +# def _xor(op, **kw): +# # https://github.com/tobymao/sqlglot/issues/2238 +# left = translate_val(op.left, **kw).sql("duckdb") +# right = translate_val(op.right, **kw).sql("duckdb") +# return sg.parse_one( +# f"({left} OR {right}) AND NOT ({left} AND {right})", read="duckdb" +# ) ### Ordering @@ -1342,9 +1066,8 @@ def _cume_dist(_, **kw): @translate_val.register -def _sort_key(op: ops.SortKey, **kw): - arg = translate_val(op.expr, **kw) - return sg.exp.Ordered(this=arg, desc=not op.ascending) +def _sort_key(op: ops.SortKey, *, expr, ascending: bool, **_): + return sg.exp.Ordered(this=expr, desc=not ascending) ### Window functions @@ -1381,84 +1104,70 @@ def cumulative_to_window(func, frame): @translate_val.register(ops.ApproxMedian) -def _approx_median(op, **kw): - expr = sg.func( - "approx_quantile", translate_val(op.arg, **kw), sg_literal(0.5, is_string=False) - ) - return _apply_agg_filter(expr, where=op.where, **kw) +def _approx_median(op, *, arg, where, **_): + return agg.approx_quantile(arg, lit(0.5), where=where) -@translate_val.register(ops.WindowFunction) -def _window(op: ops.WindowFunction, **kw: Any): - func = op.func - frame = op.frame +@translate_val.register(ops.WindowBoundary) +def _window_boundary(op, *, value, preceding, **_): + # TODO: bit of a hack to return a dict, but there's no sqlglot expression + # that corresponds to _only_ this information + return {"value": value, "side": "preceding" if preceding else "following"} - if isinstance(func, ops.CumulativeOp): - arg = cumulative_to_window(func, op.frame) - return translate_val(arg, **kw) - tr_val = partial(translate_val, **kw) - this = tr_val(func, **kw) +@translate_val.register(ops.WindowFrame) +def _window_frame(op, *, group_by, order_by, start=None, end=None, **_): + if start is None: + start = {} - if frame.start is None: - start = "UNBOUNDED" - else: - start = tr_val(frame.start.value, **kw) + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") - if frame.end is None: - end = "UNBOUNDED" - else: - end = tr_val(frame.end.value, **kw) + if end is None: + end = {} + + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") spec = sg.exp.WindowSpec( - kind=frame.how.upper(), - start=start, - start_side="preceding", - end=end, - end_side="following", + kind=op.how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, over="OVER", ) - partition_by = list(map(tr_val, frame.group_by)) or None - - order_bys = list(map(tr_val, frame.order_by)) + order = sg.exp.Order(expressions=order_by) if order_by else None - if isinstance(func, ops.Analytic) and not isinstance(func, ops.ShiftBase): - order_bys.extend(map(tr_val, func.args)) + # TODO: bit of a hack to return a partial, but similar to `WindowBoundary` + # there's no sqlglot expression that corresponds to _only_ this information + return partial(sg.exp.Window, partition_by=group_by, order=order, spec=spec) - order = sg.exp.Order(expressions=order_bys) if order_bys else None - window = sg.exp.Window(this=this, partition_by=partition_by, order=order, spec=spec) +@translate_val.register(ops.WindowFunction) +def _window(op: ops.WindowFunction, *, func, frame, **_: Any): + window = frame(this=func) # preserve zero-based indexing - if isinstance(func, ops.RankBase): + if isinstance(op.func, ops.RankBase): return window - 1 return window def shift_like(op_class, name): @translate_val.register(op_class) - def formatter(op, **kw): - arg = op.arg - offset = op.offset - default = op.default - - arg_fmt = translate_val(arg, **kw) - args = [arg_fmt] + def formatter(op, *, arg, offset, default, **_): + args = [arg] if default is not None: if offset is None: - offset_fmt = "1" - else: - offset_fmt = translate_val(offset, **kw) - - default_fmt = translate_val(default, **kw) + offset = lit(1) - args.append(offset_fmt) - args.append(default_fmt) + args.append(offset) + args.append(default) elif offset is not None: - offset_fmt = translate_val(offset, **kw) - args.append(offset_fmt) + args.append(offset) return sg.func(name, *args) @@ -1475,17 +1184,16 @@ def _argument(op, **_): @translate_val.register(ops.RowID) -def _rowid(op, *, aliases, **_) -> str: - table = op.table - return sg.column(op.name, aliases.get(table, table.name)) +def _rowid(op, *, table, **_) -> str: + return sg.column(op.name, table=table.alias_or_name) @translate_val.register(ops.ScalarUDF) -def _scalar_udf(op, **kw) -> str: - funcname = op.__class__.__name__ - return sg.func(funcname, *(translate_val(arg, **kw) for arg in op.args)) +def _scalar_udf(op, *, aliases, **kw) -> str: + funcname = op.__full_name__ + return f[funcname](*kw.values()) @translate_val.register(ops.AggUDF) -def _agg_udf(op, **kw) -> str: - return _aggregate(op, op.__class__.__name__, **kw) +def _agg_udf(op, *, aliases, where, **kw) -> str: + return agg[op.__class__.__name__](*kw.values(), where=where) From b68937451737c336834d4b11f54c609edfc7e4da Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:16:13 -0400 Subject: [PATCH 214/222] chore: get almost everything working --- ibis/backends/base/sqlglot/compiler/core.py | 4 +- ibis/backends/duckdb/compiler/relations.py | 4 +- ibis/backends/duckdb/compiler/values.py | 47 ++++++++++----------- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/ibis/backends/base/sqlglot/compiler/core.py b/ibis/backends/base/sqlglot/compiler/core.py index e37f4f211f82..3afdf890959d 100644 --- a/ibis/backends/base/sqlglot/compiler/core.py +++ b/ibis/backends/base/sqlglot/compiler/core.py @@ -151,8 +151,8 @@ def fn(node, _, **kwargs): False, dtype="bool" ) - replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(x) >> c.Not( - c.ExistsSubquery(x) + replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(x, y) >> c.Not( + c.ExistsSubquery(x, y) ) op = op.replace( diff --git a/ibis/backends/duckdb/compiler/relations.py b/ibis/backends/duckdb/compiler/relations.py index 641cf1f3df4a..700d5a044e30 100644 --- a/ibis/backends/duckdb/compiler/relations.py +++ b/ibis/backends/duckdb/compiler/relations.py @@ -213,7 +213,9 @@ def _distinct(_: ops.Distinct, *, table, **kw): @translate_rel.register(ops.DropNa) def _dropna(op: ops.DropNa, *, table, how, subset, **_): if subset is None: - subset = [sg.column(name, table=table) for name in op.table.schema.names] + subset = [ + sg.column(name, table=table.alias_or_name) for name in op.table.schema.names + ] if subset: predicate = functools.reduce( diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 0a3af382dd5c..d970a105002a 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -74,14 +74,14 @@ def _literal(op, *, value, dtype, **kw): elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): - return lit(value) + return lit(str(value)) elif dtype.is_numeric(): # cast non finite values to float because that's the behavior of # duckdb when a mixed decimal/float operation is performed # # float will be upcast to double if necessary by duckdb if not math.isfinite(value): - return cast(lit(value), to=dt.float32 if dtype.is_decimal() else dtype) + return cast(lit(str(value)), to=dt.float32 if dtype.is_decimal() else dtype) return cast(lit(value), dtype) elif dtype.is_time(): return cast(lit(value), dtype) @@ -97,11 +97,9 @@ def _literal(op, *, value, dtype, **kw): second += microsecond if (tz := dtype.timezone) is not None: timezone = lit(tz) - return sg.func( - "make_timestamptz", year, month, day, hour, minute, second, timezone - ) + return f.make_timestamptz(year, month, day, hour, minute, second, timezone) else: - return sg.func("make_timestamp", year, month, day, hour, minute, second) + return f.make_timestamp(year, month, day, hour, minute, second) elif dtype.is_date(): year = lit(value.year) month = lit(value.month) @@ -120,19 +118,19 @@ def _literal(op, *, value, dtype, **kw): elif dtype.is_map(): key_type = dtype.key_type value_type = dtype.value_type - keys = sg.exp.Array.from_arg_list( - [ + keys = f.array( + *( _literal(ops.Literal(k, dtype=key_type), value=k, dtype=key_type, **kw) for k in value.keys() - ] + ) ) - values = sg.exp.Array.from_arg_list( - [ + values = f.array( + *( _literal( ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw ) for v in value.values() - ] + ) ) return sg.exp.Map(keys=keys, values=values) elif dtype.is_struct(): @@ -305,7 +303,7 @@ def _clip(op, *, arg, lower, upper, **_): arg = if_(arg.is_(NULL), arg, f.least(upper, arg)) if lower is not None: - arg = if_(arg.is_(NULL), arg, f.greatest(upper, arg)) + arg = if_(arg.is_(NULL), arg, f.greatest(lower, arg)) return arg @@ -701,8 +699,8 @@ def _array_sort(op, *, arg, **_): @translate_val.register(ops.ArrayIndex) def _array_index_op(op, *, arg, index, **_): - correct_idx = sg.func("if", index >= 0, index + 1, index) - return f.list_element(arg, correct_idx) + correct_idx = if_(index >= 0, index + 1, index) + return f.list_extract(arg, correct_idx) @translate_val.register(ops.InValues) @@ -726,7 +724,7 @@ def _array_concat(op, *, arg, **_): @translate_val.register(ops.ArrayRepeat) def _array_repeat_op(op, *, arg, times, **_): return f.flatten( - sg.select(sg.func("array", sg.select(arg).from_(f.range(times)))).subquery(), + sg.select(f.array(sg.select(arg).from_(f.range(times)))).subquery() ) @@ -739,7 +737,7 @@ def _neg_idx_to_pos(array, idx): # e.g. where the magnitude of the negative index is greater than the # length of the array # You cannot index a[:-3] if a = [1, 2] - true=arg_length + sg.func("greatest", idx, -1 * arg_length), + true=arg_length + f.greatest(idx, -1 * arg_length), false=idx, ) @@ -941,7 +939,7 @@ def _table_array_view(op, *, table, **_): @translate_val.register(ops.ExistsSubquery) def _exists_subquery(op, *, foreign_table, predicates, **_): - subq = sg.select(1).from_(foreign_table).where(sg.condition(predicates)).subquery() + subq = sg.select(1).from_(foreign_table).where(sg.and_(*predicates)).subquery() return f.exists(subq) @@ -962,7 +960,8 @@ def _struct_column(op, *, names, values, **_): @translate_val.register(ops.StructField) def _struct_field(op, *, arg, field, **_): - return sg.exp.StructExtract(this=arg, expression=lit(field)) + val = arg.this if isinstance(op.arg, ops.Alias) else arg + return val[lit(field)] @translate_val.register(ops.IdenticalTo) @@ -1047,22 +1046,22 @@ def _row_number(_, **kw): @translate_val.register(ops.DenseRank) def _dense_rank(_, **kw): - return sg.func("dense_rank") + return f.dense_rank() @translate_val.register(ops.MinRank) def _rank(_, **kw): - return sg.func("rank") + return f.rank() @translate_val.register(ops.PercentRank) def _percent_rank(_, **kw): - return sg.func("percent_rank") + return f.percent_rank() @translate_val.register(ops.CumeDist) def _cume_dist(_, **kw): - return sg.func("cume_dist") + return f.cume_dist() @translate_val.register @@ -1169,7 +1168,7 @@ def formatter(op, *, arg, offset, default, **_): elif offset is not None: args.append(offset) - return sg.func(name, *args) + return f[name](*args) return formatter From ed749687a8f5845a0d8db55a131bb5dbdd95b37f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:18:26 -0400 Subject: [PATCH 215/222] chore: regen sql --- .../test_many_subqueries/duckdb/out.sql | 24 ++-- .../duckdb/out.sql | 12 +- .../test_sql/test_isin_bug/duckdb/out.sql | 4 +- .../test_union_aliasing/duckdb/out.sql | 36 ++--- .../test_h01/test_tpc_h01/duckdb/h01.sql | 34 ++--- .../test_h04/test_tpc_h04/duckdb/h04.sql | 32 ++--- .../test_h06/test_tpc_h06/duckdb/h06.sql | 16 +-- .../test_h12/test_tpc_h12/duckdb/h12.sql | 20 +-- .../test_h13/test_tpc_h13/duckdb/h13.sql | 7 +- .../test_h14/test_tpc_h14/duckdb/h14.sql | 40 ++---- .../test_h15/test_tpc_h15/duckdb/h15.sql | 92 +++++-------- .../test_h16/test_tpc_h16/duckdb/h16.sql | 42 +++--- .../test_h17/test_tpc_h17/duckdb/h17.sql | 40 ++---- .../test_h19/test_tpc_h19/duckdb/h19.sql | 128 ++++-------------- 14 files changed, 167 insertions(+), 360 deletions(-) diff --git a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql index de6339220a50..0f73b931355a 100644 --- a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/duckdb/out.sql @@ -4,7 +4,7 @@ SELECT FROM ( SELECT t4.street, - ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM ( SELECT t1.street, @@ -12,7 +12,7 @@ FROM ( FROM ( SELECT t0.*, - ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM "data" AS t0 ) AS t1 INNER JOIN ( @@ -21,14 +21,12 @@ FROM ( FROM ( SELECT t0.*, - ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM "data" AS t0 ) AS t1 ) AS t2 ON ( - t1.key - ) = ( - t2.key + t1.key = t2.key ) ) AS t4 ) AS t5 @@ -38,7 +36,7 @@ INNER JOIN ( FROM ( SELECT t4.street, - ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t4.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM ( SELECT t1.street, @@ -46,7 +44,7 @@ INNER JOIN ( FROM ( SELECT t0.*, - ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM "data" AS t0 ) AS t1 INNER JOIN ( @@ -55,20 +53,16 @@ INNER JOIN ( FROM ( SELECT t0.*, - ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) - 1 AS key + ROW_NUMBER() OVER (ORDER BY t0.street ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS key FROM "data" AS t0 ) AS t1 ) AS t2 ON ( - t1.key - ) = ( - t2.key + t1.key = t2.key ) ) AS t4 ) AS t5 ) AS t6 ON ( - t5.key - ) = ( - t6.key + t5.key = t6.key ) \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql index dbde514e7b86..81f9864f6880 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/duckdb/out.sql @@ -22,9 +22,7 @@ FROM ( ) AS t1 ) AS t2 ON ( - t1.key - ) = ( - t2.key + t1.key = t2.key ) ) AS t4 INNER JOIN ( @@ -49,13 +47,9 @@ INNER JOIN ( ) AS t1 ) AS t2 ON ( - t1.key - ) = ( - t2.key + t1.key = t2.key ) ) AS t5 ON ( - t4.key - ) = ( - t5.key + t4.key = t5.key ) \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql index 3c8533744eab..1f22a63f400d 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/duckdb/out.sql @@ -8,9 +8,7 @@ SELECT FROM "t" AS t0 WHERE ( - t0.x - ) > ( - CAST(2 AS TINYINT) + t0.x > CAST(2 AS TINYINT) ) ) AS t1 ) AS "InColumn(x, x)" diff --git a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql index 62adad32f7e3..eadfa35c468c 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql @@ -13,22 +13,20 @@ FROM ( FIRST(t4.diff) AS diff FROM ( SELECT - *, + t3.*, ( - t3.latest_degrees - ) - ( - t3.earliest_degrees + t3.latest_degrees - t3.earliest_degrees ) AS diff FROM ( SELECT - *, - FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS earliest_degrees, - LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS latest_degrees + t2.*, + FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS earliest_degrees, + LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS latest_degrees FROM ( SELECT t1.field_of_study, - STRUCT_EXTRACT(t1.__pivoted__, 'years') AS years, - STRUCT_EXTRACT(t1.__pivoted__, 'degrees') AS degrees + t1.__pivoted__['years'] AS years, + t1.__pivoted__['degrees'] AS degrees FROM ( SELECT t0.field_of_study, @@ -62,22 +60,20 @@ FROM ( FIRST(t4.diff) AS diff FROM ( SELECT - *, + t3.*, ( - t3.latest_degrees - ) - ( - t3.earliest_degrees + t3.latest_degrees - t3.earliest_degrees ) AS diff FROM ( SELECT - *, - FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS earliest_degrees, - LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED preceding AND UNBOUNDED following) AS latest_degrees + t2.*, + FIRST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS earliest_degrees, + LAST(t2.degrees) OVER (PARTITION BY t2.field_of_study ORDER BY t2.years ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS latest_degrees FROM ( SELECT t1.field_of_study, - STRUCT_EXTRACT(t1.__pivoted__, 'years') AS years, - STRUCT_EXTRACT(t1.__pivoted__, 'degrees') AS degrees + t1.__pivoted__['years'] AS years, + t1.__pivoted__['degrees'] AS degrees FROM ( SELECT t0.field_of_study, @@ -94,9 +90,7 @@ FROM ( ) AS t5 WHERE ( - t5.diff - ) < ( - CAST(0 AS TINYINT) + t5.diff < CAST(0 AS TINYINT) ) ) AS t7 ORDER BY diff --git a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql index 6bf4e97fbfac..7485b03ad808 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/duckdb/h01.sql @@ -6,33 +6,19 @@ FROM ( t0.l_linestatus, SUM(t0.l_quantity) AS sum_qty, SUM(t0.l_extendedprice) AS sum_base_price, - SUM( - ( - t0.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t0.l_discount - ) + SUM(( + t0.l_extendedprice * ( + CAST(1 AS TINYINT) - t0.l_discount ) - ) AS sum_disc_price, + )) AS sum_disc_price, SUM( ( ( - t0.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t0.l_discount + t0.l_extendedprice * ( + CAST(1 AS TINYINT) - t0.l_discount ) - ) - ) * ( - ( - t0.l_tax - ) + ( - CAST(1 AS TINYINT) + ) * ( + t0.l_tax + CAST(1 AS TINYINT) ) ) ) AS sum_charge, @@ -43,9 +29,7 @@ FROM ( FROM "lineitem" AS t0 WHERE ( - t0.l_shipdate - ) <= ( - MAKE_DATE(1998, 9, 2) + t0.l_shipdate <= MAKE_DATE(1998, 9, 2) ) GROUP BY 1, diff --git a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql index f58d71f9e961..4c751c9af259 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/duckdb/h04.sql @@ -4,29 +4,23 @@ SELECT FROM "orders" AS t0 WHERE EXISTS( - SELECT - 1 - FROM "lineitem" - WHERE - ( - l_orderkey - ) = ( - t0.o_orderkey - ) AND ( - l_commitdate - ) < ( - l_receiptdate - ) + ( + SELECT + 1 + FROM "lineitem" AS t1 + WHERE + ( + t1.l_orderkey = t0.o_orderkey + ) AND ( + t1.l_commitdate < t1.l_receiptdate + ) + ) ) AND ( - t0.o_orderdate - ) >= ( - MAKE_DATE(1993, 7, 1) + t0.o_orderdate >= MAKE_DATE(1993, 7, 1) ) AND ( - t0.o_orderdate - ) < ( - MAKE_DATE(1993, 10, 1) + t0.o_orderdate < MAKE_DATE(1993, 10, 1) ) GROUP BY 1 diff --git a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql index 8c780e28b15a..f93d82bdf649 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/duckdb/h06.sql @@ -1,24 +1,16 @@ SELECT SUM(( - t0.l_extendedprice - ) * ( - t0.l_discount + t0.l_extendedprice * t0.l_discount )) AS revenue FROM "lineitem" AS t0 WHERE ( - t0.l_shipdate - ) >= ( - MAKE_DATE(1994, 1, 1) + t0.l_shipdate >= MAKE_DATE(1994, 1, 1) ) AND ( - t0.l_shipdate - ) < ( - MAKE_DATE(1995, 1, 1) + t0.l_shipdate < MAKE_DATE(1995, 1, 1) ) AND t0.l_discount BETWEEN CAST(0.05 AS DOUBLE) AND CAST(0.07 AS DOUBLE) AND ( - t0.l_quantity - ) < ( - CAST(24 AS TINYINT) + t0.l_quantity < CAST(24 AS TINYINT) ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql index af51e5041edf..683d7dcacfe7 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/duckdb/h12.sql @@ -28,32 +28,22 @@ FROM ( FROM "orders" AS t0 INNER JOIN "lineitem" AS t1 ON ( - t0.o_orderkey - ) = ( - t1.l_orderkey + t0.o_orderkey = t1.l_orderkey ) ) AS t2 WHERE t2.l_shipmode IN ('MAIL', 'SHIP') AND ( - t2.l_commitdate - ) < ( - t2.l_receiptdate + t2.l_commitdate < t2.l_receiptdate ) AND ( - t2.l_shipdate - ) < ( - t2.l_commitdate + t2.l_shipdate < t2.l_commitdate ) AND ( - t2.l_receiptdate - ) >= ( - MAKE_DATE(1994, 1, 1) + t2.l_receiptdate >= MAKE_DATE(1994, 1, 1) ) AND ( - t2.l_receiptdate - ) < ( - MAKE_DATE(1995, 1, 1) + t2.l_receiptdate < MAKE_DATE(1995, 1, 1) ) GROUP BY 1 diff --git a/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql b/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql index b0275e931a23..3eaf29f2e69c 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h13/test_tpc_h13/duckdb/h13.sql @@ -15,11 +15,8 @@ FROM ( FROM "customer" AS t0 LEFT JOIN "orders" AS t1 ON ( - t0.c_custkey - ) = ( - t1.o_custkey - ) - AND NOT t1.o_comment LIKE '%special%requests%' + t0.c_custkey = t1.o_custkey + ) AND NOT t1.o_comment LIKE '%special%requests%' ) AS t2 GROUP BY 1 diff --git a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql index aba4eafae950..f6c370f5548b 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/duckdb/h14.sql @@ -5,32 +5,18 @@ SELECT CASE WHEN t2.p_type LIKE 'PROMO%' THEN ( - t2.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t2.l_discount + t2.l_extendedprice * ( + CAST(1 AS TINYINT) - t2.l_discount ) ) ELSE CAST(0 AS TINYINT) END + ) * CAST(100 AS TINYINT) + ) / SUM(( + t2.l_extendedprice * ( + CAST(1 AS TINYINT) - t2.l_discount ) - ) * ( - CAST(100 AS TINYINT) - ) - ) / ( - SUM( - ( - t2.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t2.l_discount - ) - ) - ) + )) ) AS promo_revenue FROM ( SELECT @@ -39,19 +25,13 @@ FROM ( FROM "lineitem" AS t0 INNER JOIN "part" AS t1 ON ( - t0.l_partkey - ) = ( - t1.p_partkey + t0.l_partkey = t1.p_partkey ) ) AS t2 WHERE ( - t2.l_shipdate - ) >= ( - MAKE_DATE(1995, 9, 1) + t2.l_shipdate >= MAKE_DATE(1995, 9, 1) ) AND ( - t2.l_shipdate - ) < ( - MAKE_DATE(1995, 10, 1) + t2.l_shipdate < MAKE_DATE(1995, 10, 1) ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql index a04eb32850b1..4e5a8d725a2b 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/duckdb/h15.sql @@ -1,9 +1,9 @@ SELECT - t5.s_suppkey, - t5.s_name, - t5.s_address, - t5.s_phone, - t5.total_revenue + t6.s_suppkey, + t6.s_name, + t6.s_address, + t6.s_phone, + t6.total_revenue FROM ( SELECT * @@ -17,75 +17,53 @@ FROM ( INNER JOIN ( SELECT t1.l_suppkey, - SUM( - ( - t1.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t1.l_discount - ) + SUM(( + t1.l_extendedprice * ( + CAST(1 AS TINYINT) - t1.l_discount ) - ) AS total_revenue + )) AS total_revenue FROM "lineitem" AS t1 WHERE ( - t1.l_shipdate - ) >= ( - MAKE_DATE(1996, 1, 1) + t1.l_shipdate >= MAKE_DATE(1996, 1, 1) ) AND ( - t1.l_shipdate - ) < ( - MAKE_DATE(1996, 4, 1) + t1.l_shipdate < MAKE_DATE(1996, 4, 1) ) GROUP BY 1 ) AS t2 ON ( - t0.s_suppkey - ) = ( - t2.l_suppkey + t0.s_suppkey = t2.l_suppkey ) ) AS t3 WHERE ( - t3.total_revenue - ) = ( - SELECT - MAX(t1.total_revenue) AS "Max(total_revenue)" - FROM ( + t3.total_revenue = ( SELECT - t0.l_suppkey, - SUM( - ( - t0.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t0.l_discount + MAX(t2.total_revenue) AS "Max(total_revenue)" + FROM ( + SELECT + t1.l_suppkey, + SUM(( + t1.l_extendedprice * ( + CAST(1 AS TINYINT) - t1.l_discount ) + )) AS total_revenue + FROM "lineitem" AS t1 + WHERE + ( + t1.l_shipdate >= MAKE_DATE(1996, 1, 1) ) - ) AS total_revenue - FROM "lineitem" AS t0 - WHERE - ( - t0.l_shipdate - ) >= ( - MAKE_DATE(1996, 1, 1) - ) - AND ( - t0.l_shipdate - ) < ( - MAKE_DATE(1996, 4, 1) - ) - GROUP BY - 1 - ) AS t1 + AND ( + t1.l_shipdate < MAKE_DATE(1996, 4, 1) + ) + GROUP BY + 1 + ) AS t2 + ) ) - ) AS t4 + ) AS t5 ORDER BY - t4.s_suppkey ASC -) AS t5 \ No newline at end of file + t5.s_suppkey ASC +) AS t6 \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql b/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql index e8c3a55986ad..901e8269f406 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h16/test_tpc_h16/duckdb/h16.sql @@ -2,10 +2,10 @@ SELECT * FROM ( SELECT - t2.p_brand, - t2.p_type, - t2.p_size, - COUNT(DISTINCT t2.ps_suppkey) AS supplier_cnt + t3.p_brand, + t3.p_type, + t3.p_size, + COUNT(DISTINCT t3.ps_suppkey) AS supplier_cnt FROM ( SELECT t0.*, @@ -13,37 +13,33 @@ FROM ( FROM "partsupp" AS t0 INNER JOIN "part" AS t1 ON ( - t1.p_partkey - ) = ( - t0.ps_partkey + t1.p_partkey = t0.ps_partkey ) - ) AS t2 + ) AS t3 WHERE ( - t2.p_brand - ) <> ( - 'Brand#45' + t3.p_brand <> 'Brand#45' ) - AND NOT t2.p_type LIKE 'MEDIUM POLISHED%' - AND t2.p_size IN (CAST(49 AS TINYINT), CAST(14 AS TINYINT), CAST(23 AS TINYINT), CAST(45 AS TINYINT), CAST(19 AS TINYINT), CAST(3 AS TINYINT), CAST(36 AS TINYINT), CAST(9 AS TINYINT)) - AND NOT t2.ps_suppkey IN ( + AND NOT t3.p_type LIKE 'MEDIUM POLISHED%' + AND t3.p_size IN (CAST(49 AS TINYINT), CAST(14 AS TINYINT), CAST(23 AS TINYINT), CAST(45 AS TINYINT), CAST(19 AS TINYINT), CAST(3 AS TINYINT), CAST(36 AS TINYINT), CAST(9 AS TINYINT)) + AND NOT t3.ps_suppkey IN ( SELECT - t1.s_suppkey + t4.s_suppkey FROM ( SELECT * - FROM "supplier" AS t0 + FROM "supplier" AS t2 WHERE - t0.s_comment LIKE '%Customer%Complaints%' - ) AS t1 + t2.s_comment LIKE '%Customer%Complaints%' + ) AS t4 ) GROUP BY 1, 2, 3 -) AS t3 +) AS t6 ORDER BY - t3.supplier_cnt DESC, - t3.p_brand ASC, - t3.p_type ASC, - t3.p_size ASC \ No newline at end of file + t6.supplier_cnt DESC, + t6.p_brand ASC, + t6.p_type ASC, + t6.p_size ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql index 9017582a898a..6b8d88f46d87 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h17/test_tpc_h17/duckdb/h17.sql @@ -1,8 +1,6 @@ SELECT ( - SUM(t2.l_extendedprice) - ) / ( - CAST(7.0 AS DOUBLE) + SUM(t2.l_extendedprice) / CAST(7.0 AS DOUBLE) ) AS avg_yearly FROM ( SELECT @@ -11,36 +9,26 @@ FROM ( FROM "lineitem" AS t0 INNER JOIN "part" AS t1 ON ( - t1.p_partkey - ) = ( - t0.l_partkey + t1.p_partkey = t0.l_partkey ) ) AS t2 WHERE ( - t2.p_brand - ) = ( - 'Brand#23' + t2.p_brand = 'Brand#23' ) AND ( - t2.p_container - ) = ( - 'MED BOX' + t2.p_container = 'MED BOX' ) AND ( - t2.l_quantity - ) < ( - ( - SELECT - AVG(t0.l_quantity) AS "Mean(l_quantity)" - FROM "lineitem" AS t0 - WHERE - ( - t0.l_partkey - ) = ( - p_partkey - ) - ) * ( - CAST(0.2 AS DOUBLE) + t2.l_quantity < ( + ( + SELECT + AVG(t0.l_quantity) AS "Mean(l_quantity)" + FROM "lineitem" AS t0 + WHERE + ( + t0.l_partkey = t2.p_partkey + ) + ) * CAST(0.2 AS DOUBLE) ) ) \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql b/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql index ae042784a4bd..aee0ad664340 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h19/test_tpc_h19/duckdb/h19.sql @@ -1,15 +1,9 @@ SELECT - SUM( - ( - t2.l_extendedprice - ) * ( - ( - CAST(1 AS TINYINT) - ) - ( - t2.l_discount - ) + SUM(( + t2.l_extendedprice * ( + CAST(1 AS TINYINT) - t2.l_discount ) - ) AS revenue + )) AS revenue FROM ( SELECT t0.*, @@ -17,9 +11,7 @@ FROM ( FROM "lineitem" AS t0 INNER JOIN "part" AS t1 ON ( - t1.p_partkey - ) = ( - t0.l_partkey + t1.p_partkey = t0.l_partkey ) ) AS t2 WHERE @@ -32,143 +24,79 @@ WHERE ( ( ( - t2.p_brand - ) = ( - 'Brand#12' + t2.p_brand = 'Brand#12' ) + AND t2.p_container IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') ) AND ( - t2.p_container IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + t2.l_quantity >= CAST(1 AS TINYINT) ) ) AND ( - ( - t2.l_quantity - ) >= ( - CAST(1 AS TINYINT) - ) + t2.l_quantity <= CAST(11 AS TINYINT) ) ) - AND ( - ( - t2.l_quantity - ) <= ( - CAST(11 AS TINYINT) - ) - ) - ) - AND ( - t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(5 AS TINYINT) + AND t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(5 AS TINYINT) ) + AND t2.l_shipmode IN ('AIR', 'AIR REG') ) AND ( - t2.l_shipmode IN ('AIR', 'AIR REG') + t2.l_shipinstruct = 'DELIVER IN PERSON' ) ) - AND ( - ( - t2.l_shipinstruct - ) = ( - 'DELIVER IN PERSON' - ) - ) - ) - OR ( - ( + OR ( ( ( ( ( ( ( - t2.p_brand - ) = ( - 'Brand#23' + t2.p_brand = 'Brand#23' ) + AND t2.p_container IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') ) AND ( - t2.p_container IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + t2.l_quantity >= CAST(10 AS TINYINT) ) ) AND ( - ( - t2.l_quantity - ) >= ( - CAST(10 AS TINYINT) - ) - ) - ) - AND ( - ( - t2.l_quantity - ) <= ( - CAST(20 AS TINYINT) + t2.l_quantity <= CAST(20 AS TINYINT) ) ) + AND t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(10 AS TINYINT) ) - AND ( - t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(10 AS TINYINT) - ) + AND t2.l_shipmode IN ('AIR', 'AIR REG') ) AND ( - t2.l_shipmode IN ('AIR', 'AIR REG') - ) - ) - AND ( - ( - t2.l_shipinstruct - ) = ( - 'DELIVER IN PERSON' + t2.l_shipinstruct = 'DELIVER IN PERSON' ) ) ) - ) - OR ( - ( + OR ( ( ( ( ( ( ( - t2.p_brand - ) = ( - 'Brand#34' + t2.p_brand = 'Brand#34' ) + AND t2.p_container IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') ) AND ( - t2.p_container IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + t2.l_quantity >= CAST(20 AS TINYINT) ) ) AND ( - ( - t2.l_quantity - ) >= ( - CAST(20 AS TINYINT) - ) - ) - ) - AND ( - ( - t2.l_quantity - ) <= ( - CAST(30 AS TINYINT) + t2.l_quantity <= CAST(30 AS TINYINT) ) ) + AND t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(15 AS TINYINT) ) - AND ( - t2.p_size BETWEEN CAST(1 AS TINYINT) AND CAST(15 AS TINYINT) - ) + AND t2.l_shipmode IN ('AIR', 'AIR REG') ) AND ( - t2.l_shipmode IN ('AIR', 'AIR REG') - ) - ) - AND ( - ( - t2.l_shipinstruct - ) = ( - 'DELIVER IN PERSON' + t2.l_shipinstruct = 'DELIVER IN PERSON' ) ) ) \ No newline at end of file From 8676b650f1b5d4d642c3992ff8484ddb82321288 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:44:14 -0400 Subject: [PATCH 216/222] chore: remove duplicate clickhouse xfail --- ibis/backends/tests/test_window.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index d8d7fadc1cf9..d3f0605ab59f 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -891,8 +891,7 @@ def gb_fn(df): @pytest.mark.notimpl( - ["clickhouse", "dask", "datafusion", "polars"], - raises=com.OperationNotDefinedError, + ["dask", "datafusion", "polars"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl(["pyspark"], raises=AnalysisException) @pytest.mark.notyet( From 9da49770f7b054a6a749137db153833fd39cbfea Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:45:01 -0400 Subject: [PATCH 217/222] chore: remove dead code and unnecessary `lit` calls --- ibis/backends/duckdb/compiler/values.py | 93 ++++++------------------- 1 file changed, 23 insertions(+), 70 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index d970a105002a..96a82dcabed3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -3,6 +3,7 @@ import calendar import functools import math +import operator import string from functools import partial from typing import Any @@ -10,7 +11,6 @@ import sqlglot as sg import ibis.common.exceptions as com -import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sqlglot import NULL, STAR, AggGen, FuncGen, lit, make_cast @@ -66,11 +66,17 @@ def _alias(op, *, arg, name, **_): @translate_val.register(ops.Literal) def _literal(op, *, value, dtype, **kw): - if dtype.is_interval() and value is not None: - return _interval_format(op) + if value is None: + if dtype.nullable: + return NULL if dtype.is_null() else cast(NULL, dtype) + raise NotImplementedError(f"Unsupported NULL for non-nullable type: {dtype!r}") + elif dtype.is_interval(): + if dtype.unit.short == "ns": + raise com.UnsupportedOperationError( + "Duckdb doesn't support nanosecond interval resolutions" + ) - if value is None and dtype.nullable: - return NULL if dtype.is_null() else cast(NULL, dtype) + return sg.exp.Interval(this=lit(value), unit=dtype.resolution.upper()) elif dtype.is_boolean(): return sg.exp.Boolean(this=value) elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): @@ -430,7 +436,7 @@ def _timestamp_from_ymdhms(op, *, year, month, day, hours, minutes, seconds, **_ func = "make_timestamp" if (timezone := op.dtype.timezone) is not None: func += "tz" - args.append(lit(timezone)) + args.append(timezone) return f[func](*args) @@ -479,12 +485,12 @@ def _extract_time(op, *, arg, **_): # so we have to finesse it a little bit @translate_val.register(ops.ExtractMicrosecond) def _extract_microsecond(op, *, arg, **_): - return sg.exp.Mod(this=f.extract("us", arg), expression=lit(1_000_000)) + return f.mod(f.extract("us", arg), 1_000_000) @translate_val.register(ops.ExtractMillisecond) def _extract_microsecond(op, *, arg, **_): - return sg.exp.Mod(this=f.extract("ms", arg), expression=lit(1_000)) + return f.mod(f.extract("ms", arg), 1_000) @translate_val.register(ops.DateTruncate) @@ -533,9 +539,9 @@ def day_of_week_name(op, *, arg, **_): _interval_mapping = { - ops.IntervalAdd: sg.exp.Add, - ops.IntervalSubtract: sg.exp.Sub, - ops.IntervalMultiply: sg.exp.Mul, + ops.IntervalAdd: operator.add, + ops.IntervalSubtract: operator.sub, + ops.IntervalMultiply: operator.mul, } @@ -543,18 +549,8 @@ def day_of_week_name(op, *, arg, **_): @translate_val.register(ops.IntervalSubtract) @translate_val.register(ops.IntervalMultiply) def _interval_binary(op, *, left, right, **_): - sg_expr = _interval_mapping[type(op)] - return sg_expr(this=left, expression=right) - - -def _interval_format(op): - dtype = op.dtype - if dtype.unit.short == "ns": - raise com.UnsupportedOperationError( - "Duckdb doesn't support nanosecond interval resolutions" - ) - - return sg.exp.Interval(this=lit(op.value), unit=dtype.resolution.upper()) + func = _interval_mapping[type(op)] + return func(left, right) @translate_val.register(ops.IntervalFromInteger) @@ -983,7 +979,7 @@ def _map_get(op, *, arg, key, default, **_): @translate_val.register(ops.MapContains) def _map_contains(op, *, arg, key, **_): - return sg.exp.NEQ(this=f.array_length(f.element_at(arg, key)), expression=lit(0)) + return f.len(f.element_at(arg, key)).neq(lit(0)) def _binary_infix(sg_expr: sg.exp._Expression): @@ -1026,17 +1022,7 @@ def formatter(op, *, left, right, **_): del _op, _sym -# @translate_val.register(ops.Xor) -# def _xor(op, **kw): -# # https://github.com/tobymao/sqlglot/issues/2238 -# left = translate_val(op.left, **kw).sql("duckdb") -# right = translate_val(op.right, **kw).sql("duckdb") -# return sg.parse_one( -# f"({left} OR {right}) AND NOT ({left} AND {right})", read="duckdb" -# ) - - -### Ordering +### Ordering and window functions @translate_val.register(ops.RowNumber) @@ -1069,42 +1055,9 @@ def _sort_key(op: ops.SortKey, *, expr, ascending: bool, **_): return sg.exp.Ordered(this=expr, desc=not ascending) -### Window functions - -_cumulative_to_reduction = { - ops.CumulativeSum: ops.Sum, - ops.CumulativeMin: ops.Min, - ops.CumulativeMax: ops.Max, - ops.CumulativeMean: ops.Mean, - ops.CumulativeAny: ops.Any, - ops.CumulativeAll: ops.All, -} - - -def cumulative_to_window(func, frame): - klass = _cumulative_to_reduction[type(func)] - new_op = klass(*func.args) - new_frame = frame.copy(start=None, end=0) - new_expr = an.windowize_function(new_op.to_expr(), frame=new_frame) - return new_expr.op() - - -# TODO -_map_interval_to_microseconds = { - "W": 604800000000, - "D": 86400000000, - "h": 3600000000, - "m": 60000000, - "s": 1000000, - "ms": 1000, - "us": 1, - "ns": 0.001, -} - - @translate_val.register(ops.ApproxMedian) def _approx_median(op, *, arg, where, **_): - return agg.approx_quantile(arg, lit(0.5), where=where) + return agg.approx_quantile(arg, 0.5, where=where) @translate_val.register(ops.WindowBoundary) @@ -1161,7 +1114,7 @@ def formatter(op, *, arg, offset, default, **_): if default is not None: if offset is None: - offset = lit(1) + offset = 1 args.append(offset) args.append(default) From c50f4d8b85a9315226953fa958e9d55f622ad8b9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:36:51 -0400 Subject: [PATCH 218/222] chore: clean up kw --- ibis/backends/duckdb/compiler/values.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 96a82dcabed3..13b75aaafc1e 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -36,13 +36,8 @@ def translate_val(op, **_): raise com.OperationNotDefinedError(f"No translation rule for {type(op)}") -@translate_val.register(dt.DataType) -def _datatype(t, **_): - return DuckDBType.from_ibis(t) - - @translate_val.register(ops.PhysicalTable) -def _val_physical_table(op, *, aliases, **kw): +def _val_physical_table(op, *, aliases, **_): return f"{aliases.get(op, op.name)}.*" @@ -52,8 +47,8 @@ def _val_table_node(op, *, aliases, **_): @translate_val.register(ops.TableColumn) -def _column(op, *, aliases, **_): - return sg.column(op.name, table=aliases.get(op.table)) +def _column(op, *, table, name, **_): + return sg.column(name, table=table.alias_or_name) @translate_val.register(ops.Alias) @@ -342,7 +337,7 @@ def _round(op, *, arg, digits, **_): @translate_val.register(ops.Cast) -def _cast(op, *, arg, to, **kw): +def _cast(op, *, arg, to, **_): if to.is_interval(): return f[f"to_{_interval_suffixes[to.unit.short]}"]( sg.cast(arg, to=DuckDBType.from_ibis(dt.int32)) From 159fe978a97730f3dc006a45db920bccc81eecf9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:43:17 -0400 Subject: [PATCH 219/222] chore: remove unalias function --- ibis/backends/base/sqlglot/__init__.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 0ae090a8f960..cc6d016bfa81 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: import ibis.expr.datatypes as dt - import ibis.expr.operations as ops from ibis.backends.base.sqlglot.datatypes import SqlglotType @@ -68,14 +67,3 @@ def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast: return sg.cast(arg, to=converter.from_ibis(to)) return cast - - -def unalias(op: ops.Value) -> ops.Value: - """Unwrap `Alias` objects. - - Necessary when rendering `WHERE`, `GROUP BY` and `ORDER BY` and other - clauses. - """ - import ibis.expr.operations as ops - - return op.arg if isinstance(op, ops.Alias) else op From f4f40549da133128898d0a1638c2cc690bdbc23a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:43:57 -0400 Subject: [PATCH 220/222] chore: add pattern to transform empty percent rank et al --- ibis/backends/base/sqlglot/compiler/core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ibis/backends/base/sqlglot/compiler/core.py b/ibis/backends/base/sqlglot/compiler/core.py index 3afdf890959d..9625782f1794 100644 --- a/ibis/backends/base/sqlglot/compiler/core.py +++ b/ibis/backends/base/sqlglot/compiler/core.py @@ -155,12 +155,21 @@ def fn(node, _, **kwargs): c.ExistsSubquery(x, y) ) + add_order_by_to_window_funcs = p.WindowFunction( + p.PercentRank(x) | p.RankBase(x) | p.CumeDist(x), + ( + p.WindowFrame(..., order_by=()) + >> (lambda op, ctx: op.copy(order_by=(ctx[x],))) + ), + ) + op = op.replace( replace_literals | replace_cumulative_ops | replace_in_column_with_table_array_view | replace_empty_in_values_with_false | replace_notexists_subquery_with_not_exists + | add_order_by_to_window_funcs ) # apply translate rules in topological order results = op.map(fn, filter=(ops.TableNode, ops.Value)) From fcfaad8d8389a15762cf4b38f8ee2a47e1509d55 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:44:07 -0400 Subject: [PATCH 221/222] chore(clickhouse): regen sql --- .../tests/snapshots/test_functions/test_timestamp_now/out.sql | 2 +- .../test_fine_grained_timestamp_literals/micros/out.sql | 2 +- .../test_fine_grained_timestamp_literals/micros_tz/out.sql | 2 +- .../test_fine_grained_timestamp_literals/millis/out.sql | 2 +- .../test_fine_grained_timestamp_literals/millis_tz/out.sql | 2 +- .../test_string_numeric_boolean_literals/false/out.sql | 2 +- .../test_string_numeric_boolean_literals/float/out.sql | 2 +- .../test_string_numeric_boolean_literals/int/out.sql | 2 +- .../test_string_numeric_boolean_literals/nested_quote/out.sql | 2 +- .../test_string_numeric_boolean_literals/nested_token/out.sql | 2 +- .../test_string_numeric_boolean_literals/simple/out.sql | 2 +- .../test_string_numeric_boolean_literals/true/out.sql | 2 +- .../test_literals/test_timestamp_literals/expr0/out.sql | 2 +- .../test_literals/test_timestamp_literals/expr1/out.sql | 2 +- .../test_literals/test_timestamp_literals/expr2/out.sql | 2 +- 15 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_now/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_now/out.sql index 057c9d542a56..8a90bb0bb098 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_now/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_timestamp_now/out.sql @@ -1,2 +1,2 @@ SELECT - now() \ No newline at end of file + now() AS "TimestampNow()" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros/out.sql index 74515d167df0..ce21e32a643d 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime64('2015-01-01T12:34:56.789321', 6) \ No newline at end of file + toDateTime64('2015-01-01T12:34:56.789321', 6) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros_tz/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros_tz/out.sql index f8731e4d6a08..dca34dda04fe 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros_tz/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/micros_tz/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime64('2015-01-01T12:34:56.789321', 6, 'UTC') \ No newline at end of file + toDateTime64('2015-01-01T12:34:56.789321', 6, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321, tzinfo=tzutc())" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis/out.sql index eef2cad901fa..94e5de1704d9 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime64('2015-01-01T12:34:56.789000', 3) \ No newline at end of file + toDateTime64('2015-01-01T12:34:56.789000', 3) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis_tz/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis_tz/out.sql index 367b82beee78..948f20c31431 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis_tz/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_fine_grained_timestamp_literals/millis_tz/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime64('2015-01-01T12:34:56.789000', 3, 'UTC') \ No newline at end of file + toDateTime64('2015-01-01T12:34:56.789000', 3, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000, tzinfo=tzutc())" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql index 52db1a268bf3..ce5812b68ce6 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/false/out.sql @@ -1,2 +1,2 @@ SELECT - False \ No newline at end of file + False AS False \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/float/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/float/out.sql index 0d0f28dfb907..756ffaa226f7 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/float/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/float/out.sql @@ -1,2 +1,2 @@ SELECT - 1.5 \ No newline at end of file + 1.5 AS "1.5" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/int/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/int/out.sql index 9500182be91d..5582adbae15b 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/int/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/int/out.sql @@ -1,2 +1,2 @@ SELECT - 5 \ No newline at end of file + 5 AS "5" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_quote/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_quote/out.sql index e99b43c131e8..51ffc764c325 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_quote/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_quote/out.sql @@ -1,2 +1,2 @@ SELECT - 'I can''t' \ No newline at end of file + 'I can''t' AS """I can't""" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_token/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_token/out.sql index 233fd6ddcba5..7753a461037d 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_token/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/nested_token/out.sql @@ -1,2 +1,2 @@ SELECT - 'An "escape"' \ No newline at end of file + 'An "escape"' AS "'An ""escape""'" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/simple/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/simple/out.sql index 08c1427b8a79..6acbee281ec8 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/simple/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/simple/out.sql @@ -1,2 +1,2 @@ SELECT - 'simple' \ No newline at end of file + 'simple' AS "'simple'" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql index 55e1da583d09..79f20b6933bf 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_string_numeric_boolean_literals/true/out.sql @@ -1,2 +1,2 @@ SELECT - True \ No newline at end of file + True AS True \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr0/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr0/out.sql index 290d3ec1a3ec..2ba743636e07 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr0/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr0/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime('2015-01-01T12:34:56') \ No newline at end of file + toDateTime('2015-01-01T12:34:56') AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr1/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr1/out.sql index 290d3ec1a3ec..2ba743636e07 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr1/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr1/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime('2015-01-01T12:34:56') \ No newline at end of file + toDateTime('2015-01-01T12:34:56') AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr2/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr2/out.sql index 290d3ec1a3ec..2ba743636e07 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr2/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_literals/test_timestamp_literals/expr2/out.sql @@ -1,2 +1,2 @@ SELECT - toDateTime('2015-01-01T12:34:56') \ No newline at end of file + toDateTime('2015-01-01T12:34:56') AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" \ No newline at end of file From dbf34e5a0fe3958d1b1b8dfdd9dd2c1f886cf89c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:51:36 -0400 Subject: [PATCH 222/222] chore: more lit and array cleanup --- ibis/backends/base/sqlglot/__init__.py | 2 +- ibis/backends/duckdb/compiler/values.py | 43 ++++++++++++------------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index cc6d016bfa81..334124f7286a 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -64,6 +64,6 @@ def make_cast( converter: SqlglotType, ) -> Callable[[sg.exp.Expression, dt.DataType], sg.exp.Cast]: def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast: - return sg.cast(arg, to=converter.from_ibis(to)) + return sg.cast(_to_sqlglot(arg), to=converter.from_ibis(to)) return cast diff --git a/ibis/backends/duckdb/compiler/values.py b/ibis/backends/duckdb/compiler/values.py index 13b75aaafc1e..572ef298b5c3 100644 --- a/ibis/backends/duckdb/compiler/values.py +++ b/ibis/backends/duckdb/compiler/values.py @@ -82,39 +82,36 @@ def _literal(op, *, value, dtype, **kw): # # float will be upcast to double if necessary by duckdb if not math.isfinite(value): - return cast(lit(str(value)), to=dt.float32 if dtype.is_decimal() else dtype) - return cast(lit(value), dtype) + return cast(str(value), to=dt.float32 if dtype.is_decimal() else dtype) + return cast(value, dtype) elif dtype.is_time(): - return cast(lit(value), dtype) + return cast(value, dtype) elif dtype.is_timestamp(): - year = lit(value.year) - month = lit(value.month) - day = lit(value.day) - hour = lit(value.hour) - minute = lit(value.minute) - second = lit(value.second) + year = value.year + month = value.month + day = value.day + hour = value.hour + minute = value.minute + second = value.second if us := value.microsecond: - microsecond = lit(us / 1e6) - second += microsecond + second += us / 1e6 if (tz := dtype.timezone) is not None: - timezone = lit(tz) - return f.make_timestamptz(year, month, day, hour, minute, second, timezone) + return f.make_timestamptz(year, month, day, hour, minute, second, tz) else: return f.make_timestamp(year, month, day, hour, minute, second) elif dtype.is_date(): - year = lit(value.year) - month = lit(value.month) - day = lit(value.day) - return sg.exp.DateFromParts(year=year, month=month, day=day) + return sg.exp.DateFromParts( + year=lit(value.year), month=lit(value.month), day=lit(value.day) + ) elif dtype.is_array(): value_type = dtype.value_type - return sg.exp.Array.from_arg_list( - [ + return f.array( + *( _literal( ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw ) for v in value - ] + ) ) elif dtype.is_map(): key_type = dtype.key_type @@ -146,9 +143,9 @@ def _literal(op, *, value, dtype, **kw): [sg.exp.Slice(this=k, expression=v) for k, v in zip(keys, values)] ) elif dtype.is_uuid(): - return cast(lit(str(value)), dtype) + return cast(str(value), dtype) elif dtype.is_binary(): - return cast(lit("".join(map("\\x{:02x}".format, value))), dtype) + return cast("".join(map("\\x{:02x}".format, value)), dtype) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -974,7 +971,7 @@ def _map_get(op, *, arg, key, default, **_): @translate_val.register(ops.MapContains) def _map_contains(op, *, arg, key, **_): - return f.len(f.element_at(arg, key)).neq(lit(0)) + return f.len(f.element_at(arg, key)).neq(0) def _binary_infix(sg_expr: sg.exp._Expression):