From 50f0193dc7ee1e0c8f1cc9b6c0331769042c6e2b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 23 Jan 2024 06:54:29 -0500 Subject: [PATCH] chore: remove sqlalchemy from the codebase --- .envrc | 1 - .github/workflows/ibis-backends.yml | 3 - .gitignore | 1 + README.md | 8 +- ci/make_geography_db.py | 87 +- ci/schema/snowflake.sql | 30 +- docs/backends/app/backend_info_app.py | 30 +- docs/how-to/extending/elementwise.qmd | 141 --- docs/how-to/extending/reduction.qmd | 165 --- gen_redirects.py | 14 +- ibis/backends/base/__init__.py | 30 +- ibis/backends/base/sql/__init__.py | 42 +- ibis/backends/base/sql/alchemy/__init__.py | 1043 ----------------- ibis/backends/base/sql/alchemy/datatypes.py | 194 --- .../base/sql/alchemy/query_builder.py | 465 -------- ibis/backends/base/sql/alchemy/registry.py | 813 ------------- ibis/backends/base/sql/alchemy/translator.py | 147 --- ibis/backends/base/sqlglot/__init__.py | 110 +- ibis/backends/base/sqlglot/datatypes.py | 2 +- ibis/backends/bigquery/__init__.py | 41 +- ibis/backends/clickhouse/__init__.py | 26 +- ibis/backends/conftest.py | 128 +- ibis/backends/dask/executor.py | 11 +- ibis/backends/datafusion/__init__.py | 119 +- ibis/backends/datafusion/compiler.py | 2 +- ibis/backends/druid/__init__.py | 3 + ibis/backends/druid/tests/conftest.py | 4 +- ibis/backends/duckdb/__init__.py | 92 +- ibis/backends/duckdb/tests/conftest.py | 7 +- ibis/backends/duckdb/tests/test_client.py | 5 +- ibis/backends/exasol/__init__.py | 2 +- ibis/backends/flink/translator.py | 4 +- ibis/backends/impala/__init__.py | 6 +- ibis/backends/impala/client.py | 13 +- ibis/backends/impala/tests/test_partition.py | 5 +- ibis/backends/mssql/compiler.py | 10 +- ibis/backends/mssql/tests/conftest.py | 16 +- ibis/backends/mysql/tests/conftest.py | 16 +- ibis/backends/oracle/tests/conftest.py | 4 +- ibis/backends/polars/__init__.py | 17 +- ibis/backends/postgres/__init__.py | 15 - ibis/backends/postgres/tests/conftest.py | 17 +- ibis/backends/pyspark/__init__.py | 45 - ibis/backends/risingwave/__init__.py | 2 - ibis/backends/snowflake/__init__.py | 46 - ibis/backends/snowflake/tests/conftest.py | 6 +- ibis/backends/sqlite/__init__.py | 49 +- ibis/backends/sqlite/tests/test_client.py | 8 +- ibis/backends/tests/base.py | 43 +- ibis/backends/tests/errors.py | 17 +- ibis/backends/tests/test_aggregation.py | 18 +- ibis/backends/tests/test_api.py | 16 +- ibis/backends/tests/test_array.py | 7 +- ibis/backends/tests/test_benchmarks.py | 8 +- ibis/backends/tests/test_client.py | 374 +++--- ibis/backends/tests/test_numeric.py | 89 +- ibis/backends/tests/test_register.py | 6 +- ibis/backends/tests/test_temporal.py | 20 +- ibis/backends/tests/test_udf.py | 4 +- ibis/backends/tests/tpch/test_h08.py | 6 - ibis/backends/tests/tpch/test_h14.py | 6 - ibis/backends/tests/tpch/test_h17.py | 6 - ibis/backends/trino/__init__.py | 10 +- ibis/backends/trino/tests/conftest.py | 12 +- ibis/expr/sql.py | 4 +- ibis/formats/__init__.py | 3 +- poetry.lock | 9 +- pyproject.toml | 17 +- requirements-dev.txt | 2 +- 69 files changed, 713 insertions(+), 4009 deletions(-) delete mode 100644 docs/how-to/extending/elementwise.qmd delete mode 100644 docs/how-to/extending/reduction.qmd delete mode 100644 ibis/backends/base/sql/alchemy/__init__.py delete mode 100644 ibis/backends/base/sql/alchemy/datatypes.py delete mode 100644 ibis/backends/base/sql/alchemy/query_builder.py delete mode 100644 ibis/backends/base/sql/alchemy/registry.py delete mode 100644 ibis/backends/base/sql/alchemy/translator.py diff --git a/.envrc b/.envrc index cb9f6de818107..6a558e6197e78 100644 --- a/.envrc +++ b/.envrc @@ -7,4 +7,3 @@ watch_file poetry-overrides.nix export CLOUDSDK_ACTIVE_CONFIG_NAME=ibis-gbq export GOOGLE_CLOUD_PROJECT="$CLOUDSDK_ACTIVE_CONFIG_NAME" -export SQLALCHEMY_WARN_20=1 diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 3c75efccf953a..22fdd4381913a 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -39,7 +39,6 @@ concurrency: env: FORCE_COLOR: "1" ODBCSYSINI: "${{ github.workspace }}/ci/odbc" - SQLALCHEMY_WARN_20: "1" HYPOTHESIS_PROFILE: "ci" jobs: @@ -426,8 +425,6 @@ jobs: test_backends_min_version: name: ${{ matrix.backend.title }} Min Version ${{ matrix.os }} python-${{ matrix.python-version }} runs-on: ${{ matrix.os }} - env: - SQLALCHEMY_WARN_20: "1" strategy: fail-fast: false matrix: diff --git a/.gitignore b/.gitignore index 7b9144d93eedc..bb1ea1c2ed59f 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,7 @@ result-* # tutorial data geography.db +geography.duckdb # build artifacts ci/udf/.ninja_deps diff --git a/README.md b/README.md index dbbc35fd22e3f..ed2f10bd7956e 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ Download the SQLite database from the `ibis-tutorial-data` GCS (Google Cloud Storage) bucket, then connect to it using ibis. ```bash -curl -LsS -o geography.db 'https://storage.googleapis.com/ibis-tutorial-data/geography.db' +curl -LsSO 'https://storage.googleapis.com/ibis-tutorial-data/geography.duckdb' ``` Connect to the database and show the available tables @@ -130,7 +130,7 @@ Connect to the database and show the available tables >>> import ibis >>> from ibis import _ >>> ibis.options.interactive = True ->>> con = ibis.sqlite.connect("geography.db") +>>> con = ibis.duckdb.connect("geography.duckdb") >>> con.tables Tables ------ @@ -147,7 +147,7 @@ Choose the `countries` table and preview its first few rows ┏━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┓ ┃ iso_alpha2 ┃ iso_alpha3 ┃ iso_numeric ┃ fips ┃ name ┃ capital ┃ area_km2 ┃ population ┃ continent ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━┩ -│ string │ string │ int32 │ string │ string │ string │ float64 │ int32 │ string │ +│ string │ string │ int64 │ string │ string │ string │ float64 │ int64 │ string │ ├────────────┼────────────┼─────────────┼────────┼──────────────────────┼──────────────────┼──────────┼────────────┼───────────┤ │ AD │ AND │ 20 │ AN │ Andorra │ Andorra la Vella │ 468.0 │ 84000 │ EU │ │ AE │ ARE │ 784 │ AE │ United Arab Emirates │ Abu Dhabi │ 82880.0 │ 4975593 │ AS │ @@ -170,7 +170,7 @@ Show the 5 least populous countries in Asia ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ name ┃ population ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ -│ string │ int32 │ +│ string │ int64 │ ├────────────────────────────────┼────────────┤ │ Cocos [Keeling] Islands │ 628 │ │ British Indian Ocean Territory │ 4000 │ diff --git a/ci/make_geography_db.py b/ci/make_geography_db.py index e98551b143a80..966e5292a4d85 100755 --- a/ci/make_geography_db.py +++ b/ci/make_geography_db.py @@ -16,72 +16,56 @@ from __future__ import annotations import argparse -import datetime +import json import tempfile from pathlib import Path from typing import TYPE_CHECKING, Any import requests -import sqlalchemy as sa -import toolz + +import ibis if TYPE_CHECKING: from collections.abc import Mapping SCHEMAS = { - "countries": [ - ("iso_alpha2", sa.TEXT), - ("iso_alpha3", sa.TEXT), - ("iso_numeric", sa.INT), - ("fips", sa.TEXT), - ("name", sa.TEXT), - ("capital", sa.TEXT), - ("area_km2", sa.REAL), - ("population", sa.INT), - ("continent", sa.TEXT), - ], - "gdp": [ - ("country_code", sa.TEXT), - ("year", sa.INT), - ("value", sa.REAL), - ], - "independence": [ - ("country_code", sa.TEXT), - ("independence_date", sa.DATE), - ("independence_from", sa.TEXT), - ], -} - -POST_PARSE_FUNCTIONS = { - "independence": lambda row: toolz.assoc( - row, - "independence_date", - datetime.datetime.fromisoformat(row["independence_date"]).date(), - ) + "countries": { + "iso_alpha2": "string", + "iso_alpha3": "string", + "iso_numeric": "int", + "fips": "string", + "name": "string", + "capital": "string", + "area_km2": "float", + "population": "int", + "continent": "string", + }, + "gdp": { + "country_code": "string", + "year": "int", + "value": "float", + }, + "independence": { + "country_code": "string", + "independence_date": "date", + "independence_from": "string", + }, } def make_geography_db( - data: Mapping[str, Any], - con: sa.engine.Engine, + data: Mapping[str, Any], con: ibis.backends.duckdb.Backend ) -> None: - metadata = sa.MetaData(bind=con) - - with con.begin() as bind: + with tempfile.TemporaryDirectory() as d: for table_name, schema in SCHEMAS.items(): - table = sa.Table( - table_name, - metadata, - *(sa.Column(col_name, col_type) for col_name, col_type in schema), + ibis_schema = ibis.schema(schema) + cols = ibis_schema.names + path = Path(d, f"{table_name}.jsonl") + path.write_text( + "\n".join(json.dumps(dict(zip(cols, row))) for row in data[table_name]) ) - table_columns = table.c.keys() - post_parse = POST_PARSE_FUNCTIONS.get(table_name, toolz.identity) - - table.drop(bind=bind, checkfirst=True) - table.create(bind=bind) - bind.execute( - table.insert().values(), - [post_parse(dict(zip(table_columns, row))) for row in data[table_name]], + con.create_table( + table_name, obj=con.read_json(path), schema=ibis_schema, overwrite=True ) @@ -109,9 +93,8 @@ def main() -> None: response = requests.get(args.input_data_url) response.raise_for_status() input_data = response.json() - db_path = Path(args.output_directory).joinpath("geography.db") - con = sa.create_engine(f"sqlite:///{db_path}") - make_geography_db(input_data, con) + db_path = Path(args.output_directory).joinpath("geography.duckdb") + make_geography_db(input_data, ibis.duckdb.connect(db_path)) print(db_path) # noqa: T201 diff --git a/ci/schema/snowflake.sql b/ci/schema/snowflake.sql index d0fd21ca7e871..97a4fe9950b27 100644 --- a/ci/schema/snowflake.sql +++ b/ci/schema/snowflake.sql @@ -1,4 +1,4 @@ -CREATE OR REPLACE TABLE diamonds ( +CREATE OR REPLACE TABLE "diamonds" ( "carat" FLOAT, "cut" TEXT, "color" TEXT, @@ -11,7 +11,7 @@ CREATE OR REPLACE TABLE diamonds ( "z" FLOAT ); -CREATE OR REPLACE TABLE astronauts ( +CREATE OR REPLACE TABLE "astronauts" ( "id" BIGINT, "number" BIGINT, "nationwide_number" BIGINT, @@ -38,7 +38,7 @@ CREATE OR REPLACE TABLE astronauts ( "total_eva_hrs" FLOAT ); -CREATE OR REPLACE TABLE batting ( +CREATE OR REPLACE TABLE "batting" ( "playerID" TEXT, "yearID" BIGINT, "stint" BIGINT, @@ -63,7 +63,7 @@ CREATE OR REPLACE TABLE batting ( "GIDP" BIGINT ); -CREATE OR REPLACE TABLE awards_players ( +CREATE OR REPLACE TABLE "awards_players" ( "playerID" TEXT, "awardID" TEXT, "yearID" BIGINT, @@ -72,7 +72,7 @@ CREATE OR REPLACE TABLE awards_players ( "notes" TEXT ); -CREATE OR REPLACE TABLE functional_alltypes ( +CREATE OR REPLACE TABLE "functional_alltypes" ( "id" INTEGER, "bool_col" BOOLEAN, "tinyint_col" SMALLINT, @@ -88,7 +88,7 @@ CREATE OR REPLACE TABLE functional_alltypes ( "month" INTEGER ); -CREATE OR REPLACE TABLE array_types ( +CREATE OR REPLACE TABLE "array_types" ( "x" ARRAY, "y" ARRAY, "z" ARRAY, @@ -97,7 +97,7 @@ CREATE OR REPLACE TABLE array_types ( "multi_dim" ARRAY ); -INSERT INTO array_types ("x", "y", "z", "grouper", "scalar_column", "multi_dim") +INSERT INTO "array_types" ("x", "y", "z", "grouper", "scalar_column", "multi_dim") SELECT [1, 2, 3], ['a', 'b', 'c'], [1.0, 2.0, 3.0], 'a', 1.0, [[], [1, 2, 3], NULL] UNION SELECT [4, 5], ['d', 'e'], [4.0, 5.0], 'a', 2.0, [] UNION SELECT [6, NULL], ['f', NULL], [6.0, NULL], 'a', 3.0, [NULL, [], NULL] UNION @@ -105,16 +105,16 @@ INSERT INTO array_types ("x", "y", "z", "grouper", "scalar_column", "multi_dim") SELECT [2, NULL, 3], ['b', NULL, 'c'], NULL, 'b', 5.0, NULL UNION SELECT [4, NULL, NULL, 5], ['d', NULL, NULL, 'e'], [4.0, NULL, NULL, 5.0], 'c', 6.0, [[1, 2, 3]]; -CREATE OR REPLACE TABLE map ("idx" BIGINT, "kv" OBJECT); +CREATE OR REPLACE TABLE "map" ("idx" BIGINT, "kv" OBJECT); -INSERT INTO map ("idx", "kv") +INSERT INTO "map" ("idx", "kv") SELECT 1, object_construct('a', 1, 'b', 2, 'c', 3) UNION SELECT 2, object_construct('d', 4, 'e', 5, 'f', 6); -CREATE OR REPLACE TABLE struct ("abc" OBJECT); +CREATE OR REPLACE TABLE "struct" ("abc" OBJECT); -INSERT INTO struct ("abc") +INSERT INTO "struct" ("abc") SELECT {'a': 1.0, 'b': 'banana', 'c': 2} UNION SELECT {'a': 2.0, 'b': 'apple', 'c': 3} UNION SELECT {'a': 3.0, 'b': 'orange', 'c': 4} UNION @@ -123,9 +123,9 @@ INSERT INTO struct ("abc") SELECT NULL UNION SELECT {'a': 3.0, 'b': 'orange', 'c': NULL}; -CREATE OR REPLACE TABLE json_t ("js" VARIANT); +CREATE OR REPLACE TABLE "json_t" ("js" VARIANT); -INSERT INTO json_t ("js") +INSERT INTO "json_t" ("js") SELECT parse_json('{"a": [1,2,3,4], "b": 1}') UNION SELECT parse_json('{"a":null,"b":2}') UNION SELECT parse_json('{"a":"foo", "c":null}') UNION @@ -133,8 +133,8 @@ INSERT INTO json_t ("js") SELECT parse_json('[42,47,55]') UNION SELECT parse_json('[]'); -CREATE OR REPLACE TABLE win ("g" TEXT, "x" BIGINT NOT NULL, "y" BIGINT); -INSERT INTO win VALUES +CREATE OR REPLACE TABLE "win" ("g" TEXT, "x" BIGINT NOT NULL, "y" BIGINT); +INSERT INTO "win" VALUES ('a', 0, 3), ('a', 1, 2), ('a', 2, 0), diff --git a/docs/backends/app/backend_info_app.py b/docs/backends/app/backend_info_app.py index f4edd7ec6cdbc..1eca15dea0aae 100644 --- a/docs/backends/app/backend_info_app.py +++ b/docs/backends/app/backend_info_app.py @@ -45,25 +45,25 @@ def support_matrix_df(): def backends_info_df(): return pd.DataFrame( { - "bigquery": ["string", "sql"], - "clickhouse": ["string", "sql"], + "bigquery": ["sql"], + "clickhouse": ["sql"], "dask": ["dataframe"], "datafusion": ["sql"], - "druid": ["sqlalchemy", "sql"], - "duckdb": ["sqlalchemy", "sql"], - "exasol": ["sqlalchemy", "sql"], - "flink": ["string", "sql"], - "impala": ["string", "sql"], - "mssql": ["sqlalchemy", "sql"], - "mysql": ["sqlalchemy", "sql"], - "oracle": ["sqlalchemy", "sql"], + "druid": ["sql"], + "duckdb": ["sql"], + "exasol": ["sql"], + "flink": ["sql"], + "impala": ["sql"], + "mssql": ["sql"], + "mysql": ["sql"], + "oracle": ["sql"], "pandas": ["dataframe"], "polars": ["dataframe"], - "postgres": ["sqlalchemy", "sql"], - "pyspark": ["dataframe"], - "snowflake": ["sqlalchemy", "sql"], - "sqlite": ["sqlalchemy", "sql"], - "trino": ["sqlalchemy", "sql"], + "postgres": ["sql"], + "pyspark": ["sql"], + "snowflake": ["sql"], + "sqlite": ["sql"], + "trino": ["sql"], }.items(), columns=["backend_name", "categories"], ) diff --git a/docs/how-to/extending/elementwise.qmd b/docs/how-to/extending/elementwise.qmd deleted file mode 100644 index 438a035a312a5..0000000000000 --- a/docs/how-to/extending/elementwise.qmd +++ /dev/null @@ -1,141 +0,0 @@ -# Add an elementwise operation - -This notebook will show you how to add a new elementwise operation to an existing backend. - -We are going to add `julianday`, a function supported by the SQLite database, to -the SQLite Ibis backend. - -The Julian day of a date, is the number of days since January 1st, 4713 BC. For -more information check the [Julian -day](https://en.wikipedia.org/wiki/Julian_day) Wikipedia page. - -## Step 1: Define the Operation - -Let's define the `julianday` operation as a function that takes one string input -argument and returns a float. - -```python -def julianday(date: str) -> float: - """Return the Julian day from a date.""" -``` - - -```{python} -import ibis.expr.datatypes as dt -import ibis.expr.rules as rlz -import ibis.expr.datashape as ds - -from ibis.expr.operations import Value - - -class JulianDay(Value): - arg: Value[dt.String, ds.Any] - - dtype = dt.float32 - shape = rlz.shape_like('arg') -``` - -We just defined a `JulianDay` class that takes one argument of type string or -binary, and returns a float. - -## Step 2: Define the API - -Because we know the output type of the operation, to make an expression out of -``JulianDay`` we can construct it and call its `ibis.expr.types.Node.to_expr` -method. - -We still need to add a method to `StringValue` (this needs to work on both -scalars and columns). - -When you add a method to any of the expression classes whose name matches -`*Value` both the scalar and column child classes will pick it up, making it -easy to define operations for both scalars and columns in one place. - -We can do this by defining a function and assigning it to the appropriate class -of expressions. - -```{python} -from ibis.expr.types import StringValue - - -def julianday(string_value): - return JulianDay(string_value).to_expr() - - -StringValue.julianday = julianday -``` - -## Interlude: Create some expressions with `julianday` - - -```{python} -import ibis - -t = ibis.table(dict(string_col="string"), name="t") - -t.string_col.julianday() -``` - -## Step 3: Turn the Expression into SQL - - -```{python} -import sqlalchemy as sa - - -@ibis.sqlite.add_operation(JulianDay) -def _julianday(translator, expr): - # pull out the arguments to the expression - (arg,) = expr.args - - # compile the argument - compiled_arg = translator.translate(arg) - - # return a SQLAlchemy expression that calls into the SQLite julianday function - return sa.func.julianday(compiled_arg) -``` - -## Step 4: Putting it all Together - -Download the geography database. - -```{python} -!curl -LsS -o geography.db 'https://storage.googleapis.com/ibis-tutorial-data/geography.db' - -con = ibis.sqlite.connect("geography.db") -``` - -### Create and execute a `julianday` expression - - -```{python} -ind = con.table("independence") -ind -``` - - -```{python} -day = ind.independence_date.cast("string") -day -``` - - -```{python} -jday_expr = day.julianday().name("jday") -jday_expr -``` - - -```{python} -ibis.to_sql(jday_expr) -``` - -Because we've defined our operation on `StringValue`, and not just on -`StringColumn` we get operations on both string scalars *and* string columns for -free. - - -```{python} -jday = ibis.literal("2010-03-14").julianday() -con.execute(jday) -``` diff --git a/docs/how-to/extending/reduction.qmd b/docs/how-to/extending/reduction.qmd deleted file mode 100644 index 9675230eb4511..0000000000000 --- a/docs/how-to/extending/reduction.qmd +++ /dev/null @@ -1,165 +0,0 @@ -# Add a reduction operation - -This notebook will show you how to add a new *reduction* operation `last_date` -to the existing backend SQLite. - -A reduction operation is a function that maps $N$ rows to 1 row, for example the -`sum` function. - -## Description - -We're going to add a **`last_date`** function to ibis. `last_date` returns the -latest date of a list of dates. - -## Step 1: Define the Operation - -Let's define the `last_date` operation as a function that takes any date column as input and returns a date: - -```python -from __future__ import annotations - -from datetime import date - - -def last_date(dates: list[date]) -> date: - """Latest date.""" -``` - - -```{python} -from __future__ import annotations - -import ibis.expr.datatypes as dt -import ibis.expr.datashape as ds -import ibis.expr.rules as rlz - -from ibis.expr.operations import Reduction, Value - - -class LastDate(Reduction): - arg: Value[dt.Date, ds.Any] - where: Value[dt.Boolean, ds.Any] | None = None - - dtype = rlz.dtype_like("arg") - shape = ds.scalar -``` - -We just defined a `LastDate` class that takes one date column as input, and -returns a scalar output of the same type as the input. This matches both the -requirements of a reduction and the specifics of the function that we want to -implement. - -**Note**: It is very important that you write the correct argument rules and -output type here. The expression *will not work* otherwise. - -## Step 2: Define the API - -Because every reduction in Ibis has the ability to filter out values during -aggregation, to make an expression out of `LastDate` we need to pass an -additional argument `where` to our `LastDate` constructor. - -Additionally, reductions should be defined on `Column` classes because -reductions are not always well-defined for a scalar value. - - -```{python} -from ibis.expr.types import DateColumn - - -def last_date(date_column, where=None): - return LastDate(date_column, where=where).to_expr() - - -DateColumn.last_date = last_date -``` - -## Interlude: Create some expressions using `last_date` - -```{python} -import ibis - - -people = ibis.table( - dict(name="string", country="string", date_of_birth="date"), - name="people", -) -``` - - -```{python} -people.date_of_birth.last_date() -``` - - -```{python} -people.date_of_birth.last_date(people.country == "Indonesia") -``` - -## Step 3: Turn the Expression into SQL - - -```{python} -import sqlalchemy as sa - - -@ibis.sqlite.add_operation(LastDate) -def _last_date(translator, expr): - # pull out the arguments to the expression - op = expr.op() - - arg = op.arg - where = op.where - - # compile the argument - compiled_arg = translator.translate(arg) - - # call the appropriate SQLite function (`max` for the latest date) - agg = sa.func.max(compiled_arg) - - # handle a non-None filter clause - if where is not None: - return agg.filter(translator.translate(where)) - return agg -``` - -## Step 4: Putting it all Together - -Download the geography database. - -```{python} -!curl -LsS -o geography.db 'https://storage.googleapis.com/ibis-tutorial-data/geography.db' - -con = ibis.sqlite.connect("geography.db") -``` - -### Create and execute a `bitwise_and` expression - - -```{python} -ind = con.table("independence") -ind -``` - -Last country to gain independence in our database: - - -```{python} -expr = ind.independence_date.last_date() -expr -``` - - -```{python} -ibis.to_sql(expr) -``` - -Show the last country to gain independence from the Spanish Empire, using the -`where` parameter: - - -```{python} -expr = ind.independence_date.last_date( - where=ind.independence_from == "Spanish Empire" -) -expr -``` diff --git a/gen_redirects.py b/gen_redirects.py index 01de7a2a48fb7..fa53031ba95ac 100644 --- a/gen_redirects.py +++ b/gen_redirects.py @@ -127,10 +127,14 @@ "/how_to/chain_expressions/": "/how-to/analytics/chain_expressions/", "/how_to/configuration/": "/how-to/configure/basics", "/how_to/duckdb_register/": "/backends/duckdb#ibis.backends.duckdb.Backend.register", - "/how_to/extending/elementwise/": "/how-to/extending/elementwise", - "/how_to/extending/elementwise/elementwise.ipynb": "/how-to/extending/elementwise", - "/how_to/extending/reduction/": "/how-to/extending/reduction", - "/how_to/extending/reduction/reduction.ipynb": "/how-to/extending/reduction", + "/how_to/extending/elementwise/": "/how-to/extending/builtin", + "/how_to/extending/elementwise/elementwise.ipynb": "/how-to/extending/builtin", + "/how_to/extending/reduction/": "/how-to/extending/builtin", + "/how_to/extending/reduction/reduction.ipynb": "/how-to/extending/builtin", + "/how-to/extending/elementwise/": "/how-to/extending/builtin", + "/how-to/extending/elementwise/elementwise.ipynb": "/how-to/extending/builtin", + "/how-to/extending/reduction/": "/how-to/extending/builtin", + "/how-to/extending/reduction/reduction.ipynb": "/how-to/extending/builtin", "/how_to/ffill_bfill_w_window/": "/posts/ffill-and-bfill-using-ibis", "/how_to/self_joins/": "/tutorials/ibis-for-sql-users#self-joins", "/how_to/sessionize": "/how-to/timeseries/sessionize", @@ -159,7 +163,7 @@ "/tutorial/ibis-for-sql-users/ibis-for-sql-users.ipynb": "/tutorials/ibis-for-sql-users/", "/user_guide/configuration/": "/how-to/configure/basics", "/user_guide/design/": "/concepts/internals", - "/user_guide/extending/": "/how-to/extending/elementwise", + "/user_guide/extending/": "/how-to/extending/builtin", "/user_guide/self_joins/": "/tutorials/ibis-for-sql-users#self-joins", "/versioning": "/concepts/versioning", "/why_ibis/": "/why", diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 8e3f825321e8e..6714e7a38aeb0 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -33,6 +33,8 @@ __all__ = ("BaseBackend", "Database", "connect") +# TODO(cpcloud): move these to a place that doesn't require importing +# backend-specific dependencies _IBIS_TO_SQLGLOT_DIALECT = { "mssql": "tsql", "impala": "hive", @@ -44,19 +46,6 @@ "risingwave": "postgres", } -_SQLALCHEMY_TO_SQLGLOT_DIALECT = { - # sqlalchemy dialects of backends not listed here match the sqlglot dialect - # name - "mssql": "tsql", - "postgresql": "postgres", - "default": "duckdb", - # druid allows double quotes for identifiers, like postgres: - # https://druid.apache.org/docs/latest/querying/sql#identifiers-and-literals - "druid": "postgres", - # closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901 - "exa.websocket": "oracle", -} - class Database: """Generic Database class.""" @@ -878,12 +867,6 @@ def connect(self, *args, **kwargs) -> BaseBackend: new_backend.reconnect() return new_backend - def _from_url(self, url: str, **kwargs) -> BaseBackend: - """Construct an ibis backend from a SQLAlchemy-conforming URL.""" - raise NotImplementedError( - f"`_from_url` not implemented for the {self.name} backend" - ) - @staticmethod def _convert_kwargs(kwargs: MutableMapping) -> None: """Manipulate keyword arguments to `.connect` method.""" @@ -1408,17 +1391,14 @@ def connect(resource: Path | str, **kwargs: Any) -> BaseBackend: parsed = parsed._replace(query=query) if scheme in ("postgres", "postgresql"): - # Treat `postgres://` and `postgresql://` the same, just as postgres - # does. We normalize to `postgresql` since that's what SQLAlchemy - # accepts. + # Treat `postgres://` and `postgresql://` the same scheme = "postgres" - parsed = parsed._replace(scheme="postgresql") # Convert all arguments back to a single URL string url = parsed.geturl() if "://" not in url: - # SQLAlchemy requires a `://`, while urllib may roundtrip - # `duckdb://` to `duckdb:`. Here we re-add the missing `//`. + # urllib may roundtrip `duckdb://` to `duckdb:`. Here we re-add the + # missing `//`. url = url.replace(":", "://", 1) try: diff --git a/ibis/backends/base/sql/__init__.py b/ibis/backends/base/sql/__init__.py index b89d265f434a9..ddece41f926de 100644 --- a/ibis/backends/base/sql/__init__.py +++ b/ibis/backends/base/sql/__init__.py @@ -2,11 +2,9 @@ import abc import contextlib -import os from functools import lru_cache from typing import TYPE_CHECKING, Any, Optional - -import toolz +from urllib.parse import parse_qs, urlparse import ibis.common.exceptions as exc import ibis.expr.operations as ops @@ -35,7 +33,7 @@ class BaseSQLBackend(BaseBackend): def _sqlglot_dialect(self) -> str: return self.name - def _from_url(self, url: str, **kwargs: Any) -> BaseBackend: + def _from_url(self, url: str, **kwargs): """Connect to a backend using a URL `url`. Parameters @@ -43,7 +41,7 @@ def _from_url(self, url: str, **kwargs: Any) -> BaseBackend: url URL with which to connect to a backend. kwargs - Additional keyword arguments passed to the `connect` method. + Additional keyword arguments Returns ------- @@ -51,25 +49,25 @@ def _from_url(self, url: str, **kwargs: Any) -> BaseBackend: A backend instance """ - import sqlalchemy as sa - - url = sa.engine.make_url(url) - new_kwargs = kwargs.copy() - kwargs = {} - - for name in ("host", "port", "database", "password"): - if value := ( - getattr(url, name, None) - or os.environ.get(f"{self.name.upper()}_{name.upper()}") - ): + url = urlparse(url) + database = url.path[1:] + query_params = parse_qs(url.query) + kwargs = { + "user": url.username, + "password": url.password or "", + "host": url.hostname, + "database": database or "", + } | kwargs + + for name, value in query_params.items(): + if len(value) > 1: kwargs[name] = value - if username := url.username: - kwargs["user"] = username + elif len(value) == 1: + kwargs[name] = value[0] + else: + raise exc.IbisError(f"Invalid URL parameter: {name}") - kwargs.update(url.query) - new_kwargs = toolz.merge(kwargs, new_kwargs) - self._convert_kwargs(new_kwargs) - return self.connect(**new_kwargs) + return self.connect(**kwargs) def table(self, name: str, database: str | None = None) -> ir.Table: """Construct a table expression. diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py deleted file mode 100644 index 165187ee53f44..0000000000000 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ /dev/null @@ -1,1043 +0,0 @@ -from __future__ import annotations - -import abc -import atexit -import contextlib -import warnings -from operator import methodcaller -from typing import TYPE_CHECKING, Any - -import sqlalchemy as sa -import sqlglot as sg -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import quoted_name -from sqlalchemy.sql.expression import ClauseElement, Executable - -import ibis -import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.schema as sch -import ibis.expr.types as ir -from ibis import util -from ibis.backends.base import CanCreateSchema -from ibis.backends.base.sql import BaseSQLBackend -from ibis.backends.base.sql.alchemy.query_builder import AlchemyCompiler -from ibis.backends.base.sql.alchemy.registry import ( - fixed_arity, - get_sqla_table, - reduction, - sqlalchemy_operation_registry, - sqlalchemy_window_functions_registry, - unary, - varargs, - variance_reduction, -) -from ibis.backends.base.sql.alchemy.translator import ( - AlchemyContext, - AlchemyExprTranslator, -) -from ibis.backends.base.sqlglot import STAR -from ibis.formats.pandas import PandasData - -if TYPE_CHECKING: - from collections.abc import Iterable, Mapping - - import pandas as pd - import pyarrow as pa - - -__all__ = ( - "BaseAlchemyBackend", - "AlchemyExprTranslator", - "AlchemyContext", - "AlchemyCompiler", - "sqlalchemy_operation_registry", - "sqlalchemy_window_functions_registry", - "reduction", - "variance_reduction", - "fixed_arity", - "unary", - "infix_op", - "get_sqla_table", - "schema_from_table", - "varargs", -) - - -class CreateTableAs(Executable, ClauseElement): - inherit_cache = True - - def __init__( - self, - name, - query, - temp: bool = False, - overwrite: bool = False, - quote: bool | None = None, - ): - self.name = name - self.query = query - self.temp = temp - self.overwrite = overwrite - self.quote = quote - - -@compiles(CreateTableAs) -def _create_table_as(element, compiler, **kw): - stmt = "CREATE " - - if element.overwrite: - stmt += "OR REPLACE " - - if element.temp: - stmt += "TEMPORARY " - - name = compiler.preparer.quote(quoted_name(element.name, quote=element.quote)) - return stmt + f"TABLE {name} AS {compiler.process(element.query, **kw)}" - - -class AlchemyCanCreateSchema(CanCreateSchema): - def list_schemas( - self, like: str | None = None, database: str | None = None - ) -> list[str]: - schema = ".".join(filter(None, (database, "information_schema"))) - sch = sa.table( - "schemata", - sa.column("catalog_name", sa.TEXT()), - sa.column("schema_name", sa.TEXT()), - schema=schema, - ) - - query = sa.select(sch.c.schema_name) - - with self.begin() as con: - schemas = list(con.execute(query).scalars()) - return self._filter_with_like(schemas, like=like) - - -class BaseAlchemyBackend(BaseSQLBackend): - """Backend class for backends that compile to SQLAlchemy expressions.""" - - compiler = AlchemyCompiler - supports_temporary_tables = True - _temporary_prefix = "TEMPORARY" - - def _scalar_query(self, query): - method = "exec_driver_sql" if isinstance(query, str) else "execute" - with self.begin() as con: - return getattr(con, method)(query).scalar() - - def _compile_type(self, dtype) -> str: - dialect = self.con.dialect - return sa.types.to_instance( - self.compiler.translator_class.get_sqla_type(dtype) - ).compile(dialect=dialect) - - def _build_alchemy_url( - self, - url: str | None, - host: str | None, - port: int | None, - user: str | None, - password: str | None, - database: str | None, - driver: str | None, - query: Mapping[str, Any] | None = None, - ) -> sa.engine.URL: - if url is not None: - return sa.engine.url.make_url(url) - - return sa.engine.url.URL.create( - driver, - host=host, - port=port, - username=user, - password=password, - database=database, - query=query or {}, - ) - - @property - def _current_schema(self) -> str | None: - return None - - def do_connect(self, con: sa.engine.Engine) -> None: - self.con = con - self._inspector = None - self._schemas: dict[str, sch.Schema] = {} - self._temp_views: set[str] = set() - - @property - def version(self): - if self._inspector is None: - self._inspector = sa.inspect(self.con) - return ".".join(map(str, self.con.dialect.server_version_info)) - - def list_tables(self, like=None, database=None): - tables = self.inspector.get_table_names(schema=database) - views = self.inspector.get_view_names(schema=database) - return self._filter_with_like(tables + views, like) - - @property - def inspector(self): - if self._inspector is None: - self._inspector = sa.inspect(self.con) - else: - self._inspector.info_cache.clear() - return self._inspector - - def _to_sql(self, expr: ir.Expr, **kwargs) -> str: - # For `ibis.to_sql` calls we render with literal binds and qmark params - dialect_class = sa.dialects.registry.load( - self.compiler.translator_class._dialect_name - ) - sql = self.compile(expr, **kwargs).compile( - dialect=dialect_class(paramstyle="qmark"), - compile_kwargs=dict(literal_binds=True), - ) - return str(sql) - - @contextlib.contextmanager - def _safe_raw_sql(self, *args, **kwargs): - with self.begin() as con: - yield con.execute(*args, **kwargs) - - def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: - import pandas as pd - - try: - df = pd.DataFrame.from_records( - cursor, columns=schema.names, coerce_float=True - ) - except Exception: - # clean up the cursor if we fail to create the DataFrame - # - # in the sqlite case failing to close the cursor results in - # artificially locked tables - cursor.close() - raise - df = PandasData.convert_table(df, schema) - return df - - @contextlib.contextmanager - def begin(self): - with self.con.begin() as bind: - yield bind - - def _clean_up_tmp_table(self, tmptable: sa.Table) -> None: - with self.begin() as bind: - tmptable.drop(bind=bind) - - def create_table( - self, - name: str, - obj: pd.DataFrame | pa.Table | ir.Table | None = None, - *, - schema: sch.Schema | None = None, - database: str | None = None, - temp: bool = False, - overwrite: bool = False, - ) -> ir.Table: - """Create a table. - - Parameters - ---------- - name - Name of the new table. - obj - An Ibis table expression or pandas table that will be used to - extract the schema and the data of the new table. If not provided, - `schema` must be given. - schema - The schema for the new table. Only one of `schema` or `obj` can be - provided. - database - Name of the database where the table will be created, if not the - default. - temp - Should the table be temporary for the session. - overwrite - Clobber existing data - - Returns - ------- - Table - The table that was created. - - """ - if obj is None and schema is None: - raise com.IbisError("The schema or obj parameter is required") - - import pandas as pd - import pyarrow as pa - import pyarrow_hotfix # noqa: F401 - - if isinstance(obj, (pd.DataFrame, pa.Table)): - obj = ibis.memtable(obj) - - if database == self.current_database: - # avoid fully qualified name - database = None - - if database is not None: - raise NotImplementedError( - "Creating tables from a different database is not yet implemented" - ) - - if obj is not None and schema is not None: - if not obj.schema().equals(ibis.schema(schema)): - raise com.IbisTypeError( - "Expression schema is not equal to passed schema. " - "Try passing the expression without the schema" - ) - if schema is None: - schema = obj.schema() - - self._schemas[self._fully_qualified_name(name, database)] = schema - - if has_expr := obj is not None: - # this has to happen outside the `begin` block, so that in-memory - # tables are visible inside the transaction created by it - self._run_pre_execute_hooks(obj) - - table = self._table_from_schema( - name, - schema, - # most databases don't allow temporary tables in a specific - # database so let the backend decide - # - # the ones that do (e.g., snowflake) should implement their own - # `create_table` - database=None if temp else (database or self.current_database), - temp=temp, - ) - - if has_expr: - if self.supports_create_or_replace: - ctas = CreateTableAs( - name, - self.compile(obj), - temp=temp, - overwrite=overwrite, - quote=self.compiler.translator_class._quote_table_names, - ) - with self.begin() as bind: - bind.execute(ctas) - else: - tmptable = self._table_from_schema( - util.gen_name("tmp_table_insert"), - schema, - # some backends don't support temporary tables - temp=self.supports_temporary_tables, - ) - method = self._get_insert_method(obj) - insert = table.insert().from_select(tmptable.columns, tmptable.select()) - - with self.begin() as bind: - # 1. write `obj` to a unique temp table - tmptable.create(bind=bind) - - # try/finally here so that a successfully created tmptable gets - # cleaned up no matter what - try: - with self.begin() as bind: - bind.execute(method(tmptable.insert())) - - # 2. recreate the existing table - if overwrite: - table.drop(bind=bind, checkfirst=True) - table.create(bind=bind) - - # 3. insert the temp table's data into the (re)created table - bind.execute(insert) - finally: - self._clean_up_tmp_table(tmptable) - else: - with self.begin() as bind: - if overwrite: - table.drop(bind=bind, checkfirst=True) - table.create(bind=bind) - return self.table(name, database=database) - - def _get_insert_method(self, expr): - compiled = self.compile(expr) - - # if in memory tables aren't cheap then try to pull out their data - # FIXME: queries that *select* from in memory tables are still broken - # for mysql/sqlite/postgres because the generated SQL is wrong - if ( - not self.compiler.cheap_in_memory_tables - and self.compiler.support_values_syntax_in_select - and isinstance(expr.op(), ops.InMemoryTable) - ): - (from_,) = compiled.get_final_froms() - try: - (rows,) = from_._data - except AttributeError: - return methodcaller("from_select", list(expr.columns), from_) - else: - return methodcaller("values", rows) - - return methodcaller("from_select", list(expr.columns), compiled) - - def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column]: - return [ - sa.Column( - colname, - self.compiler.translator_class.get_sqla_type(dtype), - nullable=dtype.nullable, - quote=self.compiler.translator_class._quote_column_names, - ) - for colname, dtype in zip(schema.names, schema.types) - ] - - def _table_from_schema( - self, - name: str, - schema: sch.Schema, - temp: bool = False, - database: str | None = None, - **kwargs: Any, - ) -> sa.Table: - columns = self._columns_from_schema(name, schema) - return sa.Table( - name, - sa.MetaData(), - *columns, - prefixes=[self._temporary_prefix] if temp else [], - quote=self.compiler.translator_class._quote_table_names, - **kwargs, - ) - - def drop_table( - self, name: str, *, database: str | None = None, force: bool = False - ) -> None: - """Drop a table. - - Parameters - ---------- - name - Table to drop - database - Database to drop table from - force - Check for existence before dropping - - """ - if database == self.current_database: - # avoid fully qualified name - database = None - - if database is not None: - raise com.IbisInputError( - "Dropping tables from a different database is not yet implemented" - ) - - t = self._get_sqla_table( - name, namespace=ops.Namespace(database=database), autoload=False - ) - with self.begin() as bind: - t.drop(bind=bind, checkfirst=force) - - qualified_name = self._fully_qualified_name(name, database) - - with contextlib.suppress(KeyError): - # schemas won't be cached if created with raw_sql - del self._schemas[qualified_name] - - def truncate_table(self, name: str, database: str | None = None) -> None: - t = self._get_sqla_table(name, namespace=ops.Namespace(database=database)) - with self.begin() as con: - con.execute(t.delete()) - - def schema(self, name: str) -> sch.Schema: - """Get an ibis schema from the current database for the table `name`. - - Parameters - ---------- - name - Table name - - Returns - ------- - Schema - The ibis schema of `name` - - """ - return self.database().schema(name) - - def _log(self, sql): - try: - query_str = str(sql) - except sa.exc.UnsupportedCompilationError: - pass - else: - util.log(query_str) - - @staticmethod - def _new_sa_metadata(): - return sa.MetaData() - - def _get_sqla_table( - self, - name: str, - *, - namespace: ops.Namespace = ops.Namespace(), # noqa: B008 - autoload: bool = True, - **_: Any, - ) -> sa.Table: - meta = self._new_sa_metadata() - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="Did not recognize type", category=sa.exc.SAWarning - ) - warnings.filterwarnings( - "ignore", message="index key", category=sa.exc.SAWarning - ) - table = sa.Table( - name, - meta, - schema=namespace.schema, - autoload_with=self.con if autoload else None, - quote=self.compiler.translator_class._quote_table_names, - ) - nulltype_cols = frozenset( - col.name for col in table.c if isinstance(col.type, sa.types.NullType) - ) - - if not nulltype_cols: - return table - return self._handle_failed_column_type_inference(table, nulltype_cols) - - # TODO(kszucs): remove the schema parameter - @classmethod - def _schema_from_sqla_table( - cls, - table: sa.sql.TableClause, - schema: sch.Schema | None = None, - ) -> sch.Schema: - """Retrieve an ibis schema from a SQLAlchemy `Table`. - - Parameters - ---------- - table - Table whose schema to infer - schema - Predefined ibis schema to pull types from - dialect - Optional sqlalchemy dialect - - Returns - ------- - schema - An ibis schema corresponding to the types of the columns in `table`. - - """ - schema = schema if schema is not None else {} - pairs = [] - for column in table.columns: - name = column.name - if name in schema: - dtype = schema[name] - else: - dtype = cls.compiler.translator_class.get_ibis_type( - column.type, nullable=column.nullable or column.nullable is None - ) - pairs.append((name, dtype)) - return sch.schema(pairs) - - def _handle_failed_column_type_inference( - self, table: sa.Table, nulltype_cols: Iterable[str] - ) -> sa.Table: - """Handle cases where SQLAlchemy cannot infer the column types of `table`.""" - self.inspector.reflect_table(table, table.columns) - - dialect = self.con.dialect - - quoted_name = ".".join( - map( - dialect.identifier_preparer.quote, - filter(None, [table.schema, table.name]), - ) - ) - - for colname, dtype in self._metadata(quoted_name): - if colname in nulltype_cols: - # replace null types discovered by sqlalchemy with non null - # types - table.append_column( - sa.Column( - colname, - self.compiler.translator_class.get_sqla_type(dtype), - nullable=dtype.nullable, - quote=self.compiler.translator_class._quote_column_names, - ), - replace_existing=True, - ) - return table - - def raw_sql(self, query: str | sa.sql.ClauseElement): - """Execute a query and return the cursor used for execution. - - ::: {.callout-tip} - ## Consider using [`.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) instead - - If your query is a `SELECT` statement you can use the - [backend `.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) method to avoid - having to manually release the cursor returned from this method. - - ::: {.callout-warning} - ## The cursor returned from this method must be **manually released** - - You **do not** need to call `.close()` on the cursor when running DDL - or DML statements like `CREATE`, `INSERT` or `DROP`, only when using - `SELECT` statements. - - To release a cursor, call the `close` method on the returned cursor - object. - - You can close the cursor by explicitly calling its `close` method: - - ```python - cursor = con.raw_sql("SELECT ...") - cursor.close() - ``` - - Or you can use a context manager: - - ```python - with con.raw_sql("SELECT ...") as cursor: - ... - ``` - ::: - - ::: - - Parameters - ---------- - query - SQL query or SQLAlchemy expression to execute - - Examples - -------- - >>> con = ibis.connect("duckdb://") - >>> with con.raw_sql("SELECT 1") as cursor: - ... result = cursor.fetchall() - >>> result - [(1,)] - >>> cursor.closed - True - - """ - return self.con.connect().execute( - sa.text(query) if isinstance(query, str) else query - ) - - def table( - self, - name: str, - database: str | None = None, - schema: str | None = None, - ) -> ir.Table: - """Create a table expression from a table in the database. - - Parameters - ---------- - name - Table name - database - The database the table resides in - schema - The schema inside `database` where the table resides. - - ::: {.callout-warning} - ## `schema` refers to database hierarchy - - The `schema` parameter does **not** refer to the column names and - types of `table`. - ::: - - Returns - ------- - Table - Table expression - - """ - namespace = ops.Namespace(schema=schema, database=database) - - sqla_table = self._get_sqla_table(name, namespace=namespace) - - schema = self._schema_from_sqla_table( - sqla_table, schema=self._schemas.get(name) - ) - node = ops.DatabaseTable( - name=name, schema=schema, source=self, namespace=namespace - ) - return node.to_expr() - - def _insert_dataframe( - self, table_name: str, df: pd.DataFrame, overwrite: bool - ) -> None: - namespace = ops.Namespace(schema=self._current_schema) - - t = self._get_sqla_table(table_name, namespace=namespace) - with self.con.begin() as con: - if overwrite: - con.execute(t.delete()) - con.execute(t.insert(), df.to_dict(orient="records")) - - def insert( - self, - table_name: str, - obj: pd.DataFrame | ir.Table | list | dict, - database: str | None = None, - overwrite: bool = False, - ) -> None: - """Insert data into a table. - - Parameters - ---------- - table_name - The name of the table to which data needs will be inserted - obj - The source data or expression to insert - database - Name of the attached database that the table is located in. - overwrite - If `True` then replace existing contents of table - - Raises - ------ - NotImplementedError - If inserting data from a different database - ValueError - If the type of `obj` isn't supported - - """ - - import pandas as pd - - if database == self.current_database: - # avoid fully qualified name - database = None - - if database is not None: - raise NotImplementedError( - "Inserting data to a table from a different database is not " - "yet implemented" - ) - - # If we've been passed a `memtable`, pull out the underlying dataframe - if isinstance(obj, ir.Table) and isinstance( - in_mem_table := obj.op(), ops.InMemoryTable - ): - obj = in_mem_table.data.to_frame() - - if isinstance(obj, pd.DataFrame): - self._insert_dataframe(table_name, obj, overwrite=overwrite) - elif isinstance(obj, ir.Table): - to_table_expr = self.table(table_name) - to_table_schema = to_table_expr.schema() - - if overwrite: - self.drop_table(table_name, database=database) - self.create_table(table_name, schema=to_table_schema, database=database) - - to_table = self._get_sqla_table( - table_name, namespace=ops.Namespace(database=database) - ) - - from_table_expr = obj - - if from_table_expr is not None: - compiled = from_table_expr.compile() - columns = [ - self.con.dialect.normalize_name(c) for c in from_table_expr.columns - ] - with self.begin() as bind: - bind.execute(to_table.insert().from_select(columns, compiled)) - elif isinstance(obj, (list, dict)): - to_table = self._get_sqla_table( - table_name, namespace=ops.Namespace(database=database) - ) - - with self.begin() as bind: - if overwrite: - bind.execute(to_table.delete()) - bind.execute(to_table.insert().values(obj)) - - else: - raise ValueError( - "No operation is being performed. Either the obj parameter " - "is not a pandas DataFrame or is not a ibis Table." - f"The given obj is of type {type(obj).__name__} ." - ) - - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: - if self.supports_python_udfs: - raise NotImplementedError( - f"The {self.name} backend does not support Python scalar UDFs" - ) - - def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: - if self.supports_python_udfs: - raise NotImplementedError( - f"The {self.name} backend does not support Pandas-based vectorized scalar UDFs" - ) - - def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> str: - if self.supports_python_udfs: - raise NotImplementedError( - f"The {self.name} backend does not support PyArrow-based vectorized scalar UDFs" - ) - - def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> str: - """No-op, because the function is assumed builtin.""" - - def _gen_udf_rule(self, op: ops.ScalarUDF): - @self.add_operation(type(op)) - def _(t, op): - generator = sa.func - if (namespace := op.__udf_namespace__) is not None: - generator = getattr(generator, namespace) - func = getattr(generator, op.__func_name__) - return func(*map(t.translate, op.args)) - - def _gen_udaf_rule(self, op: ops.AggUDF): - from ibis import NA - - @self.add_operation(type(op)) - def _(t, op): - args = (arg for name, arg in zip(op.argnames, op.args) if name != "where") - generator = sa.func - if (namespace := op.__udf_namespace__) is not None: - generator = getattr(generator, namespace) - func = getattr(generator, op.__func_name__) - - if (where := op.where) is None: - return func(*map(t.translate, args)) - elif t._has_reduction_filter_syntax: - return func(*map(t.translate, args)).filter(t.translate(where)) - else: - return func(*(t.translate(ops.IfElse(where, arg, NA)) for arg in args)) - - def _register_udfs(self, expr: ir.Expr) -> None: - with self.begin() as con: - for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - ) - if sql := compile_func(udf_node): - con.exec_driver_sql(sql) - - def _quote(self, name: str) -> str: - """Quote an identifier.""" - preparer = self.con.dialect.identifier_preparer - if self.compiler.translator_class._quote_table_names: - return preparer.quote_identifier(name) - return preparer.quote(name) - - def _get_temp_view_definition( - self, name: str, definition: sa.sql.compiler.Compiled - ) -> str: - raise NotImplementedError( - f"The {self.name} backend does not implement temporary view creation" - ) - - def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None: - query = f"DROP VIEW IF EXISTS {name}" - - def drop(self, raw_name: str, query: str): - with self.begin() as con: - con.exec_driver_sql(query) - self._temp_views.discard(raw_name) - - atexit.register(drop, self, raw_name, query) - - def _get_compiled_statement( - self, - definition: sa.sql.Selectable, - name: str, - compile_kwargs: Mapping[str, Any] | None = None, - ): - if compile_kwargs is None: - compile_kwargs = {} - compiled = definition.compile( - dialect=self.con.dialect, compile_kwargs=compile_kwargs - ) - create_view = self._get_temp_view_definition(name, definition=compiled) - params = compiled.params - if compiled.positional: - params = tuple(params.values()) - return create_view, params - - def _create_temp_view(self, view: sa.Table, definition: sa.sql.Selectable) -> None: - raw_name = view.name - if raw_name not in self._temp_views and raw_name in self.list_tables(): - raise ValueError(f"{raw_name} already exists as a table or view") - name = self._quote(raw_name) - self._execute_view_creation(name, definition) - self._temp_views.add(raw_name) - self._register_temp_view_cleanup(name, raw_name) - - def _execute_view_creation(self, name, definition): - lines, params = self._get_compiled_statement(definition, name) - with self.begin() as con: - for line in lines: - con.exec_driver_sql(line, parameters=params or ()) - - @abc.abstractmethod - def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: - ... - - def _get_schema_using_query(self, query: str) -> sch.Schema: - """Return an ibis Schema from a backend-specific SQL string.""" - return sch.Schema.from_tuples(self._metadata(query)) - - def _load_into_cache(self, name, expr): - self.create_table(name, expr, schema=expr.schema(), temp=True) - - def _clean_up_cached_table(self, op): - self.drop_table(op.name) - - def create_view( - self, - name: str, - obj: ir.Table, - *, - database: str | None = None, - overwrite: bool = False, - ) -> ir.Table: - from sqlalchemy_views import CreateView - - source = self.compile(obj) - view = CreateView( - sa.Table( - name, - sa.MetaData(), - schema=database, - quote=self.compiler.translator_class._quote_table_names, - ), - source, - or_replace=overwrite, - ) - with self.begin() as con: - con.execute(view) - return self.table(name, database=database) - - def drop_view( - self, name: str, *, database: str | None = None, force: bool = False - ) -> None: - from sqlalchemy_views import DropView - - view = DropView( - sa.Table( - name, - sa.MetaData(), - schema=database, - quote=self.compiler.translator_class._quote_table_names, - ), - if_exists=not force, - ) - - with self.begin() as con: - con.execute(view) - - -class AlchemyCrossSchemaBackend(BaseAlchemyBackend): - """A SQLAlchemy backend that supports cross-schema queries. - - This backend differs from the default SQLAlchemy backend in that it - overrides `_get_sqla_table` to potentially switch schemas during table - reflection, if the table requested lives in a different schema than the - currently active one. - """ - - def _get_table_identifier(self, *, name, namespace): - database = namespace.database - schema = namespace.schema - - if schema is None: - schema = self.current_schema - - try: - schema = sg.parse_one(schema, into=sg.exp.Identifier) - except sg.ParseError: - # not actually a table, but that's how sqlglot parses - # `CREATE SCHEMA` statements - parsed = sg.parse_one(schema, into=sg.exp.Table) - - # user passed database="foo", schema="bar.baz", which is ambiguous - if database is not None: - raise com.IbisInputError( - "Cannot specify both `database` and a dotted path in `schema`" - ) - - db = parsed.args["db"].this - schema = parsed.args["this"].this - else: - db = database - - table = sg.table( - name, - db=schema, - catalog=db, - quoted=self.compiler.translator_class._quote_table_names, - ).transform( - lambda node: node.__class__( - this=node.this, - quoted=node.quoted or self.compiler.translator_class._quote_table_names, - ) - if isinstance(node, sg.exp.Identifier) - else node - ) - return table - - def _get_sqla_table( - self, name: str, namespace: ops.Namespace, **_: Any - ) -> sa.Table: - table = self._get_table_identifier(name=name, namespace=namespace) - metadata_query = sg.select(STAR).from_(table).limit(0).sql(dialect=self.name) - pairs = self._metadata(metadata_query) - ibis_schema = ibis.schema(pairs) - - columns = self._columns_from_schema(name, ibis_schema) - result = sa.Table( - name, - sa.MetaData(), - *columns, - quote=self.compiler.translator_class._quote_table_names, - ) - result.fullname = table.sql(dialect=self.name) - return result - - def drop_table( - self, name: str, database: str | None = None, force: bool = False - ) -> None: - table = sg.table( - name, db=database, quoted=self.compiler.translator_class._quote_table_names - ) - drop_table = sg.exp.Drop(kind="TABLE", exists=force, this=table) - drop_table_sql = drop_table.sql(dialect=self.name) - with self.begin() as con: - con.exec_driver_sql(drop_table_sql) - - -@compiles(sa.Table, "trino", "duckdb") -def compile_trino_table(element, compiler, **kw): - return element.fullname - - -@compiles(sa.Table, "snowflake") -def compile_snowflake_table(element, compiler, **kw): - dialect = compiler.dialect.name - return ( - sg.parse_one(element.fullname, into=sg.exp.Table, read=dialect) - .transform( - lambda node: node.__class__(this=node.this, quoted=True) - if isinstance(node, sg.exp.Identifier) - else node - ) - .sql(dialect) - ) diff --git a/ibis/backends/base/sql/alchemy/datatypes.py b/ibis/backends/base/sql/alchemy/datatypes.py deleted file mode 100644 index 9d7f36ecc0af5..0000000000000 --- a/ibis/backends/base/sql/alchemy/datatypes.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -import sqlalchemy as sa -import sqlalchemy.types as sat -from sqlalchemy.ext.compiler import compiles - -import ibis.expr.datatypes as dt -from ibis.backends.base.sqlglot.datatypes import SqlglotType -from ibis.formats import TypeMapper - - -class UInt64(sat.Integer): - pass - - -class UInt32(sat.Integer): - pass - - -class UInt16(sat.Integer): - pass - - -class UInt8(sat.Integer): - pass - - -@compiles(UInt64, "mssql") -@compiles(UInt32, "mssql") -@compiles(UInt16, "mssql") -@compiles(UInt8, "mssql") -@compiles(UInt64, "sqlite") -@compiles(UInt32, "sqlite") -@compiles(UInt16, "sqlite") -@compiles(UInt8, "sqlite") -def compile_uint(element, compiler, **kw): - dialect_name = compiler.dialect.name - raise TypeError( - f"unsigned integers are not supported in the {dialect_name} backend" - ) - - -try: - UUID = sat.UUID -except AttributeError: - - class UUID(sat.String): - pass - -else: - - @compiles(UUID, "default") - def compiles_uuid(element, compiler, **kw): - return "UUID" - - -class Unknown(sa.Text): - pass - - -_from_sqlalchemy_types = { - sat.BOOLEAN: dt.Boolean, - sat.Boolean: dt.Boolean, - sat.BINARY: dt.Binary, - sat.BLOB: dt.Binary, - sat.LargeBinary: dt.Binary, - sat.DATE: dt.Date, - sat.Date: dt.Date, - sat.TEXT: dt.String, - sat.Text: dt.String, - sat.TIME: dt.Time, - sat.Time: dt.Time, - sat.VARCHAR: dt.String, - sat.CHAR: dt.String, - sat.String: dt.String, - sat.SMALLINT: dt.Int16, - sat.SmallInteger: dt.Int16, - sat.INTEGER: dt.Int32, - sat.Integer: dt.Int32, - sat.BIGINT: dt.Int64, - sat.BigInteger: dt.Int64, - sat.REAL: dt.Float32, - sat.FLOAT: dt.Float64, - UInt16: dt.UInt16, - UInt32: dt.UInt32, - UInt64: dt.UInt64, - UInt8: dt.UInt8, - Unknown: dt.Unknown, - sat.JSON: dt.JSON, - UUID: dt.UUID, -} - -_to_sqlalchemy_types = { - dt.Null: sat.NullType, - dt.Date: sat.Date, - dt.Time: sat.Time, - dt.Boolean: sat.Boolean, - dt.Binary: sat.LargeBinary, - dt.String: sat.Text, - dt.Decimal: sat.Numeric, - # Mantissa-based - dt.Float16: sat.REAL, - dt.Float32: sat.REAL, - # precision is the number of bits in the mantissa - # without specifying this, some backends interpret the type as FLOAT, which - # means float32 (and precision == 24) - dt.Float64: sat.FLOAT(precision=53), - dt.Int8: sat.SmallInteger, - dt.Int16: sat.SmallInteger, - dt.Int32: sat.Integer, - dt.Int64: sat.BigInteger, - dt.UInt8: UInt8, - dt.UInt16: UInt16, - dt.UInt32: UInt32, - dt.UInt64: UInt64, - dt.JSON: sat.JSON, - dt.Interval: sat.Interval, - dt.Unknown: Unknown, - dt.MACADDR: sat.Text, - dt.INET: sat.Text, - dt.UUID: UUID, -} - -_FLOAT_PREC_TO_TYPE = { - 11: dt.Float16, - 24: dt.Float32, - 53: dt.Float64, -} - - -class AlchemyType(TypeMapper): - @classmethod - def to_string(cls, dtype: dt.DataType): - dialect_class = sa.dialects.registry.load(cls.dialect) - return str( - sa.types.to_instance(cls.from_ibis(dtype)).compile(dialect=dialect_class()) - ) - - @classmethod - def from_string(cls, type_string, nullable=True): - return SqlglotType.from_string(type_string, nullable=nullable) - - @classmethod - def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine: - """Convert an Ibis type to a SQLAlchemy type. - - Parameters - ---------- - dtype - Ibis type to convert. - - Returns - ------- - SQLAlchemy type. - - """ - if dtype.is_decimal(): - return sat.NUMERIC(dtype.precision, dtype.scale) - elif dtype.is_timestamp(): - return sat.TIMESTAMP(timezone=bool(dtype.timezone)) - else: - return _to_sqlalchemy_types[type(dtype)] - - @classmethod - def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: - """Convert a SQLAlchemy type to an Ibis type. - - Parameters - ---------- - typ - SQLAlchemy type to convert. - nullable : bool, optional - Whether the returned type should be nullable. - - Returns - ------- - Ibis type. - - """ - if dtype := _from_sqlalchemy_types.get(type(typ)): - return dtype(nullable=nullable) - elif isinstance(typ, sat.Float): - if (float_typ := _FLOAT_PREC_TO_TYPE.get(typ.precision)) is not None: - return float_typ(nullable=nullable) - return dt.Decimal(typ.precision, typ.scale, nullable=nullable) - elif isinstance(typ, sat.Numeric): - return dt.Decimal(typ.precision, typ.scale, nullable=nullable) - elif isinstance(typ, sa.DateTime): - timezone = "UTC" if typ.timezone else None - return dt.Timestamp(timezone, nullable=nullable) - elif isinstance(typ, sat.String): - return dt.String(nullable=nullable) - else: - raise TypeError(f"Unable to convert type: {typ!r}") diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py deleted file mode 100644 index 9e9e92e11e4d7..0000000000000 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ /dev/null @@ -1,465 +0,0 @@ -from __future__ import annotations - -import functools - -import sqlalchemy as sa -import sqlglot as sg -import toolz -from sqlalchemy import sql - -import ibis.common.exceptions as com -import ibis.expr.analysis as an -import ibis.expr.operations as ops -from ibis.backends.base import _SQLALCHEMY_TO_SQLGLOT_DIALECT -from ibis.backends.base.sql.alchemy.translator import ( - AlchemyContext, - AlchemyExprTranslator, -) -from ibis.backends.base.sql.compiler import ( - Compiler, - Select, - SelectBuilder, - TableSetFormatter, -) -from ibis.backends.base.sql.compiler.base import SetOp - - -class _AlchemyTableSetFormatter(TableSetFormatter): - def get_result(self): - # Got to unravel the join stack; the nesting order could be - # arbitrary, so we do a depth first search and push the join tokens - # and predicates onto a flat list, then format them - op = self.node - - if isinstance(op, ops.Join): - self._walk_join_tree(op) - else: - self.join_tables.append(self._format_table(op)) - - result = self.join_tables[0] - for jtype, table, preds in zip( - self.join_types, self.join_tables[1:], self.join_predicates - ): - if preds: - sqla_preds = [self._translate(pred) for pred in preds] - onclause = functools.reduce(sql.and_, sqla_preds) - else: - onclause = None - - if jtype is ops.InnerJoin: - result = result.join(table, onclause) - elif jtype is ops.CrossJoin: - result = result.join(table, sa.literal(True)) - elif jtype is ops.LeftJoin: - result = result.join(table, onclause, isouter=True) - elif jtype is ops.RightJoin: - result = table.join(result, onclause, isouter=True) - elif jtype is ops.OuterJoin: - result = result.outerjoin(table, onclause, full=True) - elif jtype is ops.LeftSemiJoin: - # subquery is required for semi and anti joins done using - # sqlalchemy, otherwise multiple references to the original - # select are treated as distinct tables - # - # with a subquery, the result is a distinct table and so there's only one - # thing for subsequent expressions to reference - result = ( - result.select() - .where(sa.exists(sa.select(1).where(onclause))) - .subquery() - ) - elif jtype is ops.LeftAntiJoin: - result = ( - result.select() - .where(~sa.exists(sa.select(1).where(onclause))) - .subquery() - ) - else: - raise NotImplementedError(jtype) - - self.context.set_ref(op, result) - return result - - def _get_join_type(self, op): - return type(op) - - def _format_table(self, op): - ctx = self.context - - orig_op = op - if isinstance(op, (ops.SelfReference, ops.Sample)): - op = op.table - - alias = ctx.get_ref(orig_op) - - translator = ctx.compiler.translator_class(op, ctx) - - if isinstance(op, ops.DatabaseTable): - namespace = op.namespace - result = op.source._get_sqla_table(op.name, namespace=namespace) - elif isinstance(op, ops.UnboundTable): - # use SQLAlchemy's TableClause for unbound tables - name = op.name - namespace = op.namespace - result = sa.Table( - name, - sa.MetaData(), - *translator._schema_to_sqlalchemy_columns(op.schema), - quote=translator._quote_table_names, - ) - dialect = translator._dialect_name - result.fullname = sg.table( - name, - db=namespace.schema, - catalog=namespace.database, - quoted=translator._quote_table_names, - ).sql(dialect=_SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect, dialect)) - elif isinstance(op, ops.SQLQueryResult): - columns = translator._schema_to_sqlalchemy_columns(op.schema) - result = sa.text(op.query).columns(*columns) - elif isinstance(op, ops.SQLStringView): - columns = translator._schema_to_sqlalchemy_columns(op.schema) - result = sa.text(op.query).columns(*columns).cte(op.name) - elif isinstance(op, ops.View): - # TODO(kszucs): avoid converting to expression - child_expr = op.child.to_expr() - definition = child_expr.compile() - result = sa.Table( - op.name, - sa.MetaData(), - *translator._schema_to_sqlalchemy_columns(op.schema), - quote=translator._quote_table_names, - ) - backend = child_expr._find_backend() - backend._create_temp_view(view=result, definition=definition) - elif isinstance(op, ops.InMemoryTable): - result = self._format_in_memory_table(op, translator) - elif isinstance(op, ops.DummyTable): - result = sa.select( - *( - translator.translate(value).label(name) - for name, value in zip(op.schema.names, op.values) - ) - ) - elif ctx.is_extracted(op): - if isinstance(orig_op, ops.SelfReference): - result = ctx.get_ref(op) - elif isinstance(alias, str): - result = sa.table( - alias, - *translator._schema_to_sqlalchemy_columns(orig_op.schema), - ) - else: - result = alias - else: - result = ctx.get_compiled_expr(op) - - result = alias if hasattr(alias, "name") else result.alias(alias) - - if isinstance(orig_op, ops.Sample): - result = self._format_sample(orig_op, result) - - ctx.set_ref(orig_op, result) - return result - - def _format_sample(self, op, table): - # Should never be hit in practice, as Sample operations should be rewritten - # before this point for all backends without TABLESAMPLE support - raise com.UnsupportedOperationError("`Table.sample` is not supported") - - def _format_in_memory_table(self, op, translator): - columns = translator._schema_to_sqlalchemy_columns(op.schema) - if self.context.compiler.cheap_in_memory_tables: - result = sa.Table( - op.name, - sa.MetaData(), - *columns, - quote=translator._quote_table_names, - ) - elif not op.data: - result = sa.select( - *( - translator.translate(ops.Literal(None, dtype=type_)).label(name) - for name, type_ in op.schema.items() - ) - ).limit(0) - elif self.context.compiler.support_values_syntax_in_select: - rows = list(op.data.to_frame().itertuples(index=False)) - result = sa.values(*columns, name=op.name).data(rows).select().subquery() - else: - raw_rows = ( - sa.select( - *( - translator.translate(ops.Literal(val, dtype=type_)).label(name) - for val, (name, type_) in zip(row, op.schema.items()) - ) - ) - for row in op.data.to_frame().itertuples(index=False) - ) - result = sa.union_all(*raw_rows).alias(op.name) - return result - - -class AlchemySelect(Select): - def __init__(self, *args, **kwargs): - self.exists = kwargs.pop("exists", False) - super().__init__(*args, **kwargs) - - def compile(self): - # Can't tell if this is a hack or not. Revisit later - self.context.set_query(self) - - self._compile_subqueries() - - frag = self._compile_table_set() - steps = [ - self._add_select, - self._add_group_by, - self._add_where, - self._add_order_by, - self._add_limit, - ] - - for step in steps: - frag = step(frag) - - return frag - - def _compile_subqueries(self): - if not self.subqueries: - return - - for expr in self.subqueries: - result = self.context.get_compiled_expr(expr) - alias = self.context.get_ref(expr) - result = result.cte(alias) - self.context.set_ref(expr, result) - - def _compile_table_set(self): - if self.table_set is None: - return None - - return self.table_set_formatter_class(self, self.table_set).get_result() - - def _add_select(self, table_set): - if not self.select_set: - return table_set.element - - to_select = [] - - context = self.context - select_set = self.select_set - - has_select_star = False - for op in select_set: - if isinstance(op, ops.Value): - arg = self._translate(op, named=True) - elif isinstance(op, ops.TableNode): - arg = context.get_ref(op) - if op.equals(self.table_set): - if has_select_star := arg is None: - continue - else: - arg = table_set - elif arg is None: - raise ValueError(op) - else: - raise TypeError(op) - - to_select.append(arg) - - if has_select_star: - if table_set is None: - raise ValueError("table_set cannot be None here") - - clauses = [table_set] + to_select - else: - clauses = to_select - - result_func = sa.exists if self.exists else sa.select - result = result_func(*clauses) - - if self.distinct: - result = result.distinct() - - # only process unnest if the backend doesn't support SELECT UNNEST(...) - unnest_children = [] - if not self.translator_class.supports_unnest_in_select: - unnest_children.extend( - map( - context.get_ref, - toolz.unique(an.find_toplevel_unnest_children(select_set)), - ) - ) - - # if we're SELECT *-ing or there's no table_set (e.g., SELECT 1) *and* - # there are no unnest operations then we can return early - if (has_select_star or table_set is None) and not unnest_children: - return result - - if unnest_children: - # get all the unnests plus the current FROM clauses of the result - # selection and build up the cross join - table_set = functools.reduce( - functools.partial(sa.sql.FromClause.join, onclause=sa.true()), - toolz.unique(toolz.concatv(unnest_children, result.get_final_froms())), - ) - - return result.select_from(table_set) - - def _add_group_by(self, fragment): - # GROUP BY and HAVING - nkeys = len(self.group_by) - if not nkeys: - return fragment - - if self.context.compiler.supports_indexed_grouping_keys: - group_keys = map(sa.literal_column, map(str, range(1, nkeys + 1))) - else: - group_keys = map(self._translate, self.group_by) - - fragment = fragment.group_by(*group_keys) - - if self.having: - having_args = [self._translate(arg) for arg in self.having] - having_clause = functools.reduce(sql.and_, having_args) - fragment = fragment.having(having_clause) - - return fragment - - def _add_where(self, fragment): - if not self.where: - return fragment - - args = [ - self._translate(pred, permit_subquery=True, within_where=True) - for pred in self.where - ] - clause = functools.reduce(sql.and_, args) - return fragment.where(clause) - - def _add_order_by(self, fragment): - if not self.order_by: - return fragment - - clauses = [] - for key in self.order_by: - sort_expr = key.expr - arg = self._translate(sort_expr) - fn = sa.asc if key.ascending else sa.desc - - clauses.append(fn(arg)) - - return fragment.order_by(*clauses) - - def _among_select_set(self, expr): - return any(expr.equals(other) for other in self.select_set) - - def _add_limit(self, fragment): - if self.limit is None: - return fragment - - frag = fragment - - n = self.limit.n - - if n is None: - n = self.context.compiler.null_limit - elif not isinstance(n, int): - n = ( - sa.select(self._translate(n)) - .select_from(frag.subquery()) - .scalar_subquery() - ) - - if n is not None: - try: - fragment = fragment.limit(n) - except AttributeError: - fragment = fragment.subquery().select().limit(n) - - offset = self.limit.offset - - if not isinstance(offset, int): - offset = ( - sa.select(self._translate(offset)) - .select_from(frag.subquery()) - .scalar_subquery() - ) - - if offset != 0 and n != 0: - fragment = fragment.offset(offset) - return fragment - - -class AlchemySelectBuilder(SelectBuilder): - def _convert_group_by(self, exprs): - return exprs - - def _collect_SQLQueryResult(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [] - - -class AlchemySetOp(SetOp): - def compile(self): - context = self.context - distincts = self.distincts - - assert ( - len(set(distincts)) == 1 - ), "more than one distinct found; this shouldn't be possible because all unions are projected" - - func = self.distinct_func if distincts[0] else self.non_distinct_func - return func( - *(context.get_compiled_expr(table).cte().select() for table in self.tables) - ) - - -class AlchemyUnion(AlchemySetOp): - distinct_func = staticmethod(sa.union) - non_distinct_func = staticmethod(sa.union_all) - - -class AlchemyIntersection(AlchemySetOp): - distinct_func = staticmethod(sa.intersect) - non_distinct_func = staticmethod(sa.intersect_all) - - -class AlchemyDifference(AlchemySetOp): - distinct_func = staticmethod(sa.except_) - non_distinct_func = staticmethod(sa.except_all) - - -class AlchemyCompiler(Compiler): - translator_class = AlchemyExprTranslator - context_class = AlchemyContext - table_set_formatter_class = _AlchemyTableSetFormatter - select_builder_class = AlchemySelectBuilder - select_class = AlchemySelect - union_class = AlchemyUnion - intersect_class = AlchemyIntersection - difference_class = AlchemyDifference - - supports_indexed_grouping_keys = True - - # Value to use when the user specified `n` from the `limit` API is - # `None`. - # - # For some backends this is: - # * the identifier ALL (sa.literal_column('ALL')) - # * a NULL literal (sa.null()) - # - # and some don't accept an unbounded limit at all: the `LIMIT` - # keyword must simply be left out of the query - null_limit = sa.null() - - @classmethod - def to_sql(cls, expr, context=None, params=None, exists=False): - if context is None: - context = cls.make_context(params=params) - query = cls.to_ast(expr, context).queries[0] - if exists: - query.exists = True - return query.compile() diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py deleted file mode 100644 index baa5d5fe287e8..0000000000000 --- a/ibis/backends/base/sql/alchemy/registry.py +++ /dev/null @@ -1,813 +0,0 @@ -from __future__ import annotations - -import contextlib -import functools -import operator -from typing import Any - -import sqlalchemy as sa -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED -from sqlalchemy.sql.functions import FunctionElement, GenericFunction - -import ibis.common.exceptions as com -import ibis.expr.analysis as an -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.types as ir - - -class substr(GenericFunction): - """A generic substr function, so dialects can customize compilation.""" - - type = sa.types.String() - inherit_cache = True - - -class try_cast(GenericFunction): - pass - - -def variance_reduction(func_name, suffix=None): - suffix = suffix or {"sample": "_samp", "pop": "_pop"} - - def variance_compiler(t, op): - arg = op.arg - - if arg.dtype.is_boolean(): - arg = ops.Cast(op.arg, to=dt.int32) - - func = getattr(sa.func, f"{func_name}{suffix[op.how]}") - - if op.where is not None: - arg = ops.IfElse(op.where, arg, None) - - return func(t.translate(arg)) - - return variance_compiler - - -def fixed_arity(sa_func, arity): - def formatter(t, op): - arg_count = len(op.args) - if arity != arg_count: - raise com.IbisError( - f"Incorrect number of args. Expected: {arity}. Current: {arg_count}" - ) - - return _varargs_call(sa_func, t, op.args) - - return formatter - - -def _varargs_call(sa_func, t, args): - trans_args = [] - for raw_arg in args: - arg = t.translate(raw_arg) - with contextlib.suppress(AttributeError): - arg = arg.scalar_subquery() - trans_args.append(arg) - return sa_func(*trans_args) - - -def varargs(sa_func): - def formatter(t, op): - return _varargs_call(sa_func, t, op.arg) - - return formatter - - -def get_sqla_table(ctx, table): - if ctx.has_ref(table, parent_contexts=True): - sa_table = ctx.get_ref(table, search_parents=True) - else: - sa_table = ctx.get_compiled_expr(table) - - return sa_table - - -def get_col(sa_table, op: ops.TableColumn) -> sa.sql.ColumnClause: - """Extract a column from a table.""" - cols = sa_table.exported_columns - colname = op.name - - if (col := cols.get(colname)) is not None: - return col - - # `cols` is a SQLAlchemy column collection that contains columns - # with names that are secretly prefixed by table that contains them - # - # for example, in `t0.join(t1).select(t0.a, t1.b)` t0.a will be named `t0_a` - # and t1.b will be named `t1_b` - # - # unfortunately SQLAlchemy doesn't let you select by the *un*prefixed - # column name despite the uniqueness of `colname` - # - # however, in ibis we have already deduplicated column names so we can - # refer to the name by position - colindex = op.table.schema._name_locs[colname] - return cols[colindex] - - -def _table_column(t, op): - ctx = t.context - table = op.table - - sa_table = get_sqla_table(ctx, table) - - out_expr = get_col(sa_table, op) - out_expr.quote = t._quote_column_names - - # If the column does not originate from the table set in the current SELECT - # context, we should format as a subquery - if t.permit_subquery and ctx.is_foreign_expr(table): - try: - subq = sa_table.subquery() - except AttributeError: - subq = sa_table - return sa.select(subq.c[out_expr.name]) - - return out_expr - - -def _table_array_view(t, op): - # the table that the TableArrayView op contains (op.table) has - # one or more input relations that we need to "pin" for sqlalchemy's - # auto correlation functionality -- this is what `.correlate_except` does - # - # every relation that is NOT passed to `correlate_except` is considered an - # outer-query table - ctx = t.context - table = ctx.get_compiled_expr(op.table) - # TODO: handle the case of `op.table` being a join - first, *_ = an.find_immediate_parent_tables(op.table, keep_input=False) - ref = ctx.get_ref(first) - return table.correlate_except(ref) - - -def _exists_subquery(t, op): - ctx = t.context - - # TODO(kszucs): avoid converting the predicates to expressions - # this should be done by the rewrite step before compilation - filtered = ( - op.foreign_table.to_expr() - .filter([pred.to_expr() for pred in op.predicates]) - .select(ir.literal(1).name("")) - ) - - sub_ctx = ctx.subcontext() - clause = ctx.compiler.to_sql(filtered, sub_ctx, exists=True) - - return clause - - -def _cast(t, op): - arg = op.arg - typ = op.to - arg_dtype = arg.dtype - - sa_arg = t.translate(arg) - - # specialize going from an integer type to a timestamp - if arg_dtype.is_integer() and typ.is_timestamp(): - return t.integer_to_timestamp(sa_arg, tz=typ.timezone) - - if arg_dtype.is_binary() and typ.is_string(): - return sa.func.encode(sa_arg, "escape") - - if typ.is_binary(): - # decode yields a column of memoryview which is annoying to deal with - # in pandas. CAST(expr AS BYTEA) is correct and returns byte strings. - return sa.cast(sa_arg, sa.LargeBinary()) - - if typ.is_json() and not t.native_json_type: - return sa_arg - - return sa.cast(sa_arg, t.get_sqla_type(typ)) - - -def _contains(func): - def translate(t, op): - left = t.translate(op.value) - - options = op.options - if isinstance(options, tuple): - right = [t.translate(x) for x in op.options] - elif options.shape.is_columnar(): - right = t.translate(ops.TableArrayView(options.to_expr().as_table())) - if not isinstance(right, sa.sql.Selectable): - right = sa.select(right) - else: - right = t.translate(options) - - return func(left, right) - - return translate - - -def _in_values(t, op): - if not op.options: - return sa.literal(False) - value = t.translate(op.value) - options = [t.translate(x) for x in op.options] - return value.in_(options) - - -def _in_column(t, op): - value = t.translate(op.value) - options = t.translate(ops.TableArrayView(op.options.to_expr().as_table())) - if not isinstance(options, sa.sql.Selectable): - options = sa.select(options) - return value.in_(options) - - -def _alias(t, op): - # just compile the underlying argument because the naming is handled - # by the translator for the top level expression - return t.translate(op.arg) - - -def _literal(_, op): - dtype = op.dtype - value = op.value - - if value is None: - return sa.null() - - if dtype.is_array(): - value = list(value) - elif dtype.is_decimal(): - value = value.normalize() - - return sa.literal(value) - - -def _is_null(t, op): - arg = t.translate(op.arg) - return arg.is_(sa.null()) - - -def _not_null(t, op): - arg = t.translate(op.arg) - return arg.is_not(sa.null()) - - -def _round(t, op): - sa_arg = t.translate(op.arg) - - f = sa.func.round - - if op.digits is not None: - sa_digits = t.translate(op.digits) - return f(sa_arg, sa_digits) - else: - return f(sa_arg) - - -def _floor_divide(t, op): - left = t.translate(op.left) - right = t.translate(op.right) - return sa.func.floor(left / right) - - -def _simple_case(t, op): - return _translate_case(t, op, value=t.translate(op.base)) - - -def _searched_case(t, op): - return _translate_case(t, op, value=None) - - -def _translate_case(t, op, *, value): - return sa.case( - *zip(map(t.translate, op.cases), map(t.translate, op.results)), - value=value, - else_=t.translate(op.default), - ) - - -def _negate(t, op): - arg = t.translate(op.arg) - return sa.not_(arg) if op.arg.dtype.is_boolean() else -arg - - -def unary(sa_func): - return fixed_arity(sa_func, 1) - - -def _string_like(method_name, t, op): - method = getattr(t.translate(op.arg), method_name) - return method(t.translate(op.pattern), escape=op.escape) - - -def _startswith(t, op): - return t.translate(op.arg).startswith(t.translate(op.start)) - - -def _endswith(t, op): - return t.translate(op.arg).endswith(t.translate(op.end)) - - -def _reinterpret_range_bound(bound): - if bound is None: - return RANGE_UNBOUNDED - - try: - lower = int(bound) - except ValueError as err: - sa.util.raise_( - sa.exc.ArgumentError( - "Integer, None or expression expected for range value" - ), - replace_context=err, - ) - except TypeError: - return bound - else: - return RANGE_CURRENT if lower == 0 else lower - - -def _interpret_range(self, range_): - if not isinstance(range_, tuple) or len(range_) != 2: - raise sa.exc.ArgumentError("2-tuple expected for range/rows") - - lower = _reinterpret_range_bound(range_[0]) - upper = _reinterpret_range_bound(range_[1]) - return lower, upper - - -# monkeypatch to allow expressions in range and rows bounds -sa.sql.elements.Over._interpret_range = _interpret_range - - -def _compile_bounds(processor, left, right) -> str: - if left is RANGE_UNBOUNDED: - left = "UNBOUNDED PRECEDING" - elif left is RANGE_CURRENT: - left = "CURRENT ROW" - else: - left = f"{processor(left)} PRECEDING" - - if right is RANGE_UNBOUNDED: - right = "UNBOUNDED FOLLOWING" - elif right is RANGE_CURRENT: - right = "CURRENT ROW" - else: - right = f"{processor(right)} FOLLOWING" - - return f"BETWEEN {left} AND {right}" - - -@compiles(sa.sql.elements.Over) -def compile_over(over, compiler, **kw) -> str: - processor = functools.partial(compiler.process, **kw) - - text = processor(over.element) - - if over.range_: - bounds = _compile_bounds(processor, *over.range_) - range_ = f"RANGE {bounds}" - elif over.rows: - bounds = _compile_bounds(processor, *over.rows) - range_ = f"ROWS {bounds}" - else: - range_ = None - - args = [ - f"{word} BY {processor(clause)}" - for word, clause in ( - ("PARTITION", over.partition_by), - ("ORDER", over.order_by), - ) - if clause is not None and len(clause) - ] - - if range_ is not None: - args.append(range_) - - return f"{text} OVER ({' '.join(args)})" - - -def _window_function(t, window): - func = window.func.__window_op__ - - reduction = t.translate(func) - - # Some analytic functions need to have the expression of interest in - # the ORDER BY part of the window clause - if isinstance(func, t._require_order_by) and not window.frame.order_by: - order_by = t.translate(func.args[0]) - else: - order_by = [t.translate(arg) for arg in window.frame.order_by] - - partition_by = [t.translate(arg) for arg in window.frame.group_by] - - if isinstance(window.frame, ops.RowsWindowFrame): - if window.frame.max_lookback is not None: - raise NotImplementedError( - "Rows with max lookback is not implemented for SQLAlchemy-based " - "backends." - ) - how = "rows" - elif isinstance(window.frame, ops.RangeWindowFrame): - how = "range_" - else: - raise NotImplementedError(type(window.frame)) - - additional_params = {} - - # some functions on some backends don't support frame clauses - if not t._forbids_frame_clause or not isinstance(func, t._forbids_frame_clause): - if (start := window.frame.start) is not None: - start = t.translate(start.value) - - if (end := window.frame.end) is not None: - end = t.translate(end.value) - - additional_params[how] = (start, end) - - result = sa.over( - reduction, partition_by=partition_by, order_by=order_by, **additional_params - ) - - if isinstance(func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)): - result -= 1 - - return result - - -def _lag(t, op): - if op.default is not None: - raise NotImplementedError() - - sa_arg = t.translate(op.arg) - sa_offset = t.translate(op.offset) if op.offset is not None else 1 - return sa.func.lag(sa_arg, sa_offset) - - -def _lead(t, op): - if op.default is not None: - raise NotImplementedError() - sa_arg = t.translate(op.arg) - sa_offset = t.translate(op.offset) if op.offset is not None else 1 - return sa.func.lead(sa_arg, sa_offset) - - -def _ntile(t, op): - return sa.func.ntile(t.translate(op.buckets)) - - -def _sort_key(t, op): - func = sa.asc if op.ascending else sa.desc - return func(t.translate(op.expr)) - - -def _string_join(t, op): - return sa.func.concat_ws(t.translate(op.sep), *map(t.translate, op.arg)) - - -def reduction(sa_func): - def compile_expr(t, expr): - return t._reduction(sa_func, expr) - - return compile_expr - - -def _substring(t, op): - sa_arg = t.translate(op.arg) - sa_start = t.translate(op.start) + 1 - # Start is an expression, need a runtime branch - sa_arg_length = t.translate(ops.StringLength(op.arg)) - if op.length is None: - return sa.case( - ((sa_start >= 1), sa.func.substr(sa_arg, sa_start)), - else_=sa.func.substr(sa_arg, sa_start + sa_arg_length), - ) - else: - sa_length = t.translate(op.length) - return sa.case( - ((sa_start >= 1), sa.func.substr(sa_arg, sa_start, sa_length)), - else_=sa.func.substr(sa_arg, sa_start + sa_arg_length, sa_length), - ) - - -def _gen_string_find(func): - def string_find(t, op): - if op.end is not None: - raise NotImplementedError("`end` not yet implemented") - - arg = t.translate(op.arg) - sub_string = t.translate(op.substr) - - if (op_start := op.start) is not None: - start = t.translate(op_start) - arg = sa.func.substr(arg, start + 1) - pos = func(arg, sub_string) - return sa.case((pos > 0, pos - 1 + start), else_=-1) - - return func(arg, sub_string) - 1 - - return string_find - - -def _nth_value(t, op): - return sa.func.nth_value(t.translate(op.arg), t.translate(op.nth) + 1) - - -def _bitwise_op(operator): - def translate(t, op): - left = t.translate(op.left) - right = t.translate(op.right) - return left.op(operator)(right) - - return translate - - -def _bitwise_not(t, op): - arg = t.translate(op.arg) - return sa.sql.elements.UnaryExpression( - arg, - operator=sa.sql.operators.custom_op("~"), - ) - - -def _count_star(t, op): - if (where := op.where) is None: - return sa.func.count() - - if t._has_reduction_filter_syntax: - return sa.func.count().filter(t.translate(where)) - - return sa.func.count(t.translate(ops.IfElse(where, 1, None))) - - -def _count_distinct_star(t, op): - schema = op.arg.schema - cols = [sa.column(col, t.get_sqla_type(typ)) for col, typ in schema.items()] - - if t._supports_tuple_syntax: - func = lambda *cols: sa.func.count(sa.distinct(sa.tuple_(*cols))) - else: - func = count_distinct - - if op.where is None: - return func(*cols) - - if t._has_reduction_filter_syntax: - return func(*cols).filter(t.translate(op.where)) - - if not t._supports_tuple_syntax and len(cols) > 1: - raise com.UnsupportedOperationError( - f"{t._dialect_name} backend doesn't support `COUNT(DISTINCT ...)` with a " - "filter with more than one column" - ) - - return sa.func.count(t.translate(ops.IfElse(op.where, sa.distinct(*cols), None))) - - -def _extract(fmt: str): - def translator(t, op: ops.Node): - return sa.cast(sa.extract(fmt, t.translate(op.arg)), sa.SMALLINT) - - return translator - - -class count_distinct(FunctionElement): - inherit_cache = True - - -@compiles(count_distinct) -def compile_count_distinct(element, compiler, **kw): - quote_identifier = compiler.preparer.quote_identifier - clauses = ", ".join( - quote_identifier(compiler.process(clause, **kw)) for clause in element.clauses - ) - return f"COUNT(DISTINCT {clauses})" - - -class array_map(FunctionElement): - pass - - -class array_filter(FunctionElement): - pass - - -sqlalchemy_operation_registry: dict[Any, Any] = { - ops.Alias: _alias, - ops.And: fixed_arity(operator.and_, 2), - ops.Or: fixed_arity(operator.or_, 2), - ops.Xor: fixed_arity(lambda x, y: (x | y) & ~(x & y), 2), - ops.Not: unary(sa.not_), - ops.Abs: unary(sa.func.abs), - ops.Cast: _cast, - ops.Coalesce: varargs(sa.func.coalesce), - ops.NullIf: fixed_arity(sa.func.nullif, 2), - ops.InValues: _in_values, - ops.InSubquery: _in_column, - ops.Count: reduction(sa.func.count), - ops.CountStar: _count_star, - ops.CountDistinctStar: _count_distinct_star, - ops.Sum: reduction(sa.func.sum), - ops.Mean: reduction(sa.func.avg), - ops.Min: reduction(sa.func.min), - ops.Max: reduction(sa.func.max), - ops.Variance: variance_reduction("var"), - ops.StandardDev: variance_reduction("stddev"), - ops.BitAnd: reduction(sa.func.bit_and), - ops.BitOr: reduction(sa.func.bit_or), - ops.BitXor: reduction(sa.func.bit_xor), - ops.CountDistinct: reduction(lambda arg: sa.func.count(arg.distinct())), - ops.ApproxCountDistinct: reduction(lambda arg: sa.func.count(arg.distinct())), - ops.GroupConcat: reduction(sa.func.group_concat), - ops.Between: fixed_arity(sa.between, 3), - ops.IsNull: _is_null, - ops.NotNull: _not_null, - ops.Negate: _negate, - ops.Round: _round, - ops.Literal: _literal, - ops.SimpleCase: _simple_case, - ops.SearchedCase: _searched_case, - ops.Field: _table_column, - ops.ExistsSubquery: _exists_subquery, - # miscellaneous varargs - ops.Least: varargs(sa.func.least), - ops.Greatest: varargs(sa.func.greatest), - # string - ops.Capitalize: unary( - lambda arg: sa.func.concat( - sa.func.upper(sa.func.substr(arg, 1, 1)), - sa.func.lower(sa.func.substr(arg, 2)), - ) - ), - ops.LPad: fixed_arity(sa.func.lpad, 3), - ops.RPad: fixed_arity(sa.func.rpad, 3), - ops.Strip: unary(sa.func.trim), - ops.LStrip: unary(sa.func.ltrim), - ops.RStrip: unary(sa.func.rtrim), - ops.Repeat: fixed_arity(sa.func.repeat, 2), - ops.Reverse: unary(sa.func.reverse), - ops.StrRight: fixed_arity(sa.func.right, 2), - ops.Lowercase: unary(sa.func.lower), - ops.Uppercase: unary(sa.func.upper), - ops.StringAscii: unary(sa.func.ascii), - ops.StringFind: _gen_string_find(sa.func.strpos), - ops.StringLength: unary(sa.func.length), - ops.StringJoin: _string_join, - ops.StringReplace: fixed_arity(sa.func.replace, 3), - ops.StringSQLLike: functools.partial(_string_like, "like"), - ops.StringSQLILike: functools.partial(_string_like, "ilike"), - ops.StartsWith: _startswith, - ops.EndsWith: _endswith, - ops.StringConcat: varargs(sa.func.concat), - ops.Substring: _substring, - # math - ops.Ln: unary(sa.func.ln), - ops.Exp: unary(sa.func.exp), - ops.Sign: unary(sa.func.sign), - ops.Sqrt: unary(sa.func.sqrt), - ops.Ceil: unary(sa.func.ceil), - ops.Floor: unary(sa.func.floor), - ops.Power: fixed_arity(sa.func.pow, 2), - ops.FloorDivide: _floor_divide, - ops.Acos: unary(sa.func.acos), - ops.Asin: unary(sa.func.asin), - ops.Atan: unary(sa.func.atan), - ops.Atan2: fixed_arity(sa.func.atan2, 2), - ops.Cos: unary(sa.func.cos), - ops.Sin: unary(sa.func.sin), - ops.Tan: unary(sa.func.tan), - ops.Cot: unary(sa.func.cot), - ops.Pi: fixed_arity(sa.func.pi, 0), - ops.E: fixed_arity(lambda: sa.func.exp(1), 0), - # other - ops.SortKey: _sort_key, - ops.Date: unary(lambda arg: sa.cast(arg, sa.DATE)), - ops.DateFromYMD: fixed_arity(sa.func.date, 3), - ops.TimeFromHMS: fixed_arity(sa.func.time, 3), - ops.TimestampFromYMDHMS: lambda t, op: sa.func.make_timestamp( - *map(t.translate, op.args) - ), - ops.Degrees: unary(sa.func.degrees), - ops.Radians: unary(sa.func.radians), - ops.RandomScalar: fixed_arity(sa.func.random, 0), - # Binary arithmetic - ops.Add: fixed_arity(operator.add, 2), - ops.Subtract: fixed_arity(operator.sub, 2), - ops.Multiply: fixed_arity(operator.mul, 2), - # XXX `ops.Divide` is overwritten in `translator.py` with a custom - # function `_true_divide`, but for some reason both are required - ops.Divide: fixed_arity(operator.truediv, 2), - ops.Modulus: fixed_arity(operator.mod, 2), - # Comparisons - ops.Equals: fixed_arity(operator.eq, 2), - ops.NotEquals: fixed_arity(operator.ne, 2), - ops.Less: fixed_arity(operator.lt, 2), - ops.LessEqual: fixed_arity(operator.le, 2), - ops.Greater: fixed_arity(operator.gt, 2), - ops.GreaterEqual: fixed_arity(operator.ge, 2), - ops.IdenticalTo: fixed_arity( - sa.sql.expression.ColumnElement.is_not_distinct_from, 2 - ), - ops.IfElse: fixed_arity( - lambda predicate, value_if_true, value_if_false: sa.case( - (predicate, value_if_true), - else_=value_if_false, - ), - 3, - ), - ops.BitwiseAnd: _bitwise_op("&"), - ops.BitwiseOr: _bitwise_op("|"), - ops.BitwiseXor: _bitwise_op("^"), - ops.BitwiseLeftShift: _bitwise_op("<<"), - ops.BitwiseRightShift: _bitwise_op(">>"), - ops.BitwiseNot: _bitwise_not, - ops.JSONGetItem: fixed_arity(lambda x, y: x.op("->")(y), 2), - ops.ExtractYear: _extract("year"), - ops.ExtractQuarter: _extract("quarter"), - ops.ExtractMonth: _extract("month"), - ops.ExtractDay: _extract("day"), - ops.ExtractHour: _extract("hour"), - ops.ExtractMinute: _extract("minute"), - ops.ExtractSecond: _extract("second"), - ops.Time: fixed_arity(lambda arg: sa.cast(arg, sa.TIME), 1), -} - - -sqlalchemy_window_functions_registry = { - ops.Lag: _lag, - ops.Lead: _lead, - ops.NTile: _ntile, - ops.FirstValue: unary(sa.func.first_value), - ops.LastValue: unary(sa.func.last_value), - ops.RowNumber: fixed_arity(sa.func.row_number, 0), - ops.DenseRank: fixed_arity(sa.func.dense_rank, 0), - ops.MinRank: fixed_arity(sa.func.rank, 0), - ops.PercentRank: fixed_arity(sa.func.percent_rank, 0), - ops.CumeDist: fixed_arity(sa.func.cume_dist, 0), - ops.NthValue: _nth_value, - ops.WindowFunction: _window_function, -} - -geospatial_functions = { - ops.GeoArea: unary(sa.func.ST_Area), - ops.GeoAsBinary: unary(sa.func.ST_AsBinary), - ops.GeoAsEWKB: unary(sa.func.ST_AsEWKB), - ops.GeoAsEWKT: unary(sa.func.ST_AsEWKT), - ops.GeoAsText: unary(sa.func.ST_AsText), - ops.GeoAzimuth: fixed_arity(sa.func.ST_Azimuth, 2), - ops.GeoBuffer: fixed_arity(sa.func.ST_Buffer, 2), - ops.GeoCentroid: unary(sa.func.ST_Centroid), - ops.GeoContains: fixed_arity(sa.func.ST_Contains, 2), - ops.GeoContainsProperly: fixed_arity(sa.func.ST_Contains, 2), - ops.GeoCovers: fixed_arity(sa.func.ST_Covers, 2), - ops.GeoCoveredBy: fixed_arity(sa.func.ST_CoveredBy, 2), - ops.GeoCrosses: fixed_arity(sa.func.ST_Crosses, 2), - ops.GeoDFullyWithin: fixed_arity(sa.func.ST_DFullyWithin, 3), - ops.GeoDifference: fixed_arity(sa.func.ST_Difference, 2), - ops.GeoDisjoint: fixed_arity(sa.func.ST_Disjoint, 2), - ops.GeoDistance: fixed_arity(sa.func.ST_Distance, 2), - ops.GeoDWithin: fixed_arity(sa.func.ST_DWithin, 3), - ops.GeoEndPoint: unary(sa.func.ST_EndPoint), - ops.GeoEnvelope: unary(sa.func.ST_Envelope), - ops.GeoEquals: fixed_arity(sa.func.ST_Equals, 2), - ops.GeoGeometryN: fixed_arity(sa.func.ST_GeometryN, 2), - ops.GeoGeometryType: unary(sa.func.ST_GeometryType), - ops.GeoIntersection: fixed_arity(sa.func.ST_Intersection, 2), - ops.GeoIntersects: fixed_arity(sa.func.ST_Intersects, 2), - ops.GeoIsValid: unary(sa.func.ST_IsValid), - ops.GeoLineLocatePoint: fixed_arity(sa.func.ST_LineLocatePoint, 2), - ops.GeoLineMerge: unary(sa.func.ST_LineMerge), - ops.GeoLineSubstring: fixed_arity(sa.func.ST_LineSubstring, 3), - ops.GeoLength: unary(sa.func.ST_Length), - ops.GeoNPoints: unary(sa.func.ST_NPoints), - ops.GeoOrderingEquals: fixed_arity(sa.func.ST_OrderingEquals, 2), - ops.GeoOverlaps: fixed_arity(sa.func.ST_Overlaps, 2), - ops.GeoPerimeter: unary(sa.func.ST_Perimeter), - ops.GeoSimplify: fixed_arity(sa.func.ST_Simplify, 3), - ops.GeoSRID: unary(sa.func.ST_SRID), - ops.GeoSetSRID: fixed_arity(sa.func.ST_SetSRID, 2), - ops.GeoStartPoint: unary(sa.func.ST_StartPoint), - ops.GeoTouches: fixed_arity(sa.func.ST_Touches, 2), - ops.GeoTransform: fixed_arity(sa.func.ST_Transform, 2), - ops.GeoUnaryUnion: unary(sa.func.ST_Union), - ops.GeoUnion: fixed_arity(sa.func.ST_Union, 2), - ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2), - ops.GeoX: unary(sa.func.ST_X), - ops.GeoY: unary(sa.func.ST_Y), - # Missing Geospatial ops: - # ST_AsGML - # ST_AsGeoJSON - # ST_AsKML - # ST_AsRaster - # ST_AsSVG - # ST_AsTWKB - # ST_Distance_Sphere - # ST_Dump - # ST_DumpPoints - # ST_GeogFromText - # ST_GeomFromEWKB - # ST_GeomFromEWKT - # ST_GeomFromText -} diff --git a/ibis/backends/base/sql/alchemy/translator.py b/ibis/backends/base/sql/alchemy/translator.py deleted file mode 100644 index 341f1995aff1d..0000000000000 --- a/ibis/backends/base/sql/alchemy/translator.py +++ /dev/null @@ -1,147 +0,0 @@ -from __future__ import annotations - -import functools -import operator - -import sqlalchemy as sa -from sqlalchemy.engine.default import DefaultDialect - -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.backends.base.sql.alchemy.registry import ( - fixed_arity, - sqlalchemy_operation_registry, -) -from ibis.backends.base.sql.compiler import ExprTranslator, QueryContext - -_DEFAULT_DIALECT = DefaultDialect() - - -class AlchemyContext(QueryContext): - def collapse(self, queries): - if isinstance(queries, str): - return queries - - if len(queries) > 1: - raise NotImplementedError( - "Only a single query is supported for SQLAlchemy backends" - ) - return queries[0] - - def subcontext(self): - return self.__class__( - compiler=self.compiler, - parent=self, - params=self.params, - ) - - -class AlchemyExprTranslator(ExprTranslator): - _registry = sqlalchemy_operation_registry - _rewrites = ExprTranslator._rewrites.copy() - - type_mapper = AlchemyType - context_class = AlchemyContext - - _bool_aggs_need_cast_to_int32 = True - _has_reduction_filter_syntax = False - _supports_tuple_syntax = False - _integer_to_timestamp = staticmethod(sa.func.to_timestamp) - _timestamp_type = sa.TIMESTAMP - - def integer_to_timestamp(self, arg, tz: str | None = None): - return sa.cast( - self._integer_to_timestamp(arg), - self._timestamp_type(timezone=tz is not None), - ) - - native_json_type = True - _quote_column_names = None # let the dialect decide how to quote - _quote_table_names = None - - _require_order_by = ( - ops.DenseRank, - ops.MinRank, - ops.NTile, - ops.PercentRank, - ops.CumeDist, - ) - - _dialect_name = "default" - - supports_unnest_in_select = True - - @classmethod - def get_sqla_type(cls, ibis_type): - return cls.type_mapper.from_ibis(ibis_type) - - @classmethod - def get_ibis_type(cls, sqla_type, nullable=True): - return cls.type_mapper.to_ibis(sqla_type, nullable=nullable) - - @functools.cached_property - def dialect(self) -> sa.engine.interfaces.Dialect: - if (name := self._dialect_name) == "default": - return _DEFAULT_DIALECT - dialect_cls = sa.dialects.registry.load(name) - return dialect_cls() - - def _schema_to_sqlalchemy_columns(self, schema): - return [ - sa.Column(name, self.get_sqla_type(dtype), quote=self._quote_column_names) - for name, dtype in schema.items() - ] - - def name(self, translated, name, force=False): - return translated.label( - sa.sql.quoted_name(name, quote=force or self._quote_column_names) - ) - - def _maybe_cast_bool(self, op, arg): - if ( - self._bool_aggs_need_cast_to_int32 - and isinstance(op, (ops.Sum, ops.Mean, ops.Min, ops.Max)) - and (dtype := arg.dtype).is_boolean() - ): - return ops.Cast(arg, dt.Int32(nullable=dtype.nullable)) - return arg - - def _reduction(self, sa_func, op): - argtuple = ( - self._maybe_cast_bool(op, arg) - for name, arg in zip(op.argnames, op.args) - if isinstance(arg, ops.Node) and name != "where" - ) - if (where := op.where) is not None: - if self._has_reduction_filter_syntax: - sa_args = tuple(map(self.translate, argtuple)) - return sa_func(*sa_args).filter(self.translate(where)) - else: - sa_args = tuple( - self.translate(ops.IfElse(where, arg, None)) for arg in argtuple - ) - else: - sa_args = tuple(map(self.translate, argtuple)) - - return sa_func(*sa_args) - - -rewrites = AlchemyExprTranslator.rewrites - - -# TODO This was previously implemented with the legacy `@compiles` decorator. -# This definition should now be in the registry, but there is some magic going -# on that things fail if it's not defined here (and in the registry -# `operator.truediv` is used. -def _true_divide(t, op): - if all(arg.dtype.is_integer() for arg in op.args): - # TODO(kszucs): this should be done in the rewrite phase - right, left = op.right.to_expr(), op.left.to_expr() - new_expr = left.div(right.cast(dt.double)) - return t.translate(new_expr.op()) - - return fixed_arity(operator.truediv, 2)(t, op) - - -AlchemyExprTranslator._registry[ops.Divide] = _true_divide diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index e919d07dc0d17..af55f4a070310 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -1,14 +1,18 @@ from __future__ import annotations import abc +from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar +from urllib.parse import parse_qs, urlparse import sqlglot as sg import sqlglot.expressions as sge import ibis +import ibis.common.exceptions as exc import ibis.expr.operations as ops import ibis.expr.schema as sch +import ibis.expr.types as ir from ibis import util from ibis.backends.base import BaseBackend from ibis.backends.base.sqlglot.compiler import STAR @@ -20,11 +24,47 @@ import pyarrow as pa import ibis.expr.datatypes as dt - import ibis.expr.types as ir from ibis.backends.base.sqlglot.compiler import SQLGlotCompiler from ibis.common.typing import SupportsSchema +class UrlFromPath: + __slots__ = () + + def _from_url(self, url: str, **kwargs) -> BaseBackend: + """Connect to a backend using a URL `url`. + + Parameters + ---------- + url + URL with which to connect to a backend. + kwargs + Additional keyword arguments + + Returns + ------- + BaseBackend + A backend instance + + """ + url = urlparse(url) + netloc = url.netloc + parts = list(filter(None, (netloc, url.path[bool(netloc) :]))) + database = Path(*parts).absolute() if parts else ":memory:" + query_params = parse_qs(url.query) + + for name, value in query_params.items(): + if len(value) > 1: + kwargs[name] = value + elif len(value) == 1: + kwargs[name] = value[0] + else: + raise exc.IbisError(f"Invalid URL parameter: {name}") + + self._convert_kwargs(kwargs) + return self.connect(database=database, **kwargs) + + class SQLGlotBackend(BaseBackend): compiler: ClassVar[SQLGlotCompiler] name: ClassVar[str] @@ -327,3 +367,71 @@ def to_pyarrow_batches( batches = map(pa.RecordBatch.from_struct_array, arrays) return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches) + + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ) -> None: + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + schema + The name of the schema that the table is located in + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + + """ + if overwrite: + self.truncate_table(table_name, schema=schema, database=database) + + if not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) + + self._run_pre_execute_hooks(obj) + + compiler = self.compiler + quoted = compiler.quoted + query = sge.insert( + expression=self.compile(obj), + into=sg.table(table_name, db=schema, catalog=database, quoted=quoted), + columns=[ + sg.to_identifier(col, quoted=quoted) + for col in self.get_schema(table_name).names + ], + dialect=compiler.dialect, + ) + + with self._safe_raw_sql(query): + pass + + def truncate_table( + self, name: str, database: str | None = None, schema: str | None = None + ) -> None: + """Delete all rows from a table. + + Parameters + ---------- + name + Table name + database + Database name + schema + Schema name + + """ + ident = sg.table( + name, db=schema, catalog=database, quoted=self.compiler.quoted + ).sql(self.compiler.dialect) + with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): + pass diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 8ac2dbffba278..bb9e6447d093d 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -892,7 +892,7 @@ def _from_sqlglot_STRUCT(cls, *cols: sge.ColumnDef) -> NoReturn: class MSSQLType(SqlglotType): - dialect = "tsql" + dialect = "mssql" @classmethod def _from_sqlglot_BIT(cls): diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 4c2be367bfc7a..fa68946f1439c 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -3,10 +3,10 @@ from __future__ import annotations import concurrent.futures +import contextlib import glob import os import re -from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Optional from urllib.parse import parse_qs, urlparse @@ -639,11 +639,12 @@ def _to_sqlglot( self._define_udf_translation_rules(expr) sql = super()._to_sqlglot(expr, limit=limit, params=params, **kwargs) - return sql.transform( + query = sql.transform( _qualify_memtable, dataset=getattr(self._session_dataset, "dataset_id", None), project=getattr(self._session_dataset, "project", None), ).transform(_remove_null_ordering_from_unsupported_window) + return query def raw_sql(self, query: str, params=None): query_parameters = [ @@ -658,6 +659,8 @@ def raw_sql(self, query: str, params=None): ) for param, value in (params or {}).items() ] + with contextlib.suppress(AttributeError): + query = query.sql(self.compiler.dialect) return self._execute(query, query_parameters=query_parameters) @property @@ -722,7 +725,6 @@ def execute(self, expr, params=None, limit="default", **kwargs): # TODO: upstream needs to pass params to raw_sql, I think. kwargs.pop("timecontext", None) - self._register_in_memory_tables(expr) sql = self.compile(expr, limit=limit, params=params, **kwargs) self._log(sql) cursor = self.raw_sql(sql, params=params, **kwargs) @@ -731,6 +733,37 @@ def execute(self, expr, params=None, limit="default", **kwargs): return expr.__pandas_result__(result) + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ): + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + schema + The name of the schema that the table is located in + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + """ + return super().insert( + table_name, + obj, + schema=schema if schema is not None else self.current_schema, + database=database if database is not None else self.current_database, + overwrite=overwrite, + ) + def fetch_from_cursor(self, cursor, schema): from ibis.backends.bigquery.converter import BigQueryPandasData @@ -1173,7 +1206,7 @@ def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: def _register_udfs(self, expr: ir.Expr) -> None: """No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query.""" - @contextmanager + @contextlib.contextmanager def _safe_raw_sql(self, *args, **kwargs): yield self.raw_sql(*args, **kwargs) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 4913fc2e3f602..8142110397e34 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -389,23 +389,31 @@ def execute( def insert( self, name: str, - obj: pd.DataFrame, + obj: pd.DataFrame | ir.Table, settings: Mapping[str, Any] | None = None, + overwrite: bool = False, **kwargs: Any, ): import pandas as pd import pyarrow as pa - if not isinstance(obj, pd.DataFrame): - raise com.IbisError( - f"Invalid input type {type(obj)}; only pandas DataFrames are accepted as input" + if overwrite: + self.truncate_table(name) + + if isinstance(obj, pa.Table): + return self.con.insert_arrow(name, obj, settings=settings, **kwargs) + elif isinstance(obj, pd.DataFrame): + return self.con.insert_arrow( + name, pa.Table.from_pandas(obj), settings=settings, **kwargs ) + elif not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) - # TODO(cpcloud): add support for arrow tables - # TODO(cpcloud): insert_df doesn't work with pandas 2.1.0, move back to - # that (maybe?) when `clickhouse_connect` is fixed - t = pa.Table.from_pandas(obj) - return self.con.insert_arrow(name, t, settings=settings, **kwargs) + query = sge.insert(self.compile(obj), into=name, dialect=self.name) + + external_tables = self._collect_in_memory_tables(obj, {}) + external_data = self._normalize_external_tables(external_tables) + return self.con.command(query.sql(self.name), external_data=external_data) def raw_sql( self, diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index 5fe94de294181..e48fdc5cb5ac4 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from packaging.requirements import Requirement from packaging.version import parse as vparse @@ -25,8 +24,6 @@ from ibis.util import promote_tuple if TYPE_CHECKING: - from collections.abc import Iterable - from ibis.backends.tests.base import BackendTest @@ -181,80 +178,6 @@ def data_dir() -> Path: return root / "ci" / "ibis-testing-data" -def recreate_database( - url: sa.engine.url.URL, - database: str, - **kwargs: Any, -) -> None: - """Drop the `database` at `url`, if it exists. - - Create a new, blank database with the same name. - - Parameters - ---------- - url : url.sa.engine.url.URL - Connection url to the database - database : str - Name of the database to be dropped. - """ - engine = sa.create_engine(url.set(database=""), **kwargs) - - if url.database is not None: - with engine.begin() as con: - con.exec_driver_sql(f"DROP DATABASE IF EXISTS {database}") - con.exec_driver_sql(f"CREATE DATABASE {database}") - - -def init_database( - url: sa.engine.url.URL, - database: str, - schema: Iterable[str] | None = None, - recreate: bool = True, - isolation_level: str | None = "AUTOCOMMIT", - **kwargs: Any, -) -> sa.engine.Engine: - """Initialise `database` at `url` with `schema`. - - If `recreate`, drop the `database` at `url`, if it exists. - - Parameters - ---------- - url : url.sa.engine.url.URL - Connection url to the database - database : str - Name of the database to be dropped - schema : TextIO - File object containing schema to use - recreate : bool - If true, drop the database if it exists - isolation_level : str - Transaction isolation_level - - Returns - ------- - sa.engine.Engine - SQLAlchemy engine object - """ - if isolation_level is not None: - kwargs["isolation_level"] = isolation_level - - if recreate: - recreate_database(url, database, **kwargs) - - try: - url.database = database - except AttributeError: - url = url.set(database=database) - - engine = sa.create_engine(url, **kwargs) - - if schema: - with engine.begin() as conn: - util.consume(map(conn.exec_driver_sql, schema)) - - return engine - - def _get_backend_conf(backend_str: str): """Convert a backend string to the test class for the backend.""" conftest = importlib.import_module(f"ibis.backends.{backend_str}.tests.conftest") @@ -302,7 +225,7 @@ def pytest_ignore_collect(path, config): return False expr = _pytest.mark.expression.Expression.compile(mark_expr) # we check the "backend" marker as well since if that's passed - # any file matching a backed should be skipped + # any file matching a backend should be skipped keep = expr.evaluate(lambda s: s in (backend, "backend")) return not keep @@ -399,7 +322,7 @@ def pytest_runtest_call(item): backend = [ getattr(backend, "name", lambda backend=backend: backend)() for key, backend in item.funcargs.items() - if key.endswith(("backend", "backend_name")) + if key.endswith(("backend", "backend_name", "backend_no_data")) ] if len(backend) > 1: raise ValueError( @@ -517,6 +440,19 @@ def con(backend): return backend.connection +@pytest.fixture(params=_get_backends_to_test(), scope="session") +def backend_no_data(request, data_dir, tmp_path_factory, worker_id): + """Return an instance of BackendTest, with no data loaded.""" + cls = _get_backend_conf(request.param) + return cls(data_dir=data_dir, tmpdir=tmp_path_factory, worker_id=worker_id) + + +@pytest.fixture(scope="session") +def con_no_data(backend_no_data): + """Return an Ibis backend instance, with no data loaded.""" + return backend_no_data.connection + + @pytest.fixture(scope="session") def con_create_database(con): if isinstance(con, CanCreateDatabase): @@ -567,21 +503,6 @@ def ddl_con(ddl_backend): return ddl_backend.connection -@pytest.fixture( - params=_get_backends_to_test(keep=("risingwave",)), - scope="session", -) -def alchemy_backend(request, data_dir, tmp_path_factory, worker_id): - """Set up the SQLAlchemy-based backends.""" - pytest.skip("No SQLAlchemy backends remaining") - - -@pytest.fixture(scope="session") -def alchemy_con(alchemy_backend): - """Instance of Client, already connected to the db (if applies).""" - pytest.skip("No SQLAlchemy backends remaining") - - @pytest.fixture( params=_get_backends_to_test(keep=("dask", "pandas", "pyspark")), scope="session", @@ -675,25 +596,6 @@ def geo_df(geo): return None -@pytest.fixture -def alchemy_temp_table(alchemy_con) -> str: - """Return a temporary table name. - - Parameters - ---------- - alchemy_con : ibis.backends.base.Client - - Yields - ------ - name : string - Random table name for a temporary usage. - """ - name = util.gen_name("alchemy_table") - yield name - with contextlib.suppress(NotImplementedError): - alchemy_con.drop_table(name, force=True) - - @pytest.fixture def temp_table(con) -> str: """Return a temporary table name. diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index cc4890fbdcf0e..d2d384ff04322 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -398,10 +398,15 @@ def fn(node, _, **kwargs): def execute(cls, node, backend, params): original = node node = node.to_expr().as_table().op() - df = cls.compile(node, backend=backend, params=params) - assert isinstance(df, dd.DataFrame) + result = cls.compile(node, backend=backend, params=params) + + # should happen when the result is empty + if isinstance(result, pd.DataFrame): + assert result.empty + else: + assert isinstance(result, dd.DataFrame) + result = result.compute() - result = df.compute() result = PandasData.convert_table(result, node.schema) if isinstance(original, ops.Value): if original.shape.is_scalar(): diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index dbe8f90d989e6..68288f77dcfd2 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -3,6 +3,7 @@ import contextlib import inspect import typing +from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any @@ -38,7 +39,7 @@ SessionConfig = None if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Iterator import pandas as pd @@ -77,6 +78,8 @@ def do_connect( if isinstance(config, SessionContext): (self.con, config) = (config, None) else: + if config is not None and not isinstance(config, Mapping): + raise TypeError("Input to ibis.datafusion.connect must be a mapping") if SessionConfig is not None: df_config = SessionConfig( {"datafusion.sql_parser.dialect": "PostgreSQL"} @@ -95,7 +98,7 @@ def do_connect( @contextlib.contextmanager def _safe_raw_sql(self, sql: sge.Statement) -> Any: - yield self.raw_sql(sql) + yield self.raw_sql(sql).collect() def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: name = gen_name("datafusion_metadata_view") @@ -106,7 +109,9 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: expression=sg.parse_one(query, read="datafusion"), properties=sge.Properties(expressions=[sge.TemporaryProperty()]), ) - self.raw_sql(src) + + with self._safe_raw_sql(src): + pass try: result = ( @@ -183,7 +188,7 @@ def _compile_elementwise_udf(self, udf_node): name=udf_node.func.__name__, ) - def raw_sql(self, query: str | sg.exp.Expression) -> Any: + def raw_sql(self, query: str | sge.Expression) -> Any: """Execute a SQL string `query` against the database. Parameters @@ -217,9 +222,10 @@ def list_databases(self, like: str | None = None) -> list[str]: return self._filter_with_like(result["table_catalog"], like) def create_database(self, name: str, force: bool = False) -> None: - self.raw_sql( - sg.exp.Create(kind="DATABASE", this=sg.to_identifier(name), exists=force) - ) + with self._safe_raw_sql( + sge.Create(kind="DATABASE", this=sg.to_identifier(name), exists=force) + ): + pass def drop_database(self, name: str, force: bool = False) -> None: raise com.UnsupportedOperationError( @@ -241,13 +247,19 @@ def create_schema( ) -> None: # not actually a table, but this is how sqlglot represents schema names schema_name = sg.table(name, db=database) - self.raw_sql(sg.exp.Create(kind="SCHEMA", this=schema_name, exists=force)) + with self._safe_raw_sql( + sge.Create(kind="SCHEMA", this=schema_name, exists=force) + ): + pass def drop_schema( self, name: str, database: str | None = None, force: bool = False ) -> None: schema_name = sg.table(name, db=database) - self.raw_sql(sg.exp.Drop(kind="SCHEMA", this=schema_name, exists=force)) + with self._safe_raw_sql( + sge.Drop(kind="SCHEMA", this=schema_name, exists=force) + ): + pass def list_tables( self, @@ -532,8 +544,8 @@ def create_table( database: str | None = None, temp: bool = False, overwrite: bool = False, - ) -> ir.Table: - """Create a table in DataFusion. + ): + """Create a table in Datafusion. Parameters ---------- @@ -558,30 +570,12 @@ def create_table( if obj is None and schema is None: raise ValueError("Either `obj` or `schema` must be specified") - column_defs = [ - sg.exp.ColumnDef( - this=sg.to_identifier(name, quoted=self.compiler.quoted), - kind=self.compiler.type_mapper.from_ibis(typ), - constraints=( - None - if typ.nullable - else [ - sg.exp.ColumnConstraint(kind=sg.exp.NotNullColumnConstraint()) - ] - ), - ) - for name, typ in (schema or {}).items() - ] - - target = sg.table(name, db=database, quoted=self.compiler.quoted) - - if column_defs: - target = sg.exp.Schema(this=target, expressions=column_defs) - properties = [] if temp: - properties.append(sg.exp.TemporaryProperty()) + properties.append(sge.TemporaryProperty()) + + quoted = self.compiler.quoted if obj is not None: if not isinstance(obj, ir.Expr): @@ -591,19 +585,72 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + relname = "_" + query = sg.select( + *( + self.compiler.cast( + sg.column(col, table=relname, quoted=quoted), dtype + ).as_(col, quoted=quoted) + for col, dtype in table.schema().items() + ) + ).from_( + self._to_sqlglot(table).subquery( + sg.to_identifier(relname, quoted=quoted) + ) + ) else: query = None - create_stmt = sg.exp.Create( + table_ident = sg.to_identifier(name, quoted=quoted) + + if query is None: + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(colname, quoted=quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for colname, typ in (schema or table.schema()).items() + ] + + target = sge.Schema(this=table_ident, expressions=column_defs) + else: + target = table_ident + + create_stmt = sge.Create( kind="TABLE", this=target, - replace=overwrite, - properties=sg.exp.Properties(expressions=properties), + properties=sge.Properties(expressions=properties), expression=query, + replace=overwrite, ) with self._safe_raw_sql(create_stmt): pass return self.table(name, schema=database) + + def truncate_table( + self, name: str, database: str | None = None, schema: str | None = None + ) -> None: + """Delete all rows from a table. + + Parameters + ---------- + name + Table name + database + Database name + schema + Schema name + """ + # datafusion doesn't support `TRUNCATE TABLE` so we use `DELETE FROM` + # + # however datafusion as of 34.0.0 doesn't implement DELETE DML yet + ident = sg.table(name, db=schema, catalog=database).sql(self.name) + with self._safe_raw_sql(sge.delete(ident)): + pass diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index f090819e29761..755f0c3667476 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -505,7 +505,7 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): # datafusion lower cases all column names internally unless quoted so # quoted=True is required here for correctness by_names_quoted = tuple( - sg.column(key, table=getattr(value, "table", None), quoted=True) + sg.column(key, table=getattr(value, "table", None), quoted=quoted) for key, value in groups.items() ) selections = by_names_quoted + metrics diff --git a/ibis/backends/druid/__init__.py b/ibis/backends/druid/__init__.py index 16cf0002eb5ee..fe68977c8a41a 100644 --- a/ibis/backends/druid/__init__.py +++ b/ibis/backends/druid/__init__.py @@ -148,6 +148,9 @@ def create_table( ) -> ir.Table: raise NotImplementedError() + def drop_table(self, *args, **kwargs): + raise NotImplementedError() + def list_tables( self, like: str | None = None, database: str | None = None ) -> list[str]: diff --git a/ibis/backends/druid/tests/conftest.py b/ibis/backends/druid/tests/conftest.py index c3c216a79a88f..f37a9e5a92794 100644 --- a/ibis/backends/druid/tests/conftest.py +++ b/ibis/backends/druid/tests/conftest.py @@ -107,9 +107,7 @@ class TestConf(ServiceBackendTest): @property def functional_alltypes(self) -> ir.Table: - t = self.connection.table( - self.default_identifier_case_fn("functional_alltypes") - ) + t = self.connection.table("functional_alltypes") # The parquet loading for booleans appears to be broken in Druid, so # I'm using this as a workaround to make the data match what's on disk. return t.mutate(bool_col=1 - t.id % 2) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index d0bcb68a0386f..bf1f0be7212e0 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -9,10 +9,8 @@ from operator import itemgetter from pathlib import Path from typing import TYPE_CHECKING, Any -from urllib.parse import parse_qs, urlparse import duckdb -import pandas as pd import pyarrow as pa import pyarrow_hotfix # noqa: F401 import sqlglot as sg @@ -26,7 +24,7 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base import CanCreateSchema -from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.base.sqlglot import SQLGlotBackend, UrlFromPath from ibis.backends.base.sqlglot.compiler import STAR, C, F from ibis.backends.duckdb.compiler import DuckDBCompiler from ibis.backends.duckdb.datatypes import DuckDBPandasData @@ -35,11 +33,10 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence + import pandas as pd import torch from fsspec import AbstractFileSystem - from ibis.backends.base.sql import BaseBackend - def normalize_filenames(source_list): # Promote to list @@ -77,7 +74,7 @@ def __repr__(self): return repr(dict(zip(kv["key"], kv["value"]))) -class Backend(SQLGlotBackend, CanCreateSchema): +class Backend(SQLGlotBackend, CanCreateSchema, UrlFromPath): name = "duckdb" compiler = DuckDBCompiler() @@ -215,9 +212,7 @@ def create_table( final_table = sg.table(name, catalog=database, quoted=self.compiler.quoted) with self._safe_raw_sql(create_stmt) as cur: if query is not None: - insert_stmt = sge.Insert(this=initial_table, expression=query).sql( - self.name - ) + insert_stmt = sge.insert(query, into=initial_table).sql(self.name) cur.execute(insert_stmt).fetchall() if overwrite: @@ -454,7 +449,6 @@ def do_connect( pass self._record_batch_readers_consumed = {} - self._temp_views: set[str] = set() def _load_extensions( self, extensions: list[str], force_install: bool = False @@ -473,38 +467,6 @@ def _load_extensions( cur.install_extension(extension, force_install=force_install) cur.load_extension(extension) - # TODO(kszucs): should be a classmethod - def _from_url(self, url: str, **kwargs) -> BaseBackend: - """Connect to a backend using a URL `url`. - - Parameters - ---------- - url - URL with which to connect to a backend. - kwargs - Additional keyword arguments - - Returns - ------- - BaseBackend - A backend instance - - """ - url = urlparse(url) - database = url.path or ":memory:" - query_params = parse_qs(url.query) - - for name, value in query_params.items(): - if len(value) > 1: - kwargs[name] = value - elif len(value) == 1: - kwargs[name] = value[0] - else: - raise exc.IbisError(f"Invalid URL parameter: {name}") - - self._convert_kwargs(kwargs) - return self.connect(database=database, **kwargs) - def load_extension(self, extension: str, force_install: bool = False) -> None: """Install and load a duckdb extension by name or path. @@ -1514,52 +1476,6 @@ def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: def _compile_pandas_udf(self, _: ops.ScalarUDF) -> None: raise NotImplementedError("duckdb doesn't support pandas UDFs") - def insert( - self, - table_name: str, - obj: pd.DataFrame | ir.Table | list | dict, - database: str | None = None, - overwrite: bool = False, - ) -> None: - """Insert data into a table. - - Parameters - ---------- - table_name - The name of the table to which data needs will be inserted - obj - The source data or expression to insert - database - Name of the attached database that the table is located in. - overwrite - If `True` then replace existing contents of table - - Raises - ------ - NotImplementedError - If inserting data from a different database - ValueError - If the type of `obj` isn't supported - - """ - table = sg.table(table_name, db=database) - if overwrite: - with self._safe_raw_sql(f"TRUNCATE TABLE {table.sql('duckdb')}"): - pass - - if isinstance(obj, ir.Table): - self._run_pre_execute_hooks(obj) - query = sge.insert( - expression=self.compile(obj), into=table, dialect="duckdb" - ) - with self._safe_raw_sql(query): - pass - else: - self.con.append( - table_name, - obj if isinstance(obj, pd.DataFrame) else pd.DataFrame(obj), - ) - def _get_temp_view_definition(self, name: str, definition: str) -> str: return sge.Create( this=sg.to_identifier(name, quoted=self.compiler.quoted), diff --git a/ibis/backends/duckdb/tests/conftest.py b/ibis/backends/duckdb/tests/conftest.py index cc949ab9f9b57..7f785c6f3c5ca 100644 --- a/ibis/backends/duckdb/tests/conftest.py +++ b/ibis/backends/duckdb/tests/conftest.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any from ibis.backends.base import BaseBackend @@ -40,6 +39,7 @@ class TestConf(BackendTest): deps = ("duckdb",) stateful = False supports_tpch = True + driver_supports_multiple_statements = True def preload(self): if not SANDBOXED: @@ -95,11 +95,6 @@ def load_tpch(self) -> None: with self.connection._safe_raw_sql("CALL dbgen(sf=0.17)"): pass - def _load_data(self, **_: Any) -> None: - """Load test data into a backend.""" - with self.connection._safe_raw_sql(";\n".join(self.ddl_script)): - pass - @pytest.fixture(scope="session") def con(data_dir, tmp_path_factory, worker_id): diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py index f2cfa56195688..f94241f58c006 100644 --- a/ibis/backends/duckdb/tests/test_client.py +++ b/ibis/backends/duckdb/tests/test_client.py @@ -243,9 +243,8 @@ def test_default_backend(): param(lambda p: f"duckdb://{p}", id="absolute-path"), param( lambda p: f"duckdb://{os.path.relpath(p)}", - marks=[ - not_windows - ], # hard to test in CI since tmpdir & cwd are on different drives + # hard to test in CI since tmpdir & cwd are on different drives + marks=[not_windows], id="relative-path", ), param(lambda _: "duckdb://", id="in-memory-empty"), diff --git a/ibis/backends/exasol/__init__.py b/ibis/backends/exasol/__init__.py index 11e88170493c6..39ad83ab8c188 100644 --- a/ibis/backends/exasol/__init__.py +++ b/ibis/backends/exasol/__init__.py @@ -93,7 +93,7 @@ def do_connect( ) def _from_url(self, url: str, **kwargs) -> BaseBackend: - """Construct an ibis backend from a SQLAlchemy-conforming URL.""" + """Construct an ibis backend from a URL.""" url = urlparse(url) query_params = parse_qs(url.query) kwargs = { diff --git a/ibis/backends/flink/translator.py b/ibis/backends/flink/translator.py index eea9edba182b7..37bbcc0170ab7 100644 --- a/ibis/backends/flink/translator.py +++ b/ibis/backends/flink/translator.py @@ -6,9 +6,7 @@ class FlinkExprTranslator(ExprTranslator): - _dialect_name = ( - "hive" # TODO: neither sqlglot nor sqlalchemy supports flink dialect - ) + _dialect_name = "hive" # TODO: make a custom sqlglot dialect for Flink _registry = operation_registry _bool_aggs_need_cast_to_int32 = True diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index ac5e10c71176c..1c8833d208978 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -10,7 +10,6 @@ from urllib.parse import parse_qs, urlparse import impala.dbapi as impyla -import pandas as pd import sqlglot as sg import sqlglot.expressions as sge from impala.error import Error as ImpylaError @@ -48,6 +47,7 @@ from collections.abc import Iterator, Mapping from pathlib import Path + import pandas as pd import pyarrow as pa import ibis.expr.operations as ops @@ -513,8 +513,8 @@ def create_table( raise NotImplementedError if obj is not None: - if isinstance(obj, pd.DataFrame): - raise NotImplementedError("Pandas DataFrames not yet supported") + if not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) self._run_pre_execute_hooks(obj) diff --git a/ibis/backends/impala/client.py b/ibis/backends/impala/client.py index ab27b8f7ee184..66f54be14debf 100644 --- a/ibis/backends/impala/client.py +++ b/ibis/backends/impala/client.py @@ -1,14 +1,19 @@ from __future__ import annotations -import pandas as pd +from typing import TYPE_CHECKING + import sqlglot as sg +import ibis import ibis.common.exceptions as com import ibis.expr.schema as sch import ibis.expr.types as ir from ibis.backends.base.sql.ddl import AlterTable, InsertSelect from ibis.backends.impala import ddl +if TYPE_CHECKING: + import pandas as pd + class ImpalaTable(ir.Table): """A physical table in the Impala-Hive metastore.""" @@ -101,8 +106,10 @@ def insert( if values is not None: raise NotImplementedError - if isinstance(obj, pd.DataFrame): - raise NotImplementedError("Pandas DataFrames not yet supported") + if not isinstance(obj, ir.Table): + obj = ibis.memtable(obj) + + self._client._run_pre_execute_hooks(obj) expr = obj if validate: diff --git a/ibis/backends/impala/tests/test_partition.py b/ibis/backends/impala/tests/test_partition.py index 4549ae4652a28..f11e82c5d9a32 100644 --- a/ibis/backends/impala/tests/test_partition.py +++ b/ibis/backends/impala/tests/test_partition.py @@ -104,7 +104,6 @@ def test_unpartitioned_table_get_schema(con): con.table(tname).partition_schema() -@pytest.mark.xfail(raises=NotImplementedError) def test_insert_select_partitioned_table(con, df, temp_table, unpart_t): part_keys = ["year", "month"] @@ -202,7 +201,9 @@ def test_add_drop_partition_hive_bug(con, temp_table): assert len(table.partitions()) == 1 -@pytest.mark.xfail(raises=NotImplementedError) +@pytest.mark.xfail( + raises=AttributeError, reason="test is bogus and needs to be rewritten" +) def test_load_data_partition(con, tmp_dir, unpart_t, df, temp_table): part_keys = ["year", "month"] diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 5f19d076ed36e..cf7371746cbe8 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -29,7 +29,13 @@ from ibis.common.patterns import replace from ibis.expr.rewrites import p, rewrite_sample -TSQL.Generator.TRANSFORMS |= { + +class MSSQL(TSQL): + class Generator(TSQL.Generator): + pass + + +MSSQL.Generator.TRANSFORMS |= { sge.ApproxDistinct: rename_func("approx_count_distinct"), sge.Stddev: rename_func("stdevp"), sge.StddevPop: rename_func("stdevp"), @@ -73,7 +79,7 @@ def exclude_unsupported_window_frame_from_ops(_, y): class MSSQLCompiler(SQLGlotCompiler): __slots__ = () - dialect = "tsql" + dialect = "mssql" type_mapper = MSSQLType rewrites = ( rewrite_sample, diff --git a/ibis/backends/mssql/tests/conftest.py b/ibis/backends/mssql/tests/conftest.py index 414c34a47ecfa..36db711908765 100644 --- a/ibis/backends/mssql/tests/conftest.py +++ b/ibis/backends/mssql/tests/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -36,20 +36,6 @@ class TestConf(ServiceBackendTest): def test_files(self) -> Iterable[Path]: return self.data_dir.joinpath("csv").glob("*.csv") - def _load_data(self, **_: Any) -> None: - """Load test data into a MSSQL backend instance. - - Parameters - ---------- - data_dir - Location of testdata - script_dir - Location of scripts defining schemas - """ - with self.connection.begin() as cur: - for stmt in self.ddl_script: - cur.execute(stmt) - @staticmethod def connect(*, tmpdir, worker_id, **kw): return ibis.mssql.connect( diff --git a/ibis/backends/mysql/tests/conftest.py b/ibis/backends/mysql/tests/conftest.py index c7cadc448fd3c..9491b8328bd38 100644 --- a/ibis/backends/mysql/tests/conftest.py +++ b/ibis/backends/mysql/tests/conftest.py @@ -36,16 +36,7 @@ class TestConf(ServiceBackendTest): def test_files(self) -> Iterable[Path]: return self.data_dir.joinpath("csv").glob("*.csv") - def _load_data( - self, - *, - user: str = MYSQL_USER, - password: str = MYSQL_PASS, - host: str = MYSQL_HOST, - port: int = MYSQL_PORT, - database: str = IBIS_TEST_MYSQL_DB, - **_: Any, - ) -> None: + def _load_data(self, **kwargs: Any) -> None: """Load test data into a MySql backend instance. Parameters @@ -55,10 +46,9 @@ def _load_data( script_dir Location of scripts defining schemas """ - with self.connection.begin() as cur: - for stmt in self.ddl_script: - cur.execute(stmt) + super()._load_data(**kwargs) + with self.connection.begin() as cur: for table in TEST_TABLES: csv_path = self.data_dir / "csv" / f"{table}.csv" lines = [ diff --git a/ibis/backends/oracle/tests/conftest.py b/ibis/backends/oracle/tests/conftest.py index ee27b4cc0f1f7..9210217304f10 100644 --- a/ibis/backends/oracle/tests/conftest.py +++ b/ibis/backends/oracle/tests/conftest.py @@ -82,7 +82,7 @@ def _load_data( "docker", "compose", "exec", - "oracle", + self.service_name, "./createAppUser", user, password, @@ -107,7 +107,7 @@ def _load_data( "docker", "compose", "exec", - "oracle", + self.service_name, "sqlldr", f"{user}/{password}@{host}:{port:d}/{database}", f"control=data/{ctl_file.name}", diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index f4c7268d0416e..6a524a4081015 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Mapping from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any @@ -22,7 +23,7 @@ from ibis.util import gen_name, normalize_filename if TYPE_CHECKING: - from collections.abc import Iterable, Mapping, MutableMapping + from collections.abc import Iterable import pandas as pd import pyarrow as pa @@ -40,7 +41,7 @@ def __init__(self, *args, **kwargs): self._context = pl.SQLContext() def do_connect( - self, tables: MutableMapping[str, pl.LazyFrame | pl.DataFrame] | None = None + self, tables: Mapping[str, pl.LazyFrame | pl.DataFrame] | None = None ) -> None: """Construct a client from a dictionary of polars `LazyFrame`s and/or `DataFrame`s. @@ -50,6 +51,12 @@ def do_connect( An optional mapping of string table names to polars LazyFrames. """ + if tables is not None and not isinstance(tables, Mapping): + raise TypeError("Input to ibis.polars.connect must be a mapping") + + # tables are emphemeral + self._tables.clear() + for name, table in (tables or {}).items(): self._add_table(name, table) @@ -347,10 +354,10 @@ def create_table( "effect: Polars cannot set a database." ) - if temp: + if temp is False: raise com.IbisError( - "Passing `temp=True` to the Polars backend create_table method has no " - "effect: all tables are in memory and temporary. " + "Passing `temp=False` to the Polars backend create_table method is not " + "supported: all tables are in memory and temporary." ) if not overwrite and name in self._tables: diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 10365658518b2..ae1ba51ac10d8 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -766,18 +766,3 @@ def _to_sqlglot( if conversions: table_expr = table_expr.mutate(**conversions) return super()._to_sqlglot(table_expr, limit=limit, params=params) - - def truncate_table(self, name: str, database: str | None = None) -> None: - """Delete all rows from a table. - - Parameters - ---------- - name - Table name - database - Schema name - - """ - ident = sg.table(name, db=database).sql(self.dialect) - with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): - pass diff --git a/ibis/backends/postgres/tests/conftest.py b/ibis/backends/postgres/tests/conftest.py index 7fd19bcabdd83..0ee32b99bc265 100644 --- a/ibis/backends/postgres/tests/conftest.py +++ b/ibis/backends/postgres/tests/conftest.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -50,23 +50,12 @@ class TestConf(ServiceBackendTest): service_name = "postgres" deps = ("psycopg2",) + driver_supports_multiple_statements = True + @property def test_files(self) -> Iterable[Path]: return self.data_dir.joinpath("csv").glob("*.csv") - def _load_data(self, **_: Any) -> None: - """Load test data into a PostgreSQL backend instance. - - Parameters - ---------- - data_dir - Location of test data - script_dir - Location of scripts defining schemas - """ - with self.connection._safe_raw_sql(";".join(self.ddl_script)): - pass - @staticmethod def connect(*, tmpdir, worker_id, **kw): return ibis.postgres.connect( diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index ac07e43da9402..b186502f4cfed 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -397,23 +397,6 @@ def create_table( return self.table(name, database=database) - # TODO(kszucs): should have this implementation in the base sqlglot backend - def truncate_table(self, name: str, database: str | None = None) -> None: - """Delete all rows from an existing table. - - Parameters - ---------- - name - Table name - database - Database name - - """ - table = sg.table(name, db=database) - query = f"TRUNCATE TABLE {table}" - with self._safe_raw_sql(query): - pass - def create_view( self, name: str, @@ -478,34 +461,6 @@ def rename_table(self, old_name: str, new_name: str) -> None: with self._safe_raw_sql(query): pass - def insert( - self, - table_name: str, - obj: ir.Table | pd.DataFrame | None = None, - database: str | None = None, - overwrite: bool = False, - ) -> Any: - """Insert data into an existing table. - - Examples - -------- - >>> table = "my_table" - >>> con.insert(table, table_expr) # quartodoc: +SKIP # doctest: +SKIP - - # Completely overwrite contents - >>> con.insert(table, table_expr, overwrite=True) # quartodoc: +SKIP # doctest: +SKIP - - """ - - if isinstance(obj, ir.Expr): - df = self._session.sql(self.compile(obj)) - else: - table = ibis.memtable(obj) - df = self._session.createDataFrame(table.op().data.to_frame()) - - with self._active_database(database): - df.write.insertInto(table_name, overwrite=overwrite) - def compute_stats( self, name: str, diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index 996f776fd12e9..43191503b90ef 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -109,8 +109,6 @@ def do_connect( with self.begin() as cur: cur.execute("SET TIMEZONE = UTC") - self._temp_views = set() - def create_table( self, name: str, diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 357056a1ac961..a98127771cb89 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -274,7 +274,6 @@ def do_connect(self, create_object_udfs: bool = True, **kwargs: Any): f"Unable to create Ibis UDFs, some functionality will not work: {e}" ) self.con = con - self._temp_views: set[str] = set() def _get_udf_source(self, udf_node: ops.ScalarUDF): name = type(udf_node).__name__ @@ -1035,48 +1034,3 @@ def read_parquet( cur.execute(f"COPY INTO {qtable} FROM (SELECT {cols} FROM @{stage})") return self.table(table) - - def insert( - self, - table_name: str, - obj: pd.DataFrame | ir.Table | list | dict, - schema: str | None = None, - database: str | None = None, - overwrite: bool = False, - ) -> None: - """Insert data into a table. - - Parameters - ---------- - table_name - The name of the table to which data needs will be inserted - obj - The source data or expression to insert - schema - The name of the schema that the table is located in - database - Name of the attached database that the table is located in. - overwrite - If `True` then replace existing contents of table - - """ - if not isinstance(obj, ir.Table): - obj = ibis.memtable(obj) - - table = sg.table(table_name, db=schema, catalog=database, quoted=True) - self._run_pre_execute_hooks(obj) - query = sg.exp.insert( - expression=self.compile(obj), - into=table, - columns=[sg.column(col, quoted=True) for col in obj.columns], - dialect=self.name, - ) - - statements = [] - if overwrite: - statements.append(f"TRUNCATE TABLE {table.sql(self.name)}") - statements.append(query.sql(self.name)) - - statement = ";".join(statements) - with self._safe_raw_sql(statement): - pass diff --git a/ibis/backends/snowflake/tests/conftest.py b/ibis/backends/snowflake/tests/conftest.py index 6d041a3c5a8b7..924cdf7b73eae 100644 --- a/ibis/backends/snowflake/tests/conftest.py +++ b/ibis/backends/snowflake/tests/conftest.py @@ -68,7 +68,6 @@ def copy_into(con, data_dir: Path, table: str) -> None: class TestConf(BackendTest): supports_map = True - default_identifier_case_fn = staticmethod(str.upper) deps = ("snowflake.connector",) supports_tpch = True @@ -76,10 +75,9 @@ def load_tpch(self) -> None: """No-op, snowflake already defines these in `SNOWFLAKE_SAMPLE_DATA`.""" def _tpch_table(self, name: str): + name = name.upper() t = self.connection.table( - self.default_identifier_case_fn(name), - database="SNOWFLAKE_SAMPLE_DATA", - schema="TPCH_SF1", + name, database="SNOWFLAKE_SAMPLE_DATA", schema="TPCH_SF1" ) return t.rename("snake_case") diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index 906137627f49d..2ca2acf6977e5 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -4,7 +4,6 @@ import functools import sqlite3 from typing import TYPE_CHECKING, Any, NoReturn -from urllib.parse import urlparse import sqlglot as sg import sqlglot.expressions as sge @@ -16,7 +15,7 @@ import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util -from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.base.sqlglot import SQLGlotBackend, UrlFromPath from ibis.backends.base.sqlglot.compiler import C, F from ibis.backends.sqlite.compiler import SQLiteCompiler from ibis.backends.sqlite.converter import SQLitePandasData @@ -34,15 +33,15 @@ def _init_sqlite3(): import pandas as pd - # TODO: can we remove this? - sqlite3.register_adapter(pd.Timestamp, lambda value: value.isoformat()) + # required to support pandas Timestamp's from user input + sqlite3.register_adapter(pd.Timestamp, pd.Timestamp.isoformat) def _quote(name: str) -> str: return sg.to_identifier(name, quoted=True).sql("sqlite") -class Backend(SQLGlotBackend): +class Backend(SQLGlotBackend, UrlFromPath): name = "sqlite" compiler = SQLiteCompiler() supports_python_udfs = True @@ -89,31 +88,10 @@ def do_connect( self._type_map = {} self.con = sqlite3.connect(":memory:" if database is None else database) - self._temp_views = set() register_all(self.con) self.con.execute("PRAGMA case_sensitive_like=ON") - def _from_url(self, url: str, **kwargs): - """Connect to a backend using a URL `url`. - - Parameters - ---------- - url - URL with which to connect to a backend. - kwargs - Additional keyword arguments - - Returns - ------- - BaseBackend - A backend instance - - """ - url = urlparse(url) - database = url.path or ":memory:" - return self.connect(database=database, **kwargs) - def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: if not isinstance(query, str): query = query.sql(dialect=self.name) @@ -513,25 +491,6 @@ def drop_table( with self._safe_raw_sql(drop_stmt): pass - def _create_temp_view(self, table_name, source): - if table_name not in self._temp_views and table_name in self.list_tables(): - raise ValueError( - f"{table_name} already exists as a non-temporary table or view" - ) - - view = sg.table(table_name, catalog="temp", quoted=self.compiler.quoted) - drop = sge.Drop(kind="VIEW", exists=True, this=view).sql(self.name) - create = sge.Create( - kind="VIEW", this=view, expression=source, replace=False - ).sql(self.name) - - with self.begin() as cur: - cur.execute(drop) - cur.execute(create) - - self._temp_views.add(table_name) - self._register_temp_view_cleanup(table_name) - def create_view( self, name: str, diff --git a/ibis/backends/sqlite/tests/test_client.py b/ibis/backends/sqlite/tests/test_client.py index 3d4553eaa98ec..7e57bcb431454 100644 --- a/ibis/backends/sqlite/tests/test_client.py +++ b/ibis/backends/sqlite/tests/test_client.py @@ -51,7 +51,6 @@ def total(x) -> float: assert result == 0.0 -@pytest.mark.sqlite @pytest.mark.parametrize( "url, ext", [ @@ -61,16 +60,15 @@ def total(x) -> float: param( lambda p: f"sqlite://{os.path.relpath(p)}", "db", - marks=[ - not_windows - ], # hard to test in CI since tmpdir & cwd are on different drives + # hard to test in CI since tmpdir & cwd are on different drives + marks=[not_windows], id="relative-path", ), param(lambda _: "sqlite://", "db", id="in-memory-empty"), param(lambda _: "sqlite://:memory:", "db", id="in-memory-explicit"), ], ) -def test_connect_sqlite(url, ext, tmp_path): +def test_connect(url, ext, tmp_path): path = os.path.abspath(tmp_path / f"test.{ext}") with sqlite3.connect(path): pass diff --git a/ibis/backends/tests/base.py b/ibis/backends/tests/base.py index e727d9af312a5..eb28ed46eba01 100644 --- a/ibis/backends/tests/base.py +++ b/ibis/backends/tests/base.py @@ -11,7 +11,6 @@ import pandas as pd import pandas.testing as tm import pytest -import toolz from filelock import FileLock if TYPE_CHECKING: @@ -53,8 +52,6 @@ class BackendTest(abc.ABC): "Whether backend supports mappings (currently DuckDB, Snowflake, and Trino)" reduction_tolerance = 1e-7 "Used for a single test in `test_aggregation.py`. You should not need to touch this." - default_identifier_case_fn = staticmethod(toolz.identity) - "Function applied to all identifier names to change case as necessary (e.g. Snowflake ALL_CAPS)" stateful = True "Whether special handling is needed for running a multi-process pytest run." supports_tpch: bool = False @@ -63,6 +60,8 @@ class BackendTest(abc.ABC): "Sort results before comparing against reference computation." rounding_method: Literal["away_from_zero", "half_to_even"] = "away_from_zero" "Name of round method to use for rounding test comparisons." + driver_supports_multiple_statements: bool = False + "Whether the driver supports executing multiple statements in a single call." @property @abc.abstractmethod @@ -120,9 +119,13 @@ def _transform_tpch_sql(self, parsed): def _load_data(self, **_: Any) -> None: """Load test data into a backend.""" - with self.connection.begin() as con: - for stmt in self.ddl_script: - con.exec_driver_sql(stmt) + if self.driver_supports_multiple_statements: + with self.connection._safe_raw_sql(";".join(self.ddl_script)): + pass + else: + with self.connection.begin() as con: + for stmt in self.ddl_script: + con.execute(stmt) def stateless_load(self, **kw): self.preload() @@ -228,32 +231,30 @@ def default_series_rename(series: pd.Series, name: str = "tmp") -> pd.Series: @property def functional_alltypes(self) -> ir.Table: - t = self.connection.table( - self.default_identifier_case_fn("functional_alltypes") - ) + t = self.connection.table("functional_alltypes") if not self.native_bool: return t.mutate(bool_col=t.bool_col == 1) return t @property def batting(self) -> ir.Table: - return self.connection.table(self.default_identifier_case_fn("batting")) + return self.connection.table("batting") @property def awards_players(self) -> ir.Table: - return self.connection.table(self.default_identifier_case_fn("awards_players")) + return self.connection.table("awards_players") @property def diamonds(self) -> ir.Table: - return self.connection.table(self.default_identifier_case_fn("diamonds")) + return self.connection.table("diamonds") @property def astronauts(self) -> ir.Table: - return self.connection.table(self.default_identifier_case_fn("astronauts")) + return self.connection.table("astronauts") @property def geo(self) -> ir.Table | None: - name = self.default_identifier_case_fn("geo") + name = "geo" if name in self.connection.list_tables(): return self.connection.table(name) return None @@ -261,14 +262,14 @@ def geo(self) -> ir.Table | None: @property def struct(self) -> ir.Table | None: if self.supports_structs: - return self.connection.table(self.default_identifier_case_fn("struct")) + return self.connection.table("struct") else: pytest.xfail(f"{self.name()} backend does not support struct types") @property def array_types(self) -> ir.Table | None: if self.supports_arrays: - return self.connection.table(self.default_identifier_case_fn("array_types")) + return self.connection.table("array_types") else: pytest.xfail(f"{self.name()} backend does not support array types") @@ -277,22 +278,20 @@ def json_t(self) -> ir.Table | None: from ibis import _ if self.supports_json: - return self.connection.table( - self.default_identifier_case_fn("json_t") - ).mutate(js=_.js.cast("json")) + return self.connection.table("json_t").mutate(js=_.js.cast("json")) else: pytest.xfail(f"{self.name()} backend does not support json types") @property def map(self) -> ir.Table | None: if self.supports_map: - return self.connection.table(self.default_identifier_case_fn("map")) + return self.connection.table("map") else: pytest.xfail(f"{self.name()} backend does not support map types") @property def win(self) -> ir.Table | None: - return self.connection.table(self.default_identifier_case_fn("win")) + return self.connection.table("win") @property def api(self): @@ -336,7 +335,7 @@ def supplier(self): def _tpch_table(self, name: str): if not self.supports_tpch: pytest.skip(f"{self.name()} backend does not support testing TPC-H") - return self.connection.table(self.default_identifier_case_fn(name)) + return self.connection.table(name) class ServiceBackendTest(BackendTest): diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index e9a8347ab094a..8aec1a1912ed2 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -17,8 +17,13 @@ from clickhouse_connect.driver.exceptions import ( InternalError as ClickHouseInternalError, ) + from clickhouse_connect.driver.exceptions import ( + OperationalError as ClickHouseOperationalError, + ) except ImportError: - ClickHouseDatabaseError = ClickHouseInternalError = None + ClickHouseDatabaseError = ( + ClickHouseInternalError + ) = ClickHouseOperationalError = None try: @@ -96,14 +101,22 @@ from psycopg2.errors import ( InvalidTextRepresentation as PsycoPg2InvalidTextRepresentation, ) + from psycopg2.errors import OperationalError as PsycoPg2OperationalError from psycopg2.errors import ProgrammingError as PsycoPg2ProgrammingError from psycopg2.errors import SyntaxError as PsycoPg2SyntaxError + from psycopg2.errors import UndefinedObject as PsycoPg2UndefinedObject except ImportError: PsycoPg2SyntaxError = ( PsycoPg2IndeterminateDatatype ) = ( PsycoPg2InvalidTextRepresentation - ) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = None + ) = ( + PsycoPg2DivisionByZero + ) = ( + PsycoPg2InternalError + ) = ( + PsycoPg2ProgrammingError + ) = PsycoPg2OperationalError = PsycoPg2UndefinedObject = None try: from pymysql.err import NotSupportedError as MySQLNotSupportedError diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 76c4056f173ea..66988e15bb633 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis @@ -306,9 +305,7 @@ def mean_and_std(v): raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), - pytest.mark.notimpl( - ["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError) - ), + pytest.mark.notimpl(["exasol"], raises=ExaQueryError), pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError), ], ), @@ -327,9 +324,7 @@ def mean_and_std(v): raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), - pytest.mark.notimpl( - ["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError) - ), + pytest.mark.notimpl(["exasol"], raises=ExaQueryError), pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError), ], ), @@ -361,9 +356,7 @@ def mean_and_std(v): raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), - pytest.mark.notimpl( - ["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError) - ), + pytest.mark.notimpl(["exasol"], raises=ExaQueryError), pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError), ], ), @@ -382,9 +375,7 @@ def mean_and_std(v): raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), - pytest.mark.notimpl( - ["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError) - ), + pytest.mark.notimpl(["exasol"], raises=ExaQueryError), pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError), ], ), @@ -822,7 +813,6 @@ def test_reduction_ops( @pytest.mark.notyet( ["bigquery", "druid", "mssql", "oracle", "sqlite", "flink"], raises=( - sa.exc.OperationalError, OracleDatabaseError, com.UnsupportedOperationError, com.OperationNotDefinedError, diff --git a/ibis/backends/tests/test_api.py b/ibis/backends/tests/test_api.py index 5687eccd64e4d..c41ed2f96767d 100644 --- a/ibis/backends/tests/test_api.py +++ b/ibis/backends/tests/test_api.py @@ -54,11 +54,11 @@ def test_list_tables(con): assert all(isinstance(table, str) for table in tables) -def test_tables_accessor_mapping(backend, con): +def test_tables_accessor_mapping(con): if con.name == "snowflake": pytest.skip("snowflake sometimes counts more tables than are around") - name = backend.default_identifier_case_fn("functional_alltypes") + name = "functional_alltypes" assert isinstance(con.tables[name], ir.Table) @@ -72,8 +72,8 @@ def test_tables_accessor_mapping(backend, con): assert TEST_TABLES.keys() & set(map(str.lower, con.tables)) -def test_tables_accessor_getattr(backend, con): - name = backend.default_identifier_case_fn("functional_alltypes") +def test_tables_accessor_getattr(con): + name = "functional_alltypes" assert isinstance(getattr(con.tables, name), ir.Table) with pytest.raises(AttributeError, match="doesnt_exist"): @@ -85,8 +85,8 @@ def test_tables_accessor_getattr(backend, con): con.tables._private_attr # noqa: B018 -def test_tables_accessor_tab_completion(backend, con): - name = backend.default_identifier_case_fn("functional_alltypes") +def test_tables_accessor_tab_completion(con): + name = "functional_alltypes" attrs = dir(con.tables) assert name in attrs assert "keys" in attrs # type methods also present @@ -95,8 +95,8 @@ def test_tables_accessor_tab_completion(backend, con): assert name in keys -def test_tables_accessor_repr(backend, con): - name = backend.default_identifier_case_fn("functional_alltypes") +def test_tables_accessor_repr(con): + name = "functional_alltypes" result = repr(con.tables) assert f"- {name}" in result diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index a5fbaa58ff96e..48d906189e063 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -368,7 +368,12 @@ def test_unnest_no_nulls(backend): reason="all the input arrays must have same number of dimensions", ) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) -@pytest.mark.broken(["risingwave"], raises=AssertionError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + strict=False, + reason="row ordering is not guaranteed", +) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() diff --git a/ibis/backends/tests/test_benchmarks.py b/ibis/backends/tests/test_benchmarks.py index fbfd0977887e5..94c680ab9c17b 100644 --- a/ibis/backends/tests/test_benchmarks.py +++ b/ibis/backends/tests/test_benchmarks.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from packaging.version import parse as vparse import ibis @@ -194,7 +193,7 @@ def test_compile(benchmark, module, expr_fn, t, base, large_expr): expr = expr_fn(t, base, large_expr) try: benchmark(mod.compile, expr) - except (sa.exc.NoSuchModuleError, ImportError) as e: # delayed imports + except ImportError as e: # delayed imports pytest.skip(str(e)) @@ -696,10 +695,7 @@ def test_compile_with_drops( except (AttributeError, ImportError) as e: pytest.skip(str(e)) else: - try: - benchmark(mod.compile, expr) - except sa.exc.NoSuchModuleError as e: - pytest.skip(str(e)) + benchmark(mod.compile, expr) def test_repr_join(benchmark, customers, orders, orders_items, products): diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 6fabd77ec840a..3339a2c250505 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -16,7 +16,7 @@ import pyarrow as pa import pytest import rich.console -import sqlalchemy as sa +import toolz from packaging.version import parse as vparse from pytest import mark, param @@ -26,11 +26,16 @@ import ibis.expr.operations as ops from ibis.backends.conftest import ALL_BACKENDS from ibis.backends.tests.errors import ( + ExaQueryError, + OracleDatabaseError, PsycoPg2InternalError, + PsycoPg2UndefinedObject, Py4JJavaError, - PyDruidProgrammingError, + PyODBCProgrammingError, + SnowflakeProgrammingError, + TrinoUserError, ) -from ibis.util import gen_name, guid +from ibis.util import gen_name if TYPE_CHECKING: from ibis.backends.base import BaseBackend @@ -42,6 +47,8 @@ def new_schema(): def _create_temp_table_with_schema(backend, con, temp_table_name, schema, data=None): + if con.name == "druid": + pytest.xfail("druid doesn't implement create_table") temporary = con.create_table(temp_table_name, schema=schema) assert temporary.to_pandas().empty @@ -60,38 +67,29 @@ def _create_temp_table_with_schema(backend, con, temp_table_name, schema, data=N @pytest.mark.parametrize( - "lamduh", - [ - (lambda df: df), - param( - lambda df: pa.Table.from_pandas(df), marks=pytest.mark.notimpl(["impala"]) - ), - ], - ids=["dataframe", "pyarrow table"], + "func", [toolz.identity, pa.Table.from_pandas], ids=["dataframe", "pyarrow_table"] ) @pytest.mark.parametrize( "sch", [ - param(None, id="no schema"), - param( - ibis.schema( - [ - ("first_name", "string"), - ("last_name", "string"), - ("department_name", "string"), - ("salary", "float64"), - ] - ), - id="schema", + None, + ibis.schema( + dict( + first_name="string", + last_name="string", + department_name="string", + salary="float64", + ) ), ], + ids=["no_schema", "schema"], ) -@pytest.mark.notimpl(["druid", "impala"]) +@pytest.mark.notimpl(["druid"]) @pytest.mark.notimpl( ["flink"], reason="Flink backend supports creating only TEMPORARY VIEW for in-memory data.", ) -def test_create_table(backend, con, temp_table, lamduh, sch): +def test_create_table(backend, con, temp_table, func, sch): df = pd.DataFrame( { "first_name": ["A", "B", "C"], @@ -101,7 +99,7 @@ def test_create_table(backend, con, temp_table, lamduh, sch): } ) - con.create_table(temp_table, lamduh(df), schema=sch) + con.create_table(temp_table, func(df), schema=sch) result = ( con.table(temp_table).execute().sort_values("first_name").reset_index(drop=True) ) @@ -122,7 +120,6 @@ def test_create_table(backend, con, temp_table, lamduh, sch): ["pyspark", "trino", "exasol", "risingwave"], reason="No support for temp tables", ), - pytest.mark.never(["polars"], reason="Everything in-memory is temp"), pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"), pytest.mark.broken( ["bigquery"], @@ -130,7 +127,16 @@ def test_create_table(backend, con, temp_table, lamduh, sch): ), ], ), - param(False, True, id="no temp, overwrite"), + param( + False, + True, + marks=[ + pytest.mark.notyet( + ["polars"], raises=com.IbisError, reason="all tables are ephemeral" + ) + ], + id="no temp, overwrite", + ), param( True, False, @@ -140,7 +146,6 @@ def test_create_table(backend, con, temp_table, lamduh, sch): ["pyspark", "trino", "exasol", "risingwave"], reason="No support for temp tables", ), - pytest.mark.never(["polars"], reason="Everything in-memory is temp"), pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"), pytest.mark.broken( ["bigquery"], @@ -176,7 +181,8 @@ def test_create_table_overwrite_temp(backend, con, temp_table, temp, overwrite): [(lambda df: df), (lambda df: pa.Table.from_pandas(df))], ids=["dataframe", "pyarrow table"], ) -def test_load_data_sqlalchemy(alchemy_backend, alchemy_con, alchemy_temp_table, lamduh): +@pytest.mark.notyet(["druid"], raises=NotImplementedError) +def test_load_data(backend, con, temp_table, lamduh): sch = ibis.schema( [ ("first_name", "string"), @@ -196,15 +202,12 @@ def test_load_data_sqlalchemy(alchemy_backend, alchemy_con, alchemy_temp_table, ) obj = lamduh(df) - alchemy_con.create_table(alchemy_temp_table, obj, schema=sch, overwrite=True) + con.create_table(temp_table, obj, schema=sch, overwrite=True) result = ( - alchemy_con.table(alchemy_temp_table) - .execute() - .sort_values("first_name") - .reset_index(drop=True) + con.table(temp_table).execute().sort_values("first_name").reset_index(drop=True) ) - alchemy_backend.assert_frame_equal(df, result) + backend.assert_frame_equal(df, result) @mark.parametrize( @@ -291,16 +294,6 @@ def test_create_table_from_schema(con, new_schema, temp_table): assert result == new_table.schema() -@pytest.fixture(scope="session") -def tmpcon(alchemy_con): - """A fixture to scope the connection for temp table testing. - - This prevents resetting the connection for subsequent tests that may depend - on connection state persisting across tests. - """ - return alchemy_con._from_url(alchemy_con.con.url) - - @mark.broken( ["oracle"], reason="oracle temp tables aren't cleaned up on reconnect -- they need to " @@ -310,28 +303,34 @@ def tmpcon(alchemy_con): @mark.never( ["mssql"], reason="mssql supports support temporary tables through naming conventions", + raises=PyODBCProgrammingError, ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") -@pytest.mark.never( +@pytest.mark.notimpl( + ["impala", "pyspark"], + reason="temporary tables not implemented", + raises=NotImplementedError, +) +@pytest.mark.notyet( ["risingwave"], raises=PsycoPg2InternalError, - reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", + reason="truncate not supported upstream", ) -def test_create_temporary_table_from_schema(tmpcon, new_schema): - temp_table = f"_{guid()}" - table = tmpcon.create_table(temp_table, schema=new_schema, temp=True) +def test_create_temporary_table_from_schema(con_no_data, new_schema): + temp_table = gen_name(f"test_{con_no_data.name}_tmp") + table = con_no_data.create_table(temp_table, schema=new_schema, temp=True) # verify table exist in the current session - backend_mapping = backend_type_mapping.get(tmpcon.name, dict()) + backend_mapping = backend_type_mapping.get(con_no_data.name, dict()) for column_name, column_type in table.schema().items(): assert ( backend_mapping.get(new_schema[column_name], new_schema[column_name]) == column_type ) - tmpcon.reconnect() + con_no_data.reconnect() # verify table no longer exist after reconnect - assert temp_table not in tmpcon.tables.keys() + assert temp_table not in con_no_data.tables.keys() @mark.notimpl( @@ -447,137 +446,127 @@ def test_separate_database(ddl_con, alternate_current_database): @pytest.fixture -def employee_empty_temp_table(alchemy_backend, alchemy_con, test_employee_schema): - temp_table_name = f"temp_employee_empty_table_{guid()[:6]}" - _create_temp_table_with_schema( - alchemy_backend, - alchemy_con, - temp_table_name, - test_employee_schema, - ) +def employee_empty_temp_table(backend, con, test_employee_schema): + temp_table_name = gen_name("temp_employee_empty_table") + _create_temp_table_with_schema(backend, con, temp_table_name, test_employee_schema) yield temp_table_name - alchemy_con.drop_table(temp_table_name, force=True) + con.drop_table(temp_table_name, force=True) @pytest.fixture def employee_data_1_temp_table( - alchemy_backend, - alchemy_con, - test_employee_schema, - test_employee_data_1, + backend, con, test_employee_schema, test_employee_data_1 ): - temp_table_name = f"temp_employee_data_1_{guid()[:6]}" + temp_table_name = gen_name("temp_employee_data_1") _create_temp_table_with_schema( - alchemy_backend, - alchemy_con, - temp_table_name, - test_employee_schema, - data=test_employee_data_1, + backend, con, temp_table_name, test_employee_schema, data=test_employee_data_1 ) - assert temp_table_name in alchemy_con.list_tables() + assert temp_table_name in con.list_tables() yield temp_table_name - alchemy_con.drop_table(temp_table_name, force=True) + con.drop_table(temp_table_name, force=True) @pytest.fixture def employee_data_2_temp_table( - alchemy_backend, - alchemy_con, - test_employee_schema, - test_employee_data_2, + backend, con, test_employee_schema, test_employee_data_2 ): - temp_table_name = f"temp_employee_data_2_{guid()[:6]}" + temp_table_name = gen_name("temp_employee_data_2") _create_temp_table_with_schema( - alchemy_backend, - alchemy_con, - temp_table_name, - test_employee_schema, - data=test_employee_data_2, + backend, con, temp_table_name, test_employee_schema, data=test_employee_data_2 ) yield temp_table_name - alchemy_con.drop_table(temp_table_name, force=True) + con.drop_table(temp_table_name, force=True) +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], reason="`insert` method not implemented" +) def test_insert_no_overwrite_from_dataframe( - alchemy_backend, - alchemy_con, - test_employee_data_2, - employee_empty_temp_table, + backend, con, test_employee_data_2, employee_empty_temp_table ): - temporary = alchemy_con.table(employee_empty_temp_table) - alchemy_con.insert( - employee_empty_temp_table, - obj=test_employee_data_2, - overwrite=False, - ) + temporary = con.table(employee_empty_temp_table) + con.insert(employee_empty_temp_table, obj=test_employee_data_2, overwrite=False) result = temporary.execute() assert len(result) == 3 - alchemy_backend.assert_frame_equal( + backend.assert_frame_equal( result.sort_values("first_name").reset_index(drop=True), test_employee_data_2.sort_values("first_name").reset_index(drop=True), ) +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], reason="`insert` method not implemented" +) +@pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="truncate not supported upstream", +) +@pytest.mark.notyet( + ["datafusion"], raises=Exception, reason="DELETE DML not implemented upstream" +) +@pytest.mark.notyet( + ["trino"], raises=TrinoUserError, reason="requires a non-memory connector" +) +@pytest.mark.notyet(["druid"], raises=NotImplementedError) def test_insert_overwrite_from_dataframe( - alchemy_backend, - alchemy_con, - employee_data_1_temp_table, - test_employee_data_2, + backend, con, employee_data_1_temp_table, test_employee_data_2 ): - temporary = alchemy_con.table(employee_data_1_temp_table) + temporary = con.table(employee_data_1_temp_table) - alchemy_con.insert( - employee_data_1_temp_table, - obj=test_employee_data_2, - overwrite=True, - ) + con.insert(employee_data_1_temp_table, obj=test_employee_data_2, overwrite=True) result = temporary.execute() assert len(result) == 3 - alchemy_backend.assert_frame_equal( + backend.assert_frame_equal( result.sort_values("first_name").reset_index(drop=True), test_employee_data_2.sort_values("first_name").reset_index(drop=True), ) +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], reason="`insert` method not implemented" +) def test_insert_no_overwrite_from_expr( - alchemy_backend, - alchemy_con, - employee_empty_temp_table, - employee_data_2_temp_table, + backend, con, employee_empty_temp_table, employee_data_2_temp_table ): - temporary = alchemy_con.table(employee_empty_temp_table) - from_table = alchemy_con.table(employee_data_2_temp_table) + temporary = con.table(employee_empty_temp_table) + from_table = con.table(employee_data_2_temp_table) - alchemy_con.insert( - employee_empty_temp_table, - obj=from_table, - overwrite=False, - ) + con.insert(employee_empty_temp_table, obj=from_table, overwrite=False) result = temporary.execute() assert len(result) == 3 - alchemy_backend.assert_frame_equal( + backend.assert_frame_equal( result.sort_values("first_name").reset_index(drop=True), from_table.execute().sort_values("first_name").reset_index(drop=True), ) +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], reason="`insert` method not implemented" +) +@pytest.mark.notyet( + ["datafusion"], raises=Exception, reason="DELETE DML not implemented upstream" +) +@pytest.mark.notyet( + ["trino"], + raises=TrinoUserError, + reason="requires a non-memory connector for truncation", +) +@pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="truncate not supported upstream", +) def test_insert_overwrite_from_expr( - alchemy_backend, - alchemy_con, - employee_data_1_temp_table, - employee_data_2_temp_table, + backend, con, employee_data_1_temp_table, employee_data_2_temp_table ): - temporary = alchemy_con.table(employee_data_1_temp_table) - from_table = alchemy_con.table(employee_data_2_temp_table) + temporary = con.table(employee_data_1_temp_table) + from_table = con.table(employee_data_2_temp_table) - alchemy_con.insert( - employee_data_1_temp_table, - obj=from_table, - overwrite=True, - ) + con.insert(employee_data_1_temp_table, obj=from_table, overwrite=True) result = temporary.execute() assert len(result) == 3 - alchemy_backend.assert_frame_equal( + backend.assert_frame_equal( result.sort_values("first_name").reset_index(drop=True), from_table.execute().sort_values("first_name").reset_index(drop=True), ) @@ -586,19 +575,22 @@ def test_insert_overwrite_from_expr( @pytest.mark.notyet( ["trino"], reason="memory connector doesn't allow writing to tables" ) +@pytest.mark.notimpl( + ["polars", "pandas", "dask"], reason="`insert` method not implemented" +) @pytest.mark.notyet( - ["oracle", "exasol"], - reason="No support for in-place multirow inserts", - raises=sa.exc.CompileError, + ["datafusion"], raises=Exception, reason="DELETE DML not implemented upstream" ) -def test_insert_overwrite_from_list( - alchemy_con, - employee_data_1_temp_table, -): +@pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="truncate not supported upstream", +) +def test_insert_overwrite_from_list(con, employee_data_1_temp_table): def _emp(a, b, c, d): return dict(first_name=a, last_name=b, department_name=c, salary=d) - alchemy_con.insert( + con.insert( employee_data_1_temp_table, [ _emp("Adam", "Smith", "Accounting", 50000.0), @@ -608,77 +600,85 @@ def _emp(a, b, c, d): overwrite=True, ) - assert len(alchemy_con.table(employee_data_1_temp_table).execute()) == 3 + assert len(con.table(employee_data_1_temp_table).execute()) == 3 -def test_insert_from_memtable(alchemy_con, alchemy_temp_table): +@pytest.mark.notimpl( + ["polars", "dask", "pandas"], + raises=AttributeError, + reason="`insert` method not implemented", +) +@pytest.mark.notyet(["druid"], raises=NotImplementedError) +def test_insert_from_memtable(con, temp_table): df = pd.DataFrame({"x": range(3)}) - table_name = alchemy_temp_table + table_name = temp_table mt = ibis.memtable(df) - alchemy_con.create_table(table_name, schema=mt.schema()) - alchemy_con.insert(table_name, mt) - alchemy_con.insert(table_name, mt) + con.create_table(table_name, schema=mt.schema()) + con.insert(table_name, mt) + con.insert(table_name, mt) - table = alchemy_con.tables[table_name] + table = con.tables[table_name] assert len(table.execute()) == 6 - assert alchemy_con.tables[table_name].schema() == ibis.schema({"x": "int64"}) + assert con.tables[table_name].schema() == ibis.schema({"x": "int64"}) @pytest.mark.notyet( - ["oracle"], - raises=AttributeError, - reason="oracle doesn't support the common notion of a database", -) -@pytest.mark.notyet( - ["exasol"], + ["bigquery", "oracle", "dask", "exasol", "polars", "pandas", "druid"], raises=AttributeError, - reason="exasol doesn't support the common notion of a database", + reason="doesn't support the common notion of a database", ) -def test_list_databases(alchemy_con): +def test_list_databases(con): # Every backend has its own databases test_databases = { - "sqlite": {"main"}, - "postgres": {"postgres", "ibis_testing"}, - "risingwave": {"dev"}, + "clickhouse": {"system", "default", "ibis_testing"}, + "datafusion": {"datafusion"}, + "duckdb": {"memory"}, + "exasol": set(), + "impala": set(), "mssql": {"ibis_testing"}, "mysql": {"ibis_testing", "information_schema"}, - "duckdb": {"memory"}, + "oracle": set(), + "postgres": {"postgres", "ibis_testing"}, + "risingwave": {"dev"}, "snowflake": {"IBIS_TESTING"}, + "pyspark": set(), + "sqlite": {"main"}, "trino": {"memory"}, - "oracle": set(), - "exasol": set(), } - assert test_databases[alchemy_con.name] <= set(alchemy_con.list_databases()) + result = set(con.list_databases()) + assert test_databases[con.name] <= result -@pytest.mark.never( - ["bigquery", "postgres", "risingwave", "mssql", "mysql", "oracle"], - reason="backend does not support client-side in-memory tables", - raises=(sa.exc.OperationalError, TypeError, sa.exc.InterfaceError), +@pytest.mark.notyet( + ["postgres", "snowflake"], + raises=TypeError, + reason="backend does not support unsigned integer types", ) +@pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) +@pytest.mark.notyet(["pyspark"], raises=com.IbisTypeError) +@pytest.mark.notyet(["bigquery", "impala"], raises=com.UnsupportedBackendType) @pytest.mark.notyet( - ["trino"], reason="memory connector doesn't allow writing to tables" + ["postgres"], raises=PsycoPg2UndefinedObject, reason="no unsigned int types" ) -@pytest.mark.notimpl(["exasol"]) -def test_in_memory(alchemy_backend, alchemy_temp_table): - con = getattr(ibis, alchemy_backend.name()).connect(":memory:") - with con.begin() as c: - c.exec_driver_sql(f"CREATE TABLE {alchemy_temp_table} (x int)") - assert alchemy_temp_table in con.list_tables() - - @pytest.mark.notyet( - ["mssql", "mysql", "postgres", "snowflake", "sqlite", "trino"], - raises=TypeError, - reason="backend does not support unsigned integer types", + ["oracle"], raises=OracleDatabaseError, reason="no unsigned int types" +) +@pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="no unsigned int types") +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="no unsigned int types") +@pytest.mark.notyet(["druid"], raises=NotImplementedError) +@pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError) +@pytest.mark.notyet( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="unsigned integers are not supported", ) -def test_unsigned_integer_type(alchemy_con, alchemy_temp_table): - alchemy_con.create_table( - alchemy_temp_table, +def test_unsigned_integer_type(con, temp_table): + con.create_table( + temp_table, schema=ibis.schema(dict(a="uint8", b="uint16", c="uint32", d="uint64")), overwrite=True, ) - assert alchemy_temp_table in alchemy_con.list_tables() + assert temp_table in con.list_tables() @pytest.mark.backend @@ -1155,8 +1155,10 @@ def test_set_backend_url(url, monkeypatch): @pytest.mark.notimpl( ["snowflake"], reason="scale not implemented in ibis's snowflake backend" ) -@pytest.mark.broken(["oracle"], reason="oracle doesn't like `DESCRIBE` from sqlalchemy") -@pytest.mark.broken(["druid"], reason="sqlalchemy dialect is broken") +@pytest.mark.broken( + ["oracle"], reason="oracle doesn't allow DESCRIBE outside of its CLI" +) +@pytest.mark.broken(["druid"], reason="dialect is broken") @pytest.mark.notimpl( ["flink"], raises=com.IbisError, @@ -1334,7 +1336,7 @@ def test_persist_expression_repeated_cache(alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) @mark.notimpl( @@ -1379,7 +1381,7 @@ def gen_test_name(con: BaseBackend) -> str: reason="overwriting not implemented in ibis for this backend", ) @mark.broken( - ["druid"], raises=PyDruidProgrammingError, reason="generated SQL fails to parse" + ["druid"], raises=NotImplementedError, reason="generated SQL fails to parse" ) @mark.notimpl(["impala"], reason="impala doesn't support memtable") @mark.notimpl(["pyspark"]) diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index b92abea1470cf..11e4e9f2e9a48 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -1270,91 +1270,6 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): backend.assert_series_equal(result.astype("float64"), expected) -@pytest.mark.parametrize( - ("default_precisions", "default_scales"), - [ - ( - { - "postgres": None, - "risingwave": None, - "mysql": 10, - "snowflake": 38, - "trino": 18, - "sqlite": None, - "mssql": None, - "oracle": 38, - }, - { - "postgres": None, - "risingwave": None, - "mysql": 0, - "snowflake": 0, - "trino": 3, - "sqlite": None, - "mssql": None, - "oracle": 0, - }, - ) - ], -) -@pytest.mark.never( - [ - "bigquery", - "clickhouse", - "dask", - "datafusion", - "duckdb", - "impala", - "oracle", - "pandas", - "pyspark", - "polars", - "flink", - "sqlite", - "snowflake", - "trino", - "postgres", - "risingwave", - "mysql", - "druid", - "mssql", - "exasol", - ], - reason="Not SQLAlchemy backends", -) -def test_sa_default_numeric_precision_and_scale( - con, backend, default_precisions, default_scales, temp_table -): - sa = pytest.importorskip("sqlalchemy") - - default_precision = default_precisions[backend.name()] - default_scale = default_scales[backend.name()] - - typespec = [ - # name, sqlalchemy type, ibis type - ("n1", sa.NUMERIC, dt.Decimal(default_precision, default_scale)), - ("n2", sa.NUMERIC(5), dt.Decimal(5, default_scale)), - ("n3", sa.NUMERIC(None, 4), dt.Decimal(default_precision, 4)), - ("n4", sa.NUMERIC(10, 2), dt.Decimal(10, 2)), - ] - - sqla_types = [] - ibis_types = [] - for name, t, ibis_type in typespec: - sqla_types.append(sa.Column(name, t, nullable=True)) - ibis_types.append((name, ibis_type(nullable=True))) - - table = sa.Table(temp_table, sa.MetaData(), *sqla_types, quote=True) - with con.begin() as bind: - table.create(bind=bind, checkfirst=True) - - # Check that we can correctly recover the default precision and scale. - schema = con._schema_from_sqla_table(table) - expected = ibis.schema(ibis_types) - - assert_equal(schema, expected) - - @pytest.mark.notimpl(["dask", "pandas", "polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) @pytest.mark.notimpl( @@ -1472,7 +1387,7 @@ def test_constants(con, const): param(lambda t: t.int_col, lambda _: 3, id="col_scalar"), ], ) -@pytest.mark.notimpl(["exasol"], raises=(ExaQueryError)) +@pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @flink_no_bitwise def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): expr = op(left_fn(alltypes), right_fn(alltypes)).name("tmp") @@ -1509,7 +1424,7 @@ def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): ], ) @pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) -@pytest.mark.notimpl(["exasol"], raises=(ExaQueryError)) +@pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @flink_no_bitwise def test_bitwise_shift(backend, alltypes, df, op, left_fn, right_fn): expr = op(left_fn(alltypes), right_fn(alltypes)).name("tmp") diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 64824b6124629..8564a4d4a2820 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -354,8 +354,6 @@ def test_csv_reregister_schema(con, tmp_path): ) # For a full file scan, expect correct schema based on final row - # We also use the same `table_name` for both tests to ensure that - # the table is re-reflected in sqlalchemy foo_table = con.register(foo, table_name="same") result_schema = foo_table.schema() @@ -389,11 +387,9 @@ def test_register_garbage(con, monkeypatch): # monkeypatch to avoid downloading extensions in tests monkeypatch.setattr(con, "_load_extensions", lambda x: True) - sa = pytest.importorskip("sqlalchemy") duckdb = pytest.importorskip("duckdb") with pytest.raises( - (sa.exc.OperationalError, duckdb.IOException), - match="No files found that match the pattern", + duckdb.IOException, match="No files found that match the pattern" ): con.read_csv("garbage_notafile") diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 8d0b06e85cb03..e1bca4bb1ff2e 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -1941,11 +1941,7 @@ def test_time_literal(con, backend): raises=com.OperationNotDefinedError, reason="backend doesn't have a time datatype", ) -@pytest.mark.notyet( - ["druid"], - raises=PyDruidProgrammingError, - reason="druid sqlalchemy dialect fails to compile datetime types", -) +@pytest.mark.notyet(["druid"], raises=PyDruidProgrammingError) @pytest.mark.broken( ["sqlite"], raises=AssertionError, reason="SQLite returns Timedelta from execution" ) @@ -1962,11 +1958,7 @@ def test_time_literal(con, backend): raises=AssertionError, reason="doesn't have enough precision to capture microseconds", ), - pytest.mark.notyet( - ["trino"], - raises=AssertionError, - reason="has enough precision, but sqlalchemy dialect drops them", - ), + pytest.mark.notyet(["trino"], raises=AssertionError), pytest.mark.notimpl( ["flink"], raises=AssertionError, @@ -2017,13 +2009,7 @@ def test_extract_time_from_timestamp(con, microsecond): raises=ImpalaHiveServer2Error, ) @pytest.mark.broken( - ["mysql"], - "The backend implementation is broken. " - "If SQLAlchemy < 2 is installed, test fails with the following exception:" - "AttributeError: 'TextClause' object has no attribute 'label'" - "If SQLAlchemy >=2 is installed, test fails with the following exception:" - "NotImplementedError", - raises=MySQLProgrammingError, + ["mysql"], "The backend implementation is broken. ", raises=MySQLProgrammingError ) @pytest.mark.broken( ["bigquery", "duckdb"], diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index 75021f577133f..4aaa309ffab85 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -57,9 +57,7 @@ def num_vowels(s: str, include_y: bool = False) -> int: @mark.notimpl(["polars"]) @mark.notyet(["datafusion"], raises=NotImplementedError) @mark.notyet( - ["sqlite"], - raises=com.IbisTypeError, - reason="sqlite doesn't support map types", + ["sqlite"], raises=com.IbisTypeError, reason="sqlite doesn't support map types" ) def test_map_udf(batting): @udf.scalar.python diff --git a/ibis/backends/tests/tpch/test_h08.py b/ibis/backends/tests/tpch/test_h08.py index 971a83c4c3526..651a725b1fc8a 100644 --- a/ibis/backends/tests/tpch/test_h08.py +++ b/ibis/backends/tests/tpch/test_h08.py @@ -1,17 +1,11 @@ from __future__ import annotations -import pytest - import ibis from .conftest import add_date, tpch_test @tpch_test -@pytest.mark.xfail_version( - trino=["sqlalchemy>=2"], - reason="slightly different code is generated for sqlalchemy 2 for aggregations", -) def test_tpc_h08(part, supplier, region, lineitem, orders, customer, nation): """National Market Share Query (Q8)""" NATION = "BRAZIL" diff --git a/ibis/backends/tests/tpch/test_h14.py b/ibis/backends/tests/tpch/test_h14.py index f72bbcaf6c2b6..efc84d5b3aba7 100644 --- a/ibis/backends/tests/tpch/test_h14.py +++ b/ibis/backends/tests/tpch/test_h14.py @@ -1,17 +1,11 @@ from __future__ import annotations -import pytest - import ibis from .conftest import add_date, tpch_test @tpch_test -@pytest.mark.xfail_version( - trino=["sqlalchemy>=2"], - reason="slightly different code is generated for sqlalchemy 2 for aggregations", -) def test_tpc_h14(part, lineitem): """Promotion Effect Query (Q14) diff --git a/ibis/backends/tests/tpch/test_h17.py b/ibis/backends/tests/tpch/test_h17.py index 0d112d048c910..b451ea322d219 100644 --- a/ibis/backends/tests/tpch/test_h17.py +++ b/ibis/backends/tests/tpch/test_h17.py @@ -1,15 +1,9 @@ from __future__ import annotations -import pytest - from .conftest import tpch_test @tpch_test -@pytest.mark.xfail_version( - trino=["sqlalchemy>=2"], - reason="slightly different code is generated for sqlalchemy 2 for aggregations", -) def test_tpc_h17(lineitem, part): """Small-Quantity-Order Revenue Query (Q17) diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 619a4af23717d..d9b15b4c22344 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -242,7 +242,7 @@ def do_connect( schema: str | None = None, source: str | None = None, timezone: str = "UTC", - **connect_args, + **kwargs, ) -> None: """Connect to Trino. @@ -264,9 +264,9 @@ def do_connect( Application name passed to Trino timezone Timezone to use for the connection - connect_args - Additional keyword arguments passed directly to SQLAlchemy's - `create_engine` + kwargs + Additional keyword arguments passed directly to the + `trino.dbapi.connect` API. Examples -------- @@ -296,7 +296,7 @@ def do_connect( schema=schema, source=source or "ibis", timezone=timezone, - **connect_args, + **kwargs, ) @contextlib.contextmanager diff --git a/ibis/backends/trino/tests/conftest.py b/ibis/backends/trino/tests/conftest.py index c8f0839255e73..7b1d256e80102 100644 --- a/ibis/backends/trino/tests/conftest.py +++ b/ibis/backends/trino/tests/conftest.py @@ -2,7 +2,7 @@ import os import subprocess -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest import sqlglot as sg @@ -125,18 +125,10 @@ def load_tpch(self) -> None: c.execute(sql) - def _load_data(self, **_: Any) -> None: - """Load test data into a backend.""" - with self.connection.begin() as cur: - for stmt in self.ddl_script: - cur.execute(stmt) - def _tpch_table(self, name: str): from ibis import _ - table = self.connection.table( - self.default_identifier_case_fn(name), schema="ibis_sf1", database="hive" - ) + table = self.connection.table(name, schema="ibis_sf1", database="hive") table = table.mutate(s.across(s.of_type("double"), _.cast("decimal(15, 2)"))) return table diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index f6b1b0cdda2d8..d3b1c7e833425 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -364,8 +364,8 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString: try: backend = expr._find_backend() except com.IbisError: - # default to duckdb for sqlalchemy compilation because it supports - # the widest array of ibis features for SQL backends + # default to duckdb for SQL compilation because it supports the + # widest array of ibis features for SQL backends backend = ibis.duckdb read = "duckdb" write = ibis.options.sql.default_dialect diff --git a/ibis/formats/__init__.py b/ibis/formats/__init__.py index 54467a738ae05..b089ffb6e5479 100644 --- a/ibis/formats/__init__.py +++ b/ibis/formats/__init__.py @@ -18,8 +18,7 @@ class TypeMapper(Generic[T]): - # `T` is the format-specific type object, e.g. pyarrow.DataType or - # sqlalchemy.types.TypeEngine + # `T` is the format-specific type object, e.g. pyarrow.DataType @classmethod def from_ibis(cls, dtype: DataType) -> T: diff --git a/poetry.lock b/poetry.lock index 028062bce5a94..a4cc454cd3db7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2220,7 +2220,7 @@ test = ["coverage", "mock (>=4)", "pytest (>=7)", "pytest-cov", "pytest-mock (>= name = "greenlet" version = "3.0.3" description = "Lightweight in-process concurrent programming" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, @@ -4682,6 +4682,7 @@ files = [ [package.dependencies] requests = "*" +sqlalchemy = {version = "*", optional = true, markers = "extra == \"sqlalchemy\""} [package.extras] async = ["tornado"] @@ -6289,7 +6290,7 @@ jsonschema = ">=3.0" name = "sqlalchemy" version = "1.4.51" description = "Database Abstraction Library" -optional = false +optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ {file = "SQLAlchemy-1.4.51-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:1a09d5bd1a40d76ad90e5570530e082ddc000e1d92de495746f6257dc08f166b"}, @@ -7296,7 +7297,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pyexasol", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "sqlalchemy", "trino"] +all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pyexasol", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "trino"] bigquery = ["db-dtypes", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pydata-google-auth"] clickhouse = ["clickhouse-connect"] dask = ["dask", "regex"] @@ -7326,4 +7327,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "3c0e25e6963a7b7c69470015c6f0ba4616fdc98188bfeaa16583b15dd8260ecf" +content-hash = "a669c8ae57211aeaf0a14696134f4a708a18cfd373f12e671ce6fe4eddb53948" diff --git a/pyproject.toml b/pyproject.toml index 5d1b06db63f5d..63e29aeccf7c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,9 @@ pins = { version = ">=0.8.3,<1", extras = ["gcs"], optional = true } polars = { version = ">=0.19.3,<1", optional = true } psycopg2 = { version = ">=2.8.4,<3", optional = true } pydata-google-auth = { version = ">=1.4.0,<2", optional = true } -pydruid = { version = ">=0.6.5,<1", optional = true } +# we don't use sqlalchemy, but pydruid inadvertently requires it, +# see https://github.com/druid-io/pydruid/issues/313 +pydruid = { version = ">=0.6.5,<1", optional = true, extras = ["sqlalchemy"] } pyexasol = { version = ">=0.25.2,<1", optional = true, extras = ["pandas"] } pymysql = { version = ">=1,<2", optional = true } pyodbc = { version = ">=4.0.39,<6", optional = true } @@ -87,7 +89,6 @@ shapely = { version = ">=2,<3", optional = true } # we don't support arbitrarily old versions of this library due to security # issues with versions <3.0.2 snowflake-connector-python = { version = ">=3.0.2,<4,!=3.3.0b1", optional = true } -sqlalchemy = { version = ">=1.4,<3", optional = true } trino = { version = ">=0.321,<1", optional = true } [tool.poetry.group.dev.dependencies] @@ -120,7 +121,6 @@ pytest-repeat = ">=0.9.1,<0.10" pytest-snapshot = ">=0.9.0,<1" pytest-xdist = ">=2.3.0,<4" requests = ">=2,<3" -sqlalchemy = ">=1.4,<3" [tool.poetry.group.docs.dependencies] altair = { version = ">=5.0.1,<6", python = ">=3.10,<3.13" } @@ -145,8 +145,8 @@ all = [ "dask", "datafusion", "db-dtypes", - "duckdb", "deltalake", + "duckdb", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", @@ -166,7 +166,6 @@ all = [ "regex", "shapely", "snowflake-connector-python", - "sqlalchemy", "trino", ] bigquery = [ @@ -298,9 +297,6 @@ filterwarnings = [ "ignore:`np.object` is a deprecated alias for the builtin `object`:DeprecationWarning", # windows "ignore:getargs.* The 'u' format is deprecated:DeprecationWarning", - # sqlalchemy - "ignore:Class ST_.+ will not make use of SQL compilation caching:", - "ignore:UserDefinedType Geometry:", # google "ignore:Deprecated call to `pkg_resources\\.declare_namespace\\('.*'\\):DeprecationWarning", # pyspark on python 3.11 @@ -309,10 +305,6 @@ filterwarnings = [ "ignore:'cgi' is deprecated and slated for removal in Python 3\\.13:DeprecationWarning", # warnings from google's use of pkg_resources "ignore:pkg_resources is deprecated as an API:DeprecationWarning", - # sqlalchemy warns about mysql's inability to cast to bool; - # this has no effect on ibis's output because we convert types after - # execution - "ignore:Datatype BOOL does not support CAST on MySQL/MariaDB; the cast will be skipped:sqlalchemy.exc.SAWarning", # snowflake vendors an older version requests "ignore:'urllib3\\.contrib\\.pyopenssl' module is deprecated and will be removed in a future release of urllib3:DeprecationWarning", # apache-beam @@ -340,7 +332,6 @@ markers = [ "notyet: for functionality that isn't implemented in a backend", "never: tests for functionality that a backend is likely to never implement", "broken: test has exposed existing broken functionality", - "sqlalchemy_only: tests for SQLAlchemy based backends", "bigquery: BigQuery tests", "clickhouse: ClickHouse tests", "dask: Dask tests", diff --git a/requirements-dev.txt b/requirements-dev.txt index 1cb1ba6d5418b..6e5d2567a1048 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -176,7 +176,7 @@ pydantic-core==2.16.1 ; python_version >= "3.10" and python_version < "3.13" pydantic==2.6.0 ; python_version >= "3.10" and python_version < "3.13" pydata-google-auth==1.8.2 ; python_version >= "3.9" and python_version < "4.0" pydeps==1.12.18 ; python_version >= "3.9" and python_version < "4.0" -pydruid==0.6.6 ; python_version >= "3.9" and python_version < "4.0" +pydruid[sqlalchemy]==0.6.6 ; python_version >= "3.9" and python_version < "4.0" pyexasol[pandas]==0.25.2 ; python_version >= "3.9" and python_version < "4.0" pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0" pyinstrument==4.6.2 ; python_version >= "3.9" and python_version < "4.0"