From 20c560271940edf37a16d011d0b0a3b010bc0060 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 1 Feb 2024 14:29:43 -0500 Subject: [PATCH] refactor(risingwave): port to sqlglot (#8171) Co-authored-by: Kexiang Wang Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Co-authored-by: Jim Crist-Harif --- .github/workflows/ibis-backends.yml | 14 + ci/schema/risingwave.sql | 128 +-- ibis/backends/base/__init__.py | 44 + ibis/backends/base/sqlglot/compiler.py | 7 +- ibis/backends/base/sqlglot/datatypes.py | 18 + ibis/backends/clickhouse/compiler.py | 34 +- ibis/backends/conftest.py | 1 + ibis/backends/duckdb/compiler.py | 11 +- ibis/backends/mssql/compiler.py | 30 + ibis/backends/postgres/__init__.py | 42 +- ibis/backends/pyspark/compiler.py | 11 + ibis/backends/risingwave/__init__.py | 380 ++++---- ibis/backends/risingwave/compiler.py | 112 ++- ibis/backends/risingwave/dialect.py | 35 + ibis/backends/risingwave/tests/conftest.py | 51 +- .../test_client/test_compile_toplevel/out.sql | 5 +- .../test_union_cte/False/out.sql | 2 +- .../test_union_cte/True/out.sql | 2 +- ibis/backends/risingwave/tests/test_client.py | 66 +- .../risingwave/tests/test_functions.py | 180 +--- ibis/backends/tests/errors.py | 6 +- .../test_dot_sql/test_cte/risingwave/out.sql | 8 + .../test_default_limit/risingwave/out.sql | 5 + .../risingwave/out.sql | 5 + .../risingwave/out.sql | 3 + .../test_respect_set_limit/risingwave/out.sql | 10 + .../risingwave/out.sql | 22 + .../test_sql/test_isin_bug/risingwave/out.sql | 9 + .../test_union_aliasing/risingwave/out.sql | 60 ++ ibis/backends/tests/test_aggregation.py | 62 +- ibis/backends/tests/test_array.py | 149 ++- ibis/backends/tests/test_asof_join.py | 2 + ibis/backends/tests/test_benchmarks.py | 900 ------------------ ibis/backends/tests/test_client.py | 45 +- ibis/backends/tests/test_dot_sql.py | 17 +- ibis/backends/tests/test_export.py | 34 +- ibis/backends/tests/test_generic.py | 88 +- ibis/backends/tests/test_join.py | 2 +- ibis/backends/tests/test_map.py | 23 +- ibis/backends/tests/test_numeric.py | 87 +- ibis/backends/tests/test_param.py | 19 +- ibis/backends/tests/test_register.py | 48 +- ibis/backends/tests/test_set_ops.py | 9 +- ibis/backends/tests/test_sql.py | 8 +- ibis/backends/tests/test_string.py | 53 +- ibis/backends/tests/test_struct.py | 2 +- ibis/backends/tests/test_temporal.py | 83 +- ibis/backends/tests/test_uuid.py | 7 +- ibis/backends/tests/test_window.py | 76 +- pyproject.toml | 4 +- 50 files changed, 1026 insertions(+), 1993 deletions(-) create mode 100644 ibis/backends/risingwave/dialect.py create mode 100644 ibis/backends/tests/snapshots/test_dot_sql/test_cte/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_default_limit/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_isin_bug/risingwave/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_union_aliasing/risingwave/out.sql delete mode 100644 ibis/backends/tests/test_benchmarks.py diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 9214381bd548..a4208954b7a5 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -125,6 +125,12 @@ jobs: - postgres sys-deps: - libgeos-dev + - name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - name: impala title: Impala serial: true @@ -220,6 +226,14 @@ jobs: - postgres sys-deps: - libgeos-dev + - os: windows-latest + backend: + name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - os: windows-latest backend: name: postgres diff --git a/ci/schema/risingwave.sql b/ci/schema/risingwave.sql index cedfa8449d60..251b689ada0d 100644 --- a/ci/schema/risingwave.sql +++ b/ci/schema/risingwave.sql @@ -1,27 +1,27 @@ SET RW_IMPLICIT_FLUSH=true; -DROP TABLE IF EXISTS diamonds CASCADE; - -CREATE TABLE diamonds ( - carat FLOAT, - cut TEXT, - color TEXT, - clarity TEXT, - depth FLOAT, +DROP TABLE IF EXISTS "diamonds" CASCADE; + +CREATE TABLE "diamonds" ( + "carat" FLOAT, + "cut" TEXT, + "color" TEXT, + "clarity" TEXT, + "depth" FLOAT, "table" FLOAT, - price BIGINT, - x FLOAT, - y FLOAT, - z FLOAT + "price" BIGINT, + "x" FLOAT, + "y" FLOAT, + "z" FLOAT ) WITH ( connector = 'posix_fs', match_pattern = 'diamonds.csv', posix_fs.root = '/data', ) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); -DROP TABLE IF EXISTS astronauts CASCADE; +DROP TABLE IF EXISTS "astronauts" CASCADE; -CREATE TABLE astronauts ( +CREATE TABLE "astronauts" ( "id" BIGINT, "number" BIGINT, "nationwide_number" BIGINT, @@ -52,12 +52,12 @@ CREATE TABLE astronauts ( posix_fs.root = '/data', ) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); -DROP TABLE IF EXISTS batting CASCADE; +DROP TABLE IF EXISTS "batting" CASCADE; -CREATE TABLE batting ( +CREATE TABLE "batting" ( "playerID" TEXT, "yearID" BIGINT, - stint BIGINT, + "stint" BIGINT, "teamID" TEXT, "lgID" TEXT, "G" BIGINT, @@ -83,71 +83,71 @@ CREATE TABLE batting ( posix_fs.root = '/data', ) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); -DROP TABLE IF EXISTS awards_players CASCADE; +DROP TABLE IF EXISTS "awards_players" CASCADE; -CREATE TABLE awards_players ( +CREATE TABLE "awards_players" ( "playerID" TEXT, "awardID" TEXT, "yearID" BIGINT, "lgID" TEXT, - tie TEXT, - notes TEXT + "tie" TEXT, + "notes" TEXT ) WITH ( connector = 'posix_fs', match_pattern = 'awards_players.csv', posix_fs.root = '/data', ) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); -DROP TABLE IF EXISTS functional_alltypes CASCADE; - -CREATE TABLE functional_alltypes ( - id INTEGER, - bool_col BOOLEAN, - tinyint_col SMALLINT, - smallint_col SMALLINT, - int_col INTEGER, - bigint_col BIGINT, - float_col REAL, - double_col DOUBLE PRECISION, - date_string_col TEXT, - string_col TEXT, - timestamp_col TIMESTAMP WITHOUT TIME ZONE, - year INTEGER, - month INTEGER +DROP TABLE IF EXISTS "functional_alltypes" CASCADE; + +CREATE TABLE "functional_alltypes" ( + "id" INTEGER, + "bool_col" BOOLEAN, + "tinyint_col" SMALLINT, + "smallint_col" SMALLINT, + "int_col" INTEGER, + "bigint_col" BIGINT, + "float_col" REAL, + "double_col" DOUBLE PRECISION, + "date_string_col" TEXT, + "string_col" TEXT, + "timestamp_col" TIMESTAMP WITHOUT TIME ZONE, + "year" INTEGER, + "month" INTEGER ) WITH ( connector = 'posix_fs', match_pattern = 'functional_alltypes.csv', posix_fs.root = '/data', ) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); -DROP TABLE IF EXISTS tzone CASCADE; +DROP TABLE IF EXISTS "tzone" CASCADE; -CREATE TABLE tzone ( - ts TIMESTAMP WITH TIME ZONE, - key TEXT, - value DOUBLE PRECISION +CREATE TABLE "tzone" ( + "ts" TIMESTAMP WITH TIME ZONE, + "key" TEXT, + "value" DOUBLE PRECISION ); -INSERT INTO tzone +INSERT INTO "tzone" SELECT CAST('2017-05-28 11:01:31.000400' AS TIMESTAMP WITH TIME ZONE) + - t * INTERVAL '1 day 1 second' AS ts, - CHR(97 + t) AS key, - t + t / 10.0 AS value - FROM generate_series(0, 9) AS t; - -DROP TABLE IF EXISTS array_types CASCADE; - -CREATE TABLE IF NOT EXISTS array_types ( - x BIGINT[], - y TEXT[], - z DOUBLE PRECISION[], - grouper TEXT, - scalar_column DOUBLE PRECISION, - multi_dim BIGINT[][] + t * INTERVAL '1 day 1 second' AS "ts", + CHR(97 + t) AS "key", + t + t / 10.0 AS "value" + FROM generate_series(0, 9) AS "t"; + +DROP TABLE IF EXISTS "array_types" CASCADE; + +CREATE TABLE IF NOT EXISTS "array_types" ( + "x" BIGINT[], + "y" TEXT[], + "z" DOUBLE PRECISION[], + "grouper" TEXT, + "scalar_column" DOUBLE PRECISION, + "multi_dim" BIGINT[][] ); -INSERT INTO array_types VALUES +INSERT INTO "array_types" VALUES (ARRAY[1, 2, 3], ARRAY['a', 'b', 'c'], ARRAY[1.0, 2.0, 3.0], 'a', 1.0, ARRAY[ARRAY[NULL::BIGINT, NULL, NULL], ARRAY[1, 2, 3]]), (ARRAY[4, 5], ARRAY['d', 'e'], ARRAY[4.0, 5.0], 'a', 2.0, ARRAY[]::BIGINT[][]), (ARRAY[6, NULL], ARRAY['f', NULL], ARRAY[6.0, NULL], 'a', 3.0, ARRAY[NULL, ARRAY[]::BIGINT[], NULL]), @@ -155,11 +155,11 @@ INSERT INTO array_types VALUES (ARRAY[2, NULL, 3], ARRAY['b', NULL, 'c'], NULL, 'b', 5.0, NULL), (ARRAY[4, NULL, NULL, 5], ARRAY['d', NULL, NULL, 'e'], ARRAY[4.0, NULL, NULL, 5.0], 'c', 6.0, ARRAY[ARRAY[1, 2, 3]]); -DROP TABLE IF EXISTS json_t CASCADE; +DROP TABLE IF EXISTS "json_t" CASCADE; -CREATE TABLE IF NOT EXISTS json_t (js JSONB); +CREATE TABLE IF NOT EXISTS "json_t" ("js" JSONB); -INSERT INTO json_t VALUES +INSERT INTO "json_t" VALUES ('{"a": [1,2,3,4], "b": 1}'), ('{"a":null,"b":2}'), ('{"a":"foo", "c":null}'), @@ -167,9 +167,9 @@ INSERT INTO json_t VALUES ('[42,47,55]'), ('[]'); -DROP TABLE IF EXISTS win CASCADE; -CREATE TABLE win (g TEXT, x BIGINT, y BIGINT); -INSERT INTO win VALUES +DROP TABLE IF EXISTS "win" CASCADE; +CREATE TABLE "win" ("g" TEXT, "x" BIGINT, "y" BIGINT); +INSERT INTO "win" VALUES ('a', 0, 3), ('a', 1, 2), ('a', 2, 0), diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 7411077514d1..445997da3b09 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -41,6 +41,7 @@ "datafusion": "postgres", # closest match see https://github.com/ibis-project/ibis/pull/7303#discussion_r1350223901 "exasol": "oracle", + "risingwave": "postgres", } _SQLALCHEMY_TO_SQLGLOT_DIALECT = { @@ -75,6 +76,7 @@ def __dir__(self) -> list[str]: ------- list[str] A list of the attributes and tables available in the database. + """ attrs = dir(type(self)) unqualified_tables = [self._unqualify(x) for x in self.tables] @@ -92,6 +94,7 @@ def __contains__(self, table: str) -> bool: ------- bool True if the given table is available in the current database. + """ return table in self.tables @@ -103,6 +106,7 @@ def tables(self) -> list[str]: ------- list[str] The list of tables in the database + """ return self.list_tables() @@ -118,6 +122,7 @@ def __getitem__(self, table: str) -> ir.Table: ------- Table Table expression + """ return self.table(table) @@ -133,6 +138,7 @@ def __getattr__(self, table: str) -> ir.Table: ------- Table Table expression + """ return self.table(table) @@ -150,6 +156,7 @@ def drop(self, force: bool = False) -> None: force If `True`, drop any objects that exist, and do not fail if the database does not exist. + """ self.client.drop_database(self.name, force=force) @@ -165,6 +172,7 @@ def table(self, name: str) -> ir.Table: ------- Table Table expression + """ qualified_name = self._qualify(name) return self.client.table(qualified_name, self.name) @@ -178,6 +186,7 @@ def list_tables(self, like=None, database=None): A pattern to use for listing tables. database The database to perform the list against + """ return self.client.list_tables(like, database=database or self.name) @@ -192,6 +201,7 @@ class TablesAccessor(collections.abc.Mapping): >>> con = ibis.sqlite.connect("example.db") >>> people = con.tables["people"] # access via index >>> people = con.tables.people # access via attribute + """ def __init__(self, backend: BaseBackend): @@ -276,6 +286,7 @@ def to_pandas( "no limit". The default is in `ibis/config.py`. kwargs Keyword arguments + """ return self.execute(expr, params=params, limit=limit, **kwargs) @@ -309,6 +320,7 @@ def to_pandas_batches( ------- Iterator[pd.DataFrame] An iterator of pandas `DataFrame`s. + """ from ibis.formats.pandas import PandasData @@ -354,6 +366,7 @@ def to_pyarrow( ------- Table A pyarrow table holding the results of the executed expression. + """ pa = self._import_pyarrow() self._run_pre_execute_hooks(expr) @@ -403,6 +416,7 @@ def to_pyarrow_batches( ------- results RecordBatchReader + """ raise NotImplementedError @@ -432,6 +446,7 @@ def to_torch( ------- dict[str, torch.Tensor] A dictionary of torch tensors, keyed by column name. + """ import torch @@ -463,6 +478,7 @@ def read_parquet( ------- ir.Table The just-registered table + """ raise NotImplementedError( f"{self.name} does not support direct registration of parquet data." @@ -487,6 +503,7 @@ def read_csv( ------- ir.Table The just-registered table + """ raise NotImplementedError( f"{self.name} does not support direct registration of CSV data." @@ -511,6 +528,7 @@ def read_json( ------- ir.Table The just-registered table + """ raise NotImplementedError( f"{self.name} does not support direct registration of JSON data." @@ -536,6 +554,7 @@ def read_delta( ------- ir.Table The just-registered table. + """ raise NotImplementedError( f"{self.name} does not support direct registration of DeltaLake tables." @@ -567,6 +586,7 @@ def to_parquet( Additional keyword arguments passed to pyarrow.parquet.ParquetWriter https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html + """ self._import_pyarrow() import pyarrow.parquet as pq @@ -602,6 +622,7 @@ def to_csv( Additional keyword arguments passed to pyarrow.csv.CSVWriter https://arrow.apache.org/docs/python/generated/pyarrow.csv.CSVWriter.html + """ self._import_pyarrow() import pyarrow.csv as pcsv @@ -666,6 +687,7 @@ def list_databases(self, like: str | None = None) -> list[str]: list[str] The database names that exist in the current connection, that match the `like` pattern if provided. + """ @property @@ -685,6 +707,7 @@ def create_database(self, name: str, force: bool = False) -> None: Name of the new database. force If `False`, an exception is raised if the database already exists. + """ @abc.abstractmethod @@ -697,6 +720,7 @@ def drop_database(self, name: str, force: bool = False) -> None: Database to drop. force If `False`, an exception is raised if the database does not exist. + """ @@ -716,6 +740,7 @@ def create_schema( current database is used. force If `False`, an exception is raised if the schema exists. + """ @abc.abstractmethod @@ -733,6 +758,7 @@ def drop_schema( current database is used. force If `False`, an exception is raised if the schema does not exist. + """ @abc.abstractmethod @@ -755,6 +781,7 @@ def list_schemas( list[str] The schema names that exist in the current connection, that match the `like` pattern if provided. + """ @property @@ -814,6 +841,7 @@ def db_identity(self) -> str: ------- Hashable Database identity + """ parts = [self.__class__] parts.extend(self._con_args) @@ -844,6 +872,7 @@ def connect(self, *args, **kwargs) -> BaseBackend: ------- BaseBackend An instance of the backend + """ new_backend = self.__class__(*args, **kwargs) new_backend.reconnect() @@ -880,6 +909,7 @@ def database(self, name: str | None = None) -> Database: ------- Database A database object for the specified database. + """ return Database(name=name or self.current_database, client=self) @@ -905,6 +935,7 @@ def _filter_with_like(values: Iterable[str], like: str | None = None) -> list[st ------- list[str] Names filtered by the `like` pattern. + """ if like is None: return sorted(values) @@ -933,6 +964,7 @@ def list_tables( ------- list[str] The list of the table names that match the pattern `like`. + """ @abc.abstractmethod @@ -950,6 +982,7 @@ def table(self, name: str, database: str | None = None) -> ir.Table: ------- Table Table expression + """ @functools.cached_property @@ -963,6 +996,7 @@ def tables(self): >>> con = ibis.sqlite.connect("example.db") >>> people = con.tables["people"] # access via index >>> people = con.tables.people # access via attribute + """ return TablesAccessor(self) @@ -980,6 +1014,7 @@ def version(self) -> str: ------- str The backend version + """ @classmethod @@ -1088,6 +1123,7 @@ def create_table( ------- Table The table that was created. + """ @abc.abstractmethod @@ -1108,6 +1144,7 @@ def drop_table( Name of the database where the table exists, if not the default. force If `False`, an exception is raised if the table does not exist. + """ raise NotImplementedError( f'Backend "{self.name}" does not implement "drop_table"' @@ -1122,6 +1159,7 @@ def rename_table(self, old_name: str, new_name: str) -> None: The old name of the table. new_name The new name of the table. + """ raise NotImplementedError( f'Backend "{self.name}" does not implement "rename_table"' @@ -1154,6 +1192,7 @@ def create_view( ------- Table The view that was created. + """ @abc.abstractmethod @@ -1170,6 +1209,7 @@ def drop_view( Name of the database where the view exists, if not the default. force If `False`, an exception is raised if the view does not exist. + """ @classmethod @@ -1194,6 +1234,7 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: False >>> ibis.postgres.has_operation(ops.ArrayIndex) True + """ raise NotImplementedError( f"{cls.name} backend has not implemented `has_operation` API" @@ -1228,6 +1269,7 @@ def _release_cached(self, expr: ir.CachedTable) -> None: ---------- expr Cached expression to release + """ del self._query_cache[expr.op()] @@ -1268,6 +1310,7 @@ def _get_backend_names() -> frozenset[str]: If a `set` is used, then any in-place modifications to the set are visible to every caller of this function. + """ if sys.version_info < (3, 10): @@ -1325,6 +1368,7 @@ def connect(resource: Path | str, **kwargs: Any) -> BaseBackend: >>> con = ibis.connect( ... "bigquery://my-project/my-dataset" ... ) # quartodoc: +SKIP # doctest: +SKIP + """ url = resource = str(resource) diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index 0f5b3e738b6f..db90f5ea8061 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -141,9 +141,9 @@ def parenthesize(op, arg): C = ColGen() F = FuncGen() -NULL = sge.NULL -FALSE = sge.FALSE -TRUE = sge.TRUE +NULL = sge.Null() +FALSE = sge.false() +TRUE = sge.true() STAR = sge.Star() @@ -251,6 +251,7 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: ------- sqlglot.expressions.Expression A sqlglot expression + """ # substitute parameters immediately to avoid having to define a # ScalarParameter translation rule diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 370afc7a06b9..db1bae762c9a 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -394,6 +394,24 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: return sge.DataType(this=typecode.HSTORE) +class RisingWaveType(PostgresType): + dialect = "risingwave" + + @classmethod + def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: + if dtype.timezone is not None: + return sge.DataType(this=typecode.TIMESTAMPTZ) + return sge.DataType(this=typecode.TIMESTAMP) + + @classmethod + def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType: + return sge.DataType(this=typecode.DECIMAL) + + @classmethod + def _from_ibis_UUID(cls, dtype: dt.UUID) -> sge.DataType: + return sge.DataType(this=typecode.VARCHAR) + + class DataFusionType(PostgresType): unknown_type_strings = { "utf8": dt.string, diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index e4121a2ef9b2..6439fb99a5e2 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -209,24 +209,26 @@ def visit_Hash(self, op, *, arg): @visit_node.register(ops.HashBytes) def visit_HashBytes(self, op, *, arg, how): - supported_algorithms = frozenset( - ( - "MD5", - "halfMD5", - "SHA1", - "SHA224", - "SHA256", - "intHash32", - "intHash64", - "cityHash64", - "sipHash64", - "sipHash128", - ) - ) - if how not in supported_algorithms: + supported_algorithms = { + "md5": "MD5", + "MD5": "MD5", + "halfMD5": "halfMD5", + "SHA1": "SHA1", + "sha1": "SHA1", + "SHA224": "SHA224", + "sha224": "SHA224", + "SHA256": "SHA256", + "sha256": "SHA256", + "intHash32": "intHash32", + "intHash64": "intHash64", + "cityHash64": "cityHash64", + "sipHash64": "sipHash64", + "sipHash128": "sipHash128", + } + if (funcname := supported_algorithms.get(how)) is None: raise com.UnsupportedOperationError(f"Unsupported hash algorithm {how}") - return self.f[how](arg) + return self.f[funcname](arg) @visit_node.register(ops.IntervalFromInteger) def visit_IntervalFromInteger(self, op, *, arg, unit): diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index 949ae7752b94..5fe94de29418 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from collections.abc import Iterable + from ibis.backends.tests.base import BackendTest diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index e42bb5e733b0..a21735474c6a 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -329,10 +329,6 @@ def visit_GeoConvert(self, op, *, arg, source, target): # matches the behavior of the equivalent geopandas functionality return self.f.st_transform(arg, source, target, True) - @visit_node.register(ops.HexDigest) - def visit_HexDigest(self, op, *, arg, how): - return self.f[how](arg) - @visit_node.register(ops.TimestampNow) def visit_TimestampNow(self, op): """DuckDB current timestamp defaults to timestamp + tz.""" @@ -349,6 +345,13 @@ def visit_Quantile(self, op, *, arg, quantile, where): funcname = f"percentile_{suffix}" return self.agg[funcname](arg, quantile, where=where) + @visit_node.register(ops.HexDigest) + def visit_HexDigest(self, op, *, arg, how): + if how in ("md5", "sha256"): + return getattr(self.f, how)(arg) + else: + raise NotImplementedError(f"No available hashing function for {how}") + _SIMPLE_OPS = { ops.ArrayPosition: "list_indexof", diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 4609d241610d..d0c4470d7489 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -376,6 +376,36 @@ def visit_Not(self, op, *, arg): return sge.FALSE if arg == sge.TRUE else sge.TRUE return self.if_(arg, 1, 0).eq(0) + @visit_node.register(ops.HashBytes) + def visit_HashBytes(self, op, *, arg, how): + if how in ("md5", "sha1"): + return self.f.hashbytes(how, arg) + elif how == "sha256": + return self.f.hashbytes("sha2_256", arg) + elif how == "sha512": + return self.f.hashbytes("sha2_512", arg) + else: + raise NotImplementedError(how) + + @visit_node.register(ops.HexDigest) + def visit_HexDigest(self, op, *, arg, how): + if how in ("md5", "sha1"): + hashbinary = self.f.hashbytes(how, arg) + elif how == "sha256": + hashbinary = self.f.hashbytes("sha2_256", arg) + elif how == "sha512": + hashbinary = self.f.hashbytes("sha2_512", arg) + else: + raise NotImplementedError(how) + + # mssql uppercases the hexdigest which is inconsistent with several other + # implementations and inconsistent with Python, so lowercase it. + return self.f.lower( + self.f.convert( + sge.Literal(this="VARCHAR(MAX)", is_string=False), hashbinary, 2 + ) + ) + @visit_node.register(ops.Any) @visit_node.register(ops.All) @visit_node.register(ops.ApproxMedian) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 0f51218c7f7c..10365658518b 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -44,6 +44,7 @@ def _verify_source_line(func_name: str, line: str): class Backend(SQLGlotBackend): name = "postgres" + dialect = "postgres" compiler = PostgresCompiler() supports_python_udfs = True @@ -61,6 +62,7 @@ def _from_url(self, url: str, **kwargs): ------- BaseBackend A backend instance + """ url = urlparse(url) @@ -106,7 +108,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: raise exc.IbisTypeError( - "Postgres cannot yet reliably handle `null` typed columns; " + f"{self.name} cannot yet reliably handle `null` typed columns; " f"got null typed columns: {null_columns}" ) @@ -137,18 +139,18 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: ), properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]), ) - create_stmt_sql = create_stmt.sql(self.name) + create_stmt_sql = create_stmt.sql(self.dialect) columns = schema.keys() df = op.data.to_frame() data = df.itertuples(index=False) cols = ", ".join( - ident.sql(self.name) + ident.sql(self.dialect) for ident in map(partial(sg.to_identifier, quoted=quoted), columns) ) specs = ", ".join(repeat("%s", len(columns))) table = sg.table(name, quoted=quoted) - sql = f"INSERT INTO {table.sql(self.name)} ({cols}) VALUES ({specs})" + sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})" with self.begin() as cur: cur.execute(create_stmt_sql) extras.execute_batch(cur, sql, data, 128) @@ -254,6 +256,7 @@ def do_connect( timestamp_col : timestamp year : int32 month : int32 + """ self.con = psycopg2.connect( @@ -291,6 +294,7 @@ def list_tables( The `schema` parameter does **not** refer to the column names and types of `table`. ::: + """ if database is not None: util.warn_deprecated( @@ -314,7 +318,7 @@ def list_tables( .from_(sg.table("tables", db="information_schema")) .distinct() .where(*conditions) - .sql(self.name) + .sql(self.dialect) ) with self._safe_raw_sql(sql) as cur: @@ -447,10 +451,10 @@ def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: """No op.""" def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None: - raise NotImplementedError("pyarrow UDFs are not supported in Postgres") + raise NotImplementedError(f"pyarrow UDFs are not supported in {self.name}") def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: - raise NotImplementedError("pandas UDFs are not supported in Postgres") + raise NotImplementedError(f"pandas UDFs are not supported in {self.name}") def _define_udf_translation_rules(self, expr: ir.Expr) -> None: """No-op, these are defined in the compiler.""" @@ -535,11 +539,11 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: create_stmt = sge.Create( kind="VIEW", this=sg.table(name), - expression=sg.parse_one(query, read=self.name), + expression=sg.parse_one(query, read=self.dialect), properties=sge.Properties(expressions=[sge.TemporaryProperty()]), ) drop_stmt = sge.Drop(kind="VIEW", this=sg.table(name), exists=True).sql( - self.name + self.dialect ) with self._safe_raw_sql(create_stmt): @@ -555,7 +559,7 @@ def create_schema( ) -> None: if database is not None and database != self.current_database: raise exc.UnsupportedOperationError( - "Postgres does not support creating a schema in a different database" + f"{self.name} does not support creating a schema in a different database" ) sql = sge.Create( kind="SCHEMA", this=sg.table(name, catalog=database), exists=force @@ -572,7 +576,7 @@ def drop_schema( ) -> None: if database is not None and database != self.current_database: raise exc.UnsupportedOperationError( - "Postgres does not support dropping a schema in a different database" + f"{self.name} does not support dropping a schema in a different database" ) sql = sge.Drop( @@ -614,13 +618,14 @@ def create_table( overwrite If `True`, replace the table if it already exists, otherwise fail if the table exists + """ if obj is None and schema is None: raise ValueError("Either `obj` or `schema` must be specified") if database is not None and database != self.current_database: raise com.UnsupportedOperationError( - "Creating tables in other databases is not supported by Postgres" + f"Creating tables in other databases is not supported by {self.name}" ) else: database = None @@ -672,15 +677,15 @@ def create_table( this = sg.table(name, catalog=database, quoted=self.compiler.quoted) with self._safe_raw_sql(create_stmt) as cur: if query is not None: - insert_stmt = sge.Insert(this=table, expression=query).sql(self.name) + insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect) cur.execute(insert_stmt) if overwrite: cur.execute( - sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name) + sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect) ) cur.execute( - f"ALTER TABLE IF EXISTS {table.sql(self.name)} RENAME TO {this.sql(self.name)}" + f"ALTER TABLE IF EXISTS {table.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}" ) if schema is None: @@ -700,7 +705,7 @@ def drop_table( ) -> None: if database is not None and database != self.current_database: raise com.UnsupportedOperationError( - "Droppping tables in other databases is not supported by Postgres" + f"Droppping tables in other databases is not supported by {self.name}" ) else: database = None @@ -721,7 +726,7 @@ def _safe_raw_sql(self, *args, **kwargs): def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: with contextlib.suppress(AttributeError): - query = query.sql(dialect=self.name) + query = query.sql(dialect=self.dialect) con = self.con cursor = con.cursor() @@ -771,7 +776,8 @@ def truncate_table(self, name: str, database: str | None = None) -> None: Table name database Schema name + """ - ident = sg.table(name, db=database).sql(self.name) + ident = sg.table(name, db=database).sql(self.dialect) with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index bc2bbf2b7584..b4e75c959735 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -457,6 +457,17 @@ def visit_JoinLink(self, op, **kwargs): def visit_Undefined(self, op, **_): raise com.OperationNotDefinedError(type(op).__name__) + @visit_node.register(ops.HexDigest) + def visit_HexDigest(self, op, *, arg, how): + if how == "md5": + return self.f.md5(arg) + elif how == "sha1": + return self.f.sha1(arg) + elif how in ("sha256", "sha512"): + return self.f.sha2(arg, int(how[-3:])) + else: + raise NotImplementedError(f"No available hashing function for {how}") + _SIMPLE_OPS = { ops.ArrayDistinct: "array_distinct", diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index 04de491f6dfe..996f776fd12e 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -2,36 +2,40 @@ from __future__ import annotations -import inspect -from typing import TYPE_CHECKING, Callable, Literal - -import sqlalchemy as sa - -import ibis.common.exceptions as exc +import atexit +from functools import partial +from itertools import repeat +from typing import TYPE_CHECKING + +import psycopg2 +import sqlglot as sg +import sqlglot.expressions as sge +from psycopg2 import extras + +import ibis +import ibis.common.exceptions as com import ibis.expr.operations as ops +import ibis.expr.types as ir from ibis import util -from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend +from ibis.backends.postgres import Backend as PostgresBackend from ibis.backends.risingwave.compiler import RisingwaveCompiler -from ibis.backends.risingwave.datatypes import RisingwaveType -from ibis.common.exceptions import InvalidDecoratorError +from ibis.backends.risingwave.dialect import RisingWave as RisingWaveDialect if TYPE_CHECKING: - from collections.abc import Iterable - - import ibis.expr.datatypes as dt + import pandas as pd + import pyarrow as pa def _verify_source_line(func_name: str, line: str): if line.startswith("@"): - raise InvalidDecoratorError(func_name, line) + raise com.InvalidDecoratorError(func_name, line) return line -class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema): +class Backend(PostgresBackend): name = "risingwave" - compiler = RisingwaveCompiler - supports_temporary_tables = False - supports_create_or_replace = False + dialect = RisingWaveDialect + compiler = RisingwaveCompiler() supports_python_udfs = False def do_connect( @@ -42,10 +46,8 @@ def do_connect( port: int = 5432, database: str | None = None, schema: str | None = None, - url: str | None = None, - driver: Literal["psycopg2"] = "psycopg2", ) -> None: - """Create an Ibis client connected to Risingwave database. + """Create an Ibis client connected to RisingWave database. Parameters ---------- @@ -60,13 +62,7 @@ def do_connect( database Database to connect to schema - Risingwave schema to use. If `None`, use the default `search_path`. - url - SQLAlchemy connection string. - - If passed, the other connection arguments are ignored. - driver - Database driver + RisingWave schema to use. If `None`, use the default `search_path`. Examples -------- @@ -98,185 +94,199 @@ def do_connect( timestamp_col : timestamp year : int32 month : int32 + """ - if driver != "psycopg2": - raise NotImplementedError("psycopg2 is currently the only supported driver") - alchemy_url = self._build_alchemy_url( - url=url, + self.con = psycopg2.connect( host=host, port=port, user=user, password=password, database=database, - driver=f"risingwave+{driver}", - ) - - connect_args = {} - if schema is not None: - connect_args["options"] = f"-csearch_path={schema}" - - engine = sa.create_engine( - alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool + options=(f"-csearch_path={schema}" * (schema is not None)) or None, ) - @sa.event.listens_for(engine, "connect") - def connect(dbapi_connection, connection_record): - with dbapi_connection.cursor() as cur: - cur.execute("SET TIMEZONE = UTC") + with self.begin() as cur: + cur.execute("SET TIMEZONE = UTC") - super().do_connect(engine) + self._temp_views = set() - def list_tables(self, like=None, schema=None): - """List the tables in the database. + def create_table( + self, + name: str, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, + *, + schema: ibis.Schema | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + ): + """Create a table in Risingwave. Parameters ---------- - like - A pattern to use for listing tables. + name + Name of the table to create + obj + The data with which to populate the table; optional, but at least + one of `obj` or `schema` must be specified schema - The schema to perform the list against. - - ::: {.callout-warning} - ## `schema` refers to database hierarchy + The schema of the table to create; optional, but at least one of + `obj` or `schema` must be specified + database + The name of the database in which to create the table; if not + passed, the current database is used. + temp + Create a temporary table + overwrite + If `True`, replace the table if it already exists, otherwise fail + if the table exists - The `schema` parameter does **not** refer to the column names and - types of `table`. - ::: """ - tables = self.inspector.get_table_names(schema=schema) - views = self.inspector.get_view_names(schema=schema) - return self._filter_with_like(tables + views, like) - - def list_databases(self, like=None) -> list[str]: - # http://dba.stackexchange.com/a/1304/58517 - dbs = sa.table( - "pg_database", - sa.column("datname", sa.TEXT()), - sa.column("datistemplate", sa.BOOLEAN()), - schema="pg_catalog", + if obj is None and schema is None: + raise ValueError("Either `obj` or `schema` must be specified") + + if database is not None and database != self.current_database: + raise com.UnsupportedOperationError( + f"Creating tables in other databases is not supported by {self.name}" + ) + else: + database = None + + properties = [] + + if temp: + properties.append(sge.TemporaryProperty()) + + if obj is not None: + if not isinstance(obj, ir.Expr): + table = ibis.memtable(obj) + else: + table = obj + + self._run_pre_execute_hooks(table) + + query = self._to_sqlglot(table) + else: + query = None + + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(colname, quoted=self.compiler.quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for colname, typ in (schema or table.schema()).items() + ] + + if overwrite: + temp_name = util.gen_name(f"{self.name}_table") + else: + temp_name = name + + table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted) + target = sge.Schema(this=table, expressions=column_defs) + + create_stmt = sge.Create( + kind="TABLE", + this=target, + properties=sge.Properties(expressions=properties), ) - query = sa.select(dbs.c.datname).where(sa.not_(dbs.c.datistemplate)) - with self.begin() as con: - databases = list(con.execute(query).scalars()) - - return self._filter_with_like(databases, like) - - @property - def current_database(self) -> str: - return self._scalar_query(sa.select(sa.func.current_database())) - - @property - def current_schema(self) -> str: - return self._scalar_query(sa.select(sa.func.current_schema())) - - def function(self, name: str, *, schema: str | None = None) -> Callable: - query = sa.text( - """ -SELECT - n.nspname as schema, - pg_catalog.pg_get_function_result(p.oid) as return_type, - string_to_array(pg_catalog.pg_get_function_arguments(p.oid), ', ') as signature, - CASE p.prokind - WHEN 'a' THEN 'agg' - WHEN 'w' THEN 'window' - WHEN 'p' THEN 'proc' - ELSE 'func' - END as "Type" -FROM pg_catalog.pg_proc p -LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace -WHERE p.proname = :name -""" - + "AND n.nspname OPERATOR(pg_catalog.~) :schema COLLATE pg_catalog.default" - * (schema is not None) - ).bindparams(name=name, schema=f"^({schema})$") - - def split_name_type(arg: str) -> tuple[str, dt.DataType]: - name, typ = arg.split(" ", 1) - return name, RisingwaveType.from_string(typ) - - with self.begin() as con: - rows = con.execute(query).mappings().fetchall() - - if not rows: - name = f"{schema}.{name}" if schema else name - raise exc.MissingUDFError(name) - elif len(rows) > 1: - raise exc.AmbiguousUDFError(name) - - [row] = rows - return_type = RisingwaveType.from_string(row["return_type"]) - signature = list(map(split_name_type, row["signature"])) - - # dummy callable - def fake_func(*args, **kwargs): - ... - - fake_func.__name__ = name - fake_func.__signature__ = inspect.Signature( - [ - inspect.Parameter( - name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typ + + this = sg.table(name, catalog=database, quoted=self.compiler.quoted) + with self._safe_raw_sql(create_stmt) as cur: + if query is not None: + insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect) + cur.execute(insert_stmt) + + if overwrite: + cur.execute( + sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect) ) - for name, typ in signature - ], - return_annotation=return_type, - ) - fake_func.__annotations__ = {"return": return_type, **dict(signature)} - op = ops.udf.scalar.builtin(fake_func, schema=schema) - return op - - def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: - name = util.gen_name("risingwave_metadata") - type_info_sql = """\ - SELECT - attname, - format_type(atttypid, atttypmod) AS type - FROM pg_attribute - WHERE attrelid = CAST(:name AS regclass) - AND attnum > 0 - AND NOT attisdropped - ORDER BY attnum""" - if self.inspector.has_table(query): - query = f"TABLE {query}" - - text = sa.text(type_info_sql).bindparams(name=name) - with self.begin() as con: - con.exec_driver_sql(f"CREATE VIEW IF NOT EXISTS {name} AS {query}") - try: - yield from ( - (col, RisingwaveType.from_string(typestr)) - for col, typestr in con.execute(text) + cur.execute( + f"ALTER TABLE {table.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}" ) - finally: - con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}") - def _get_temp_view_definition( - self, name: str, definition: sa.sql.compiler.Compiled - ) -> str: - yield f"DROP VIEW IF EXISTS {name}" - yield f"CREATE TEMPORARY VIEW {name} AS {definition}" - - def create_schema( - self, name: str, database: str | None = None, force: bool = False - ) -> None: - if database is not None and database != self.current_database: - raise exc.UnsupportedOperationError( - "Risingwave does not support creating a schema in a different database" + if schema is None: + return self.table(name, schema=database) + + # preserve the input schema if it was provided + return ops.DatabaseTable( + name, schema=schema, source=self, namespace=ops.Namespace(database=database) + ).to_expr() + + def _get_temp_view_definition(self, name: str, definition): + drop = sge.Drop( + kind="VIEW", exists=True, this=sg.table(name), cascade=True + ).sql(self.dialect) + + create = sge.Create( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind="VIEW", + expression=definition, + replace=False, + ).sql(self.dialect) + + atexit.register(self._clean_up_tmp_view, name) + return f"{drop}; {create}" + + def _clean_up_tmp_view(self, name: str) -> None: + drop = sge.Drop( + kind="VIEW", exists=True, this=sg.table(name), cascade=True + ).sql(self.dialect) + with self.begin() as bind: + bind.execute(drop) + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + schema = op.schema + if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: + raise com.IbisTypeError( + f"{self.name} cannot yet reliably handle `null` typed columns; " + f"got null typed columns: {null_columns}" ) - if_not_exists = "IF NOT EXISTS " * force - name = self._quote(name) - with self.begin() as con: - con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") - def drop_schema( - self, name: str, database: str | None = None, force: bool = False - ) -> None: - if database is not None and database != self.current_database: - raise exc.UnsupportedOperationError( - "Risingwave does not support dropping a schema in a different database" + # only register if we haven't already done so + if (name := op.name) not in self.list_tables(): + quoted = self.compiler.quoted + column_defs = [ + sg.exp.ColumnDef( + this=sg.to_identifier(colname, quoted=quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [ + sg.exp.ColumnConstraint( + kind=sg.exp.NotNullColumnConstraint() + ) + ] + ), + ) + for colname, typ in schema.items() + ] + + create_stmt = sg.exp.Create( + kind="TABLE", + this=sg.exp.Schema( + this=sg.to_identifier(name, quoted=quoted), expressions=column_defs + ), + ) + create_stmt_sql = create_stmt.sql(self.dialect) + + columns = schema.keys() + df = op.data.to_frame() + data = df.itertuples(index=False) + cols = ", ".join( + ident.sql(self.dialect) + for ident in map(partial(sg.to_identifier, quoted=quoted), columns) ) - name = self._quote(name) - if_exists = "IF EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") + specs = ", ".join(repeat("%s", len(columns))) + table = sg.table(name, quoted=quoted) + sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})" + with self.begin() as cur: + cur.execute(create_stmt_sql) + extras.execute_batch(cur, sql, data, 128) diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py index b4bcd9c0b9d5..5bc7bfef2f5b 100644 --- a/ibis/backends/risingwave/compiler.py +++ b/ibis/backends/risingwave/compiler.py @@ -1,34 +1,104 @@ from __future__ import annotations +from functools import singledispatchmethod + +import sqlglot.expressions as sge +from public import public + +import ibis.common.exceptions as com +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator -from ibis.backends.risingwave.datatypes import RisingwaveType -from ibis.backends.risingwave.registry import operation_registry -from ibis.expr.rewrites import rewrite_sample +from ibis.backends.base.sqlglot.datatypes import RisingWaveType +from ibis.backends.postgres.compiler import PostgresCompiler +from ibis.backends.risingwave.dialect import RisingWave # noqa: F401 + + +@public +class RisingwaveCompiler(PostgresCompiler): + __slots__ = () + + dialect = "risingwave" + name = "risingwave" + type_mapper = RisingWaveType + + @singledispatchmethod + def visit_node(self, op, **kwargs): + return super().visit_node(op, **kwargs) + + @visit_node.register(ops.Correlation) + def visit_Correlation(self, op, *, left, right, how, where): + if how == "sample": + raise com.UnsupportedOperationError( + f"{self.name} only implements `pop` correlation coefficient" + ) + return super().visit_Correlation( + op, left=left, right=right, how=how, where=where + ) + + @visit_node.register(ops.TimestampTruncate) + @visit_node.register(ops.DateTruncate) + @visit_node.register(ops.TimeTruncate) + def visit_TimestampTruncate(self, op, *, arg, unit): + unit_mapping = { + "Y": "year", + "Q": "quarter", + "M": "month", + "W": "week", + "D": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "milliseconds", + "us": "microseconds", + } + + if (unit := unit_mapping.get(unit.short)) is None: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") + + return self.f.date_trunc(unit, arg) + @visit_node.register(ops.IntervalFromInteger) + def visit_IntervalFromInteger(self, op, *, arg, unit): + if op.arg.shape == ds.scalar: + return sge.Interval(this=arg, unit=self.v[unit.name]) + elif op.arg.shape == ds.columnar: + return arg * sge.Interval(this=sge.convert(1), unit=self.v[unit.name]) + else: + raise ValueError("Invalid shape for converting to interval") -class RisingwaveExprTranslator(AlchemyExprTranslator): - _registry = operation_registry.copy() - _rewrites = AlchemyExprTranslator._rewrites.copy() - _has_reduction_filter_syntax = True - _supports_tuple_syntax = True - _dialect_name = "risingwave" + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_binary(): + return self.cast("".join(map(r"\x{:0>2x}".format, value)), dt.binary) + elif dtype.is_date(): + return self.cast(value.isoformat(), dtype) + elif dtype.is_json(): + return sge.convert(str(value)) + return None - # it does support it, but we can't use it because of support for pivot - supports_unnest_in_select = False + @visit_node.register(ops.DateFromYMD) + @visit_node.register(ops.Mode) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) - type_mapper = RisingwaveType +_SIMPLE_OPS = { + ops.First: "first_value", + ops.Last: "last_value", +} -rewrites = RisingwaveExprTranslator.rewrites +for _op, _name in _SIMPLE_OPS.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + @RisingwaveCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, where, **kw): + return self.agg[_name](*kw.values(), where=where) -@rewrites(ops.Any) -@rewrites(ops.All) -def _any_all_no_op(expr): - return expr + else: + @RisingwaveCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, **kw): + return self.f[_name](*kw.values()) -class RisingwaveCompiler(AlchemyCompiler): - translator_class = RisingwaveExprTranslator - rewrites = AlchemyCompiler.rewrites | rewrite_sample + setattr(RisingwaveCompiler, f"visit_{_op.__name__}", _fmt) diff --git a/ibis/backends/risingwave/dialect.py b/ibis/backends/risingwave/dialect.py new file mode 100644 index 000000000000..2237c2a4d188 --- /dev/null +++ b/ibis/backends/risingwave/dialect.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import sqlglot.expressions as sge +from sqlglot import generator +from sqlglot.dialects import Postgres + + +class RisingWave(Postgres): + # Need to disable timestamp precision + # No "or replace" allowed in create statements + # no "not null" clause for column constraints + + class Generator(generator.Generator): + SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False + LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False + QUERY_HINTS = False + NVL2_SUPPORTED = False + PARAMETER_TOKEN = "$" + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True + JSON_TYPE_REQUIRED_FOR_EXTRACTION = True + SUPPORTS_UNLOGGED_TABLES = True + + TYPE_MAPPING = { + **Postgres.Generator.TYPE_MAPPING, + sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMPTZ", + } + + TRANSFORMS = { + **Postgres.Generator.TRANSFORMS, + } diff --git a/ibis/backends/risingwave/tests/conftest.py b/ibis/backends/risingwave/tests/conftest.py index 35cfe6b8e1db..4ffb2ab85722 100644 --- a/ibis/backends/risingwave/tests/conftest.py +++ b/ibis/backends/risingwave/tests/conftest.py @@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, Any import pytest -import sqlalchemy as sa import ibis -from ibis.backends.conftest import init_database from ibis.backends.tests.base import ServiceBackendTest if TYPE_CHECKING: @@ -35,23 +33,14 @@ class TestConf(ServiceBackendTest): supports_structs = False rounding_method = "half_to_even" service_name = "risingwave" - deps = "psycopg2", "sqlalchemy" + deps = ("psycopg2",) @property def test_files(self) -> Iterable[Path]: return self.data_dir.joinpath("csv").glob("*.csv") - def _load_data( - self, - *, - user: str = PG_USER, - password: str = PG_PASS, - host: str = PG_HOST, - port: int = PG_PORT, - database: str = IBIS_TEST_RISINGWAVE_DB, - **_: Any, - ) -> None: - """Load test data into a Risingwave backend instance. + def _load_data(self, **_: Any) -> None: + """Load test data into a PostgreSQL backend instance. Parameters ---------- @@ -60,15 +49,8 @@ def _load_data( script_dir Location of scripts defining schemas """ - init_database( - url=sa.engine.make_url( - f"risingwave://{user}:{password}@{host}:{port:d}/{database}" - ), - database=database, - schema=self.ddl_script, - isolation_level="AUTOCOMMIT", - recreate=False, - ) + with self.connection._safe_raw_sql(";".join(self.ddl_script)): + pass @staticmethod def connect(*, tmpdir, worker_id, port: int | None = None, **kw): @@ -91,13 +73,8 @@ def con(tmp_path_factory, data_dir, worker_id): @pytest.fixture(scope="module") -def db(con): - return con.database() - - -@pytest.fixture(scope="module") -def alltypes(db): - return db.functional_alltypes +def alltypes(con): + return con.tables.functional_alltypes @pytest.fixture(scope="module") @@ -105,20 +82,6 @@ def df(alltypes): return alltypes.execute() -@pytest.fixture(scope="module") -def alltypes_sqla(con, alltypes): - name = alltypes.op().name - return con._get_sqla_table(name) - - @pytest.fixture(scope="module") def intervals(con): return con.table("intervals") - - -@pytest.fixture -def translate(): - from ibis.backends.risingwave import Backend - - context = Backend.compiler.make_context() - return lambda expr: Backend.compiler.translator_class(expr, context).get_result() diff --git a/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql index cfbcf133a863..c0b4a0b83304 100644 --- a/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql +++ b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql @@ -1,2 +1,3 @@ -SELECT sum(t0.foo) AS "Sum(foo)" -FROM t0 AS t0 \ No newline at end of file +SELECT + SUM("t0"."foo") AS "Sum(foo)" +FROM "t0" AS "t0" \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql index 34761d9a76e0..f0366d83444d 100644 --- a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql @@ -1 +1 @@ -WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION ALL SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION ALL SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file +WITH "t1" AS ( SELECT "t0"."string_col", SUM("t0"."double_col") AS "metric" FROM "functional_alltypes" AS "t0" GROUP BY 1 ) SELECT "t7"."string_col", "t7"."metric" FROM ( SELECT "t5"."string_col", "t5"."metric" FROM ( SELECT * FROM "t1" AS "t2" UNION ALL SELECT * FROM "t1" AS "t4" ) AS "t5" UNION ALL SELECT * FROM "t1" AS "t3" ) AS "t7" \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql index 6ce31e7468bb..5a873785e92b 100644 --- a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql @@ -1 +1 @@ -WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file +WITH "t1" AS ( SELECT "t0"."string_col", SUM("t0"."double_col") AS "metric" FROM "functional_alltypes" AS "t0" GROUP BY 1 ) SELECT "t7"."string_col", "t7"."metric" FROM ( SELECT "t5"."string_col", "t5"."metric" FROM ( SELECT * FROM "t1" AS "t2" UNION SELECT * FROM "t1" AS "t4" ) AS "t5" UNION SELECT * FROM "t1" AS "t3" ) AS "t7" \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/test_client.py b/ibis/backends/risingwave/tests/test_client.py index b5c7cfa98560..918b648b7bc8 100644 --- a/ibis/backends/risingwave/tests/test_client.py +++ b/ibis/backends/risingwave/tests/test_client.py @@ -4,17 +4,15 @@ import pandas as pd import pytest +import sqlglot as sg from pytest import param import ibis import ibis.expr.datatypes as dt import ibis.expr.types as ir -from ibis.tests.util import assert_equal +from ibis.util import gen_name pytest.importorskip("psycopg2") -sa = pytest.importorskip("sqlalchemy") - -from sqlalchemy.dialects import postgresql # noqa: E402 RISINGWAVE_TEST_DB = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") IBIS_RISINGWAVE_HOST = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") @@ -64,47 +62,15 @@ def test_list_databases(con): assert RISINGWAVE_TEST_DB in con.list_databases() -def test_schema_type_conversion(con): - typespec = [ - # name, type, nullable - ("jsonb", postgresql.JSONB, True, dt.JSON), - ] - - sqla_types = [] - ibis_types = [] - for name, t, nullable, ibis_type in typespec: - sqla_types.append(sa.Column(name, t, nullable=nullable)) - ibis_types.append((name, ibis_type(nullable=nullable))) - - # Create a table with placeholder stubs for JSON, JSONB, and UUID. - table = sa.Table("tname", sa.MetaData(), *sqla_types) - - # Check that we can correctly create a schema with dt.any for the - # missing types. - schema = con._schema_from_sqla_table(table) - expected = ibis.schema(ibis_types) - - assert_equal(schema, expected) +def test_create_and_drop_table(con, temp_table): + sch = ibis.schema([("first_name", "string")]) + con.create_table(temp_table, schema=sch) + assert con.table(temp_table) is not None -@pytest.mark.parametrize("params", [{}, {"database": RISINGWAVE_TEST_DB}]) -def test_create_and_drop_table(con, temp_table, params): - sch = ibis.schema( - [ - ("first_name", "string"), - ("last_name", "string"), - ("department_name", "string"), - ("salary", "float64"), - ] - ) - - con.create_table(temp_table, schema=sch, **params) - assert con.table(temp_table, **params) is not None - - con.drop_table(temp_table, **params) + con.drop_table(temp_table) - with pytest.raises(sa.exc.NoSuchTableError): - con.table(temp_table, **params) + assert temp_table not in con.list_tables() @pytest.mark.parametrize( @@ -124,8 +90,8 @@ def test_create_and_drop_table(con, temp_table, params): ("date", dt.date), ("time", dt.time), ("time without time zone", dt.time), - ("timestamp without time zone", dt.timestamp), - ("timestamp with time zone", dt.Timestamp("UTC")), + ("timestamp without time zone", dt.Timestamp(scale=6)), + ("timestamp with time zone", dt.Timestamp("UTC", scale=6)), ("interval", dt.Interval("s")), ("numeric", dt.decimal), ("jsonb", dt.json), @@ -133,17 +99,16 @@ def test_create_and_drop_table(con, temp_table, params): ], ) def test_get_schema_from_query(con, pg_type, expected_type): - name = con._quote(ibis.util.guid()) + name = sg.table(gen_name("risingwave_temp_table"), quoted=True) with con.begin() as c: - c.exec_driver_sql(f"CREATE TABLE {name} (x {pg_type}, y {pg_type}[])") + c.execute(f"CREATE TABLE {name} (x {pg_type}, y {pg_type}[])") expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type))) result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}") assert result_schema == expected_schema with con.begin() as c: - c.exec_driver_sql(f"DROP TABLE {name}") + c.execute(f"DROP TABLE {name}") -@pytest.mark.xfail(reason="unsupported insert with CTEs") def test_insert_with_cte(con): X = con.create_table("X", schema=ibis.schema(dict(id="int")), temp=False) expr = X.join(X.mutate(a=X["id"] + 1), ["id"]) @@ -151,8 +116,3 @@ def test_insert_with_cte(con): assert Y.execute().empty con.drop_table("Y") con.drop_table("X") - - -def test_connect_url_with_empty_host(): - con = ibis.connect("risingwave:///dev") - assert con.con.url.host is None diff --git a/ibis/backends/risingwave/tests/test_functions.py b/ibis/backends/risingwave/tests/test_functions.py index c8874e390c60..d680fb3190f9 100644 --- a/ibis/backends/risingwave/tests/test_functions.py +++ b/ibis/backends/risingwave/tests/test_functions.py @@ -1,7 +1,6 @@ from __future__ import annotations import operator -import string import warnings from datetime import datetime @@ -13,104 +12,9 @@ import ibis import ibis.expr.datatypes as dt -import ibis.expr.types as ir -from ibis import config from ibis import literal as L pytest.importorskip("psycopg2") -sa = pytest.importorskip("sqlalchemy") - -from sqlalchemy.dialects import postgresql # noqa: E402 - - -@pytest.mark.parametrize( - ("left_func", "right_func"), - [ - param( - lambda t: t.double_col.cast("int8"), - lambda at: sa.cast(at.c.double_col, sa.SMALLINT), - id="double_to_int8", - ), - param( - lambda t: t.double_col.cast("int16"), - lambda at: sa.cast(at.c.double_col, sa.SMALLINT), - id="double_to_int16", - ), - param( - lambda t: t.string_col.cast("double"), - lambda at: sa.cast(at.c.string_col, postgresql.DOUBLE_PRECISION), - id="string_to_double", - ), - param( - lambda t: t.string_col.cast("float32"), - lambda at: sa.cast(at.c.string_col, postgresql.REAL), - id="string_to_float", - ), - param( - lambda t: t.string_col.cast("decimal"), - lambda at: sa.cast(at.c.string_col, sa.NUMERIC()), - id="string_to_decimal_no_params", - ), - param( - lambda t: t.string_col.cast("decimal(9, 3)"), - lambda at: sa.cast(at.c.string_col, sa.NUMERIC(9, 3)), - id="string_to_decimal_params", - ), - ], -) -def test_cast(alltypes, alltypes_sqla, translate, left_func, right_func): - left = left_func(alltypes) - right = right_func(alltypes_sqla.alias("t0")) - assert str(translate(left.op()).compile()) == str(right.compile()) - - -def test_date_cast(alltypes, alltypes_sqla, translate): - result = alltypes.date_string_col.cast("date") - expected = sa.cast(alltypes_sqla.alias("t0").c.date_string_col, sa.DATE) - assert str(translate(result.op())) == str(expected) - - -@pytest.mark.parametrize( - "column", - [ - "id", - "bool_col", - "tinyint_col", - "smallint_col", - "int_col", - "bigint_col", - "float_col", - "double_col", - "date_string_col", - "string_col", - "timestamp_col", - "year", - "month", - ], -) -def test_noop_cast(alltypes, alltypes_sqla, translate, column): - col = alltypes[column] - result = col.cast(col.type()) - expected = alltypes_sqla.alias("t0").c[column] - assert result.equals(col) - assert str(translate(result.op())) == str(expected) - - -def test_timestamp_cast_noop(alltypes, alltypes_sqla, translate): - # See GH #592 - result1 = alltypes.timestamp_col.cast("timestamp") - result2 = alltypes.int_col.cast("timestamp") - - assert isinstance(result1, ir.TimestampColumn) - assert isinstance(result2, ir.TimestampColumn) - - expected1 = alltypes_sqla.alias("t0").c.timestamp_col - expected2 = sa.cast( - sa.func.to_timestamp(alltypes_sqla.alias("t0").c.int_col), sa.TIMESTAMP() - ) - - assert str(translate(result1.op())) == str(expected1) - assert str(translate(result2.op())) == str(expected2) @pytest.mark.parametrize(("value", "expected"), [(0, None), (5.5, 5.5)]) @@ -427,12 +331,7 @@ def test_union_cte(alltypes, distinct, snapshot): expr2 = expr1.view() expr3 = expr1.view() expr = expr1.union(expr2, distinct=distinct).union(expr3, distinct=distinct) - result = " ".join( - line.strip() - for line in str( - expr.compile().compile(compile_kwargs={"literal_binds": True}) - ).splitlines() - ) + result = " ".join(line.strip() for line in expr.compile().splitlines()) snapshot.assert_match(result, "out.sql") @@ -568,18 +467,6 @@ def test_not_exists(alltypes, df): tm.assert_frame_equal(result, expected, check_index_type=False, check_dtype=False) -def test_interactive_repr_shows_error(alltypes): - # #591. Doing this in Postgres because so many built-in functions are - # not available - - expr = alltypes.int_col.convert_base(10, 2) - - with config.option_context("interactive", True): - result = repr(expr) - - assert "no translation rule" in result.lower() - - def test_subquery(alltypes, df): t = alltypes @@ -758,9 +645,6 @@ def array_types(con): return con.table("array_types") -@pytest.mark.xfail( - reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype" -) def test_array_length(array_types): expr = array_types.select( array_types.x.length().name("x_length"), @@ -861,60 +745,6 @@ def test_timestamp_with_timezone(con): assert str(result.dtype.tz) -@pytest.fixture( - params=[ - None, - "UTC", - "America/New_York", - "America/Los_Angeles", - "Europe/Paris", - "Chile/Continental", - "Asia/Tel_Aviv", - "Asia/Tokyo", - "Africa/Nairobi", - "Australia/Sydney", - ] -) -def tz(request): - return request.param - - -@pytest.fixture -def tzone_compute(con, temp_table, tz): - schema = ibis.schema([("ts", dt.Timestamp(tz)), ("b", "double"), ("c", "string")]) - con.create_table(temp_table, schema=schema, temp=False) - t = con.table(temp_table) - - n = 10 - df = pd.DataFrame( - { - "ts": pd.date_range("2017-04-01", periods=n, tz=tz).values, - "b": np.arange(n).astype("float64"), - "c": list(string.ascii_lowercase[:n]), - } - ) - - df.to_sql( - temp_table, - con.con, - index=False, - if_exists="append", - dtype={"ts": sa.TIMESTAMP(timezone=True), "b": sa.FLOAT, "c": sa.TEXT}, - ) - - yield t - con.drop_table(temp_table) - - -def test_ts_timezone_is_preserved(tzone_compute, tz): - assert dt.Timestamp(tz).equals(tzone_compute.ts.type()) - - -def test_timestamp_with_timezone_select(tzone_compute, tz): - ts = tzone_compute.ts.execute() - assert str(getattr(ts.dtype, "tz", None)) == str(tz) - - @pytest.mark.parametrize( ("left", "right", "type"), [ @@ -1010,8 +840,8 @@ def test_string_to_binary_cast(con): "FROM functional_alltypes LIMIT 10" ) with con.begin() as c: - cur = c.exec_driver_sql(sql_string) - raw_data = [row[0][0] for row in cur] + c.execute(sql_string) + raw_data = [row[0][0] for row in c.fetchall()] expected = pd.Series(raw_data, name=name) tm.assert_series_equal(result, expected) @@ -1027,6 +857,6 @@ def test_string_to_binary_round_trip(con): "FROM functional_alltypes LIMIT 10" ) with con.begin() as c: - cur = c.exec_driver_sql(sql_string) - expected = pd.Series([row[0][0] for row in cur], name=name) + c.execute(sql_string) + expected = pd.Series([row[0][0] for row in c.fetchall()], name=name) tm.assert_series_equal(result, expected) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index a314b4f7543c..e9a8347ab094 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -92,14 +92,18 @@ try: from psycopg2.errors import DivisionByZero as PsycoPg2DivisionByZero from psycopg2.errors import IndeterminateDatatype as PsycoPg2IndeterminateDatatype + from psycopg2.errors import InternalError_ as PsycoPg2InternalError from psycopg2.errors import ( InvalidTextRepresentation as PsycoPg2InvalidTextRepresentation, ) + from psycopg2.errors import ProgrammingError as PsycoPg2ProgrammingError from psycopg2.errors import SyntaxError as PsycoPg2SyntaxError except ImportError: PsycoPg2SyntaxError = ( PsycoPg2IndeterminateDatatype - ) = PsycoPg2InvalidTextRepresentation = PsycoPg2DivisionByZero = None + ) = ( + PsycoPg2InvalidTextRepresentation + ) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = None try: from pymysql.err import NotSupportedError as MySQLNotSupportedError diff --git a/ibis/backends/tests/snapshots/test_dot_sql/test_cte/risingwave/out.sql b/ibis/backends/tests/snapshots/test_dot_sql/test_cte/risingwave/out.sql new file mode 100644 index 000000000000..efc0daaef0d6 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_dot_sql/test_cte/risingwave/out.sql @@ -0,0 +1,8 @@ +WITH "foo" AS ( + SELECT + * + FROM "test_risingwave_temp_mem_t_for_cte" AS "t0" +) +SELECT + COUNT(*) AS "x" +FROM "foo" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_default_limit/risingwave/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/risingwave/out.sql new file mode 100644 index 000000000000..b309cd65374d --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/risingwave/out.sql @@ -0,0 +1,5 @@ +SELECT + "t0"."id", + "t0"."bool_col" +FROM "functional_alltypes" AS "t0" +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/risingwave/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/risingwave/out.sql new file mode 100644 index 000000000000..b309cd65374d --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/risingwave/out.sql @@ -0,0 +1,5 @@ +SELECT + "t0"."id", + "t0"."bool_col" +FROM "functional_alltypes" AS "t0" +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/risingwave/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/risingwave/out.sql new file mode 100644 index 000000000000..6bd0ba8c995d --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/risingwave/out.sql @@ -0,0 +1,3 @@ +SELECT + SUM("t0"."bigint_col") AS "Sum(bigint_col)" +FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/risingwave/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/risingwave/out.sql new file mode 100644 index 000000000000..97338646649f --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/risingwave/out.sql @@ -0,0 +1,10 @@ +SELECT + * +FROM ( + SELECT + "t0"."id", + "t0"."bool_col" + FROM "functional_alltypes" AS "t0" + LIMIT 10 +) AS "t2" +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/risingwave/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/risingwave/out.sql new file mode 100644 index 000000000000..d3969647c9ea --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/risingwave/out.sql @@ -0,0 +1,22 @@ +SELECT + CASE "t0"."continent" + WHEN 'NA' + THEN 'North America' + WHEN 'SA' + THEN 'South America' + WHEN 'EU' + THEN 'Europe' + WHEN 'AF' + THEN 'Africa' + WHEN 'AS' + THEN 'Asia' + WHEN 'OC' + THEN 'Oceania' + WHEN 'AN' + THEN 'Antarctica' + ELSE 'Unknown continent' + END AS "cont", + SUM("t0"."population") AS "total_pop" +FROM "countries" AS "t0" +GROUP BY + 1 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/risingwave/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/risingwave/out.sql new file mode 100644 index 000000000000..c1611d8cecc3 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/risingwave/out.sql @@ -0,0 +1,9 @@ +SELECT + "t0"."x" IN ( + SELECT + "t0"."x" + FROM "t" AS "t0" + WHERE + "t0"."x" > 2 + ) AS "InSubquery(x)" +FROM "t" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/risingwave/out.sql b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/risingwave/out.sql new file mode 100644 index 000000000000..b7508b9ef535 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/risingwave/out.sql @@ -0,0 +1,60 @@ +WITH "t5" AS ( + SELECT + "t4"."field_of_study", + FIRST("t4"."diff") AS "diff" + FROM ( + SELECT + "t3"."field_of_study", + "t3"."years", + "t3"."degrees", + "t3"."earliest_degrees", + "t3"."latest_degrees", + "t3"."latest_degrees" - "t3"."earliest_degrees" AS "diff" + FROM ( + SELECT + "t2"."field_of_study", + "t2"."years", + "t2"."degrees", + FIRST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "earliest_degrees", + LAST_VALUE("t2"."degrees") OVER (PARTITION BY "t2"."field_of_study" ORDER BY "t2"."years" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "latest_degrees" + FROM ( + SELECT + "t1"."field_of_study", + CAST(TO_JSONB("t1"."__pivoted__") -> 'f1' AS VARCHAR) AS "years", + CAST(TO_JSONB("t1"."__pivoted__") -> 'f2' AS BIGINT) AS "degrees" + FROM ( + SELECT + "t0"."field_of_study", + UNNEST( + ARRAY[ROW(CAST('1970-71' AS VARCHAR), CAST("t0"."1970-71" AS BIGINT)), ROW(CAST('1975-76' AS VARCHAR), CAST("t0"."1975-76" AS BIGINT)), ROW(CAST('1980-81' AS VARCHAR), CAST("t0"."1980-81" AS BIGINT)), ROW(CAST('1985-86' AS VARCHAR), CAST("t0"."1985-86" AS BIGINT)), ROW(CAST('1990-91' AS VARCHAR), CAST("t0"."1990-91" AS BIGINT)), ROW(CAST('1995-96' AS VARCHAR), CAST("t0"."1995-96" AS BIGINT)), ROW(CAST('2000-01' AS VARCHAR), CAST("t0"."2000-01" AS BIGINT)), ROW(CAST('2005-06' AS VARCHAR), CAST("t0"."2005-06" AS BIGINT)), ROW(CAST('2010-11' AS VARCHAR), CAST("t0"."2010-11" AS BIGINT)), ROW(CAST('2011-12' AS VARCHAR), CAST("t0"."2011-12" AS BIGINT)), ROW(CAST('2012-13' AS VARCHAR), CAST("t0"."2012-13" AS BIGINT)), ROW(CAST('2013-14' AS VARCHAR), CAST("t0"."2013-14" AS BIGINT)), ROW(CAST('2014-15' AS VARCHAR), CAST("t0"."2014-15" AS BIGINT)), ROW(CAST('2015-16' AS VARCHAR), CAST("t0"."2015-16" AS BIGINT)), ROW(CAST('2016-17' AS VARCHAR), CAST("t0"."2016-17" AS BIGINT)), ROW(CAST('2017-18' AS VARCHAR), CAST("t0"."2017-18" AS BIGINT)), ROW(CAST('2018-19' AS VARCHAR), CAST("t0"."2018-19" AS BIGINT)), ROW(CAST('2019-20' AS VARCHAR), CAST("t0"."2019-20" AS BIGINT))] + ) AS "__pivoted__" + FROM "humanities" AS "t0" + ) AS "t1" + ) AS "t2" + ) AS "t3" + ) AS "t4" + GROUP BY + 1 +) +SELECT + "t11"."field_of_study", + "t11"."diff" +FROM ( + SELECT + "t6"."field_of_study", + "t6"."diff" + FROM "t5" AS "t6" + ORDER BY + "t6"."diff" DESC NULLS LAST + LIMIT 10 + UNION ALL + SELECT + "t6"."field_of_study", + "t6"."diff" + FROM "t5" AS "t6" + WHERE + "t6"."diff" < 0 + ORDER BY + "t6"."diff" ASC + LIMIT 10 +) AS "t11" \ No newline at end of file diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 2ae1aa8b4be4..63c40a7a24a7 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -22,6 +22,7 @@ MySQLNotSupportedError, OracleDatabaseError, PolarsInvalidOperationError, + PsycoPg2InternalError, Py4JError, PyDruidProgrammingError, PyODBCProgrammingError, @@ -92,6 +93,7 @@ def mean_udf(s): "druid", "oracle", "flink", + "risingwave", "exasol", ], raises=com.OperationNotDefinedError, @@ -439,6 +441,7 @@ def mean_and_std(v): "oracle", "exasol", "flink", + "risingwave", ], raises=com.OperationNotDefinedError, ), @@ -537,7 +540,7 @@ def mean_and_std(v): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -562,7 +565,7 @@ def mean_and_std(v): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -592,7 +595,7 @@ def mean_and_std(v): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -649,7 +652,7 @@ def mean_and_std(v): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -664,7 +667,7 @@ def mean_and_std(v): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -771,21 +774,25 @@ def mean_and_std(v): param( lambda t: t.string_col.isin(["1", "7"]), lambda t: t.string_col.isin(["1", "7"]), - marks=pytest.mark.notimpl( - ["exasol"], - raises=(com.OperationNotDefinedError, ExaQueryError), - strict=False, - ), + marks=[ + pytest.mark.notimpl( + ["exasol"], + raises=(com.OperationNotDefinedError, ExaQueryError), + strict=False, + ), + ], id="is_in", ), param( lambda _: ibis._.string_col.isin(["1", "7"]), lambda t: t.string_col.isin(["1", "7"]), - marks=pytest.mark.notimpl( - ["exasol"], - raises=(com.OperationNotDefinedError, ExaQueryError), - strict=False, - ), + marks=[ + pytest.mark.notimpl( + ["exasol"], + raises=(com.OperationNotDefinedError, ExaQueryError), + strict=False, + ), + ], id="is_in_deferred", ), ], @@ -939,7 +946,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): pytest.mark.broken( ["risingwave"], reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64", - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, ), ], ), @@ -954,7 +961,14 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): lambda t: t.string_col.isin(["1", "7"]), id="is_in", marks=[ - pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) + pytest.mark.notimpl( + ["datafusion"], raises=com.OperationNotDefinedError + ), + pytest.mark.notimpl( + "risingwave", + raises=PsycoPg2InternalError, + reason="probably incorrect filter syntax but not sure", + ), ], ), ], @@ -991,7 +1005,7 @@ def test_quantile( ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function covar_pop(integer, integer) does not exist", ), ], @@ -1011,7 +1025,7 @@ def test_quantile( ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function covar_pop(integer, integer) does not exist", ), ], @@ -1036,7 +1050,7 @@ def test_quantile( ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function covar_pop(integer, integer) does not exist", ), ], @@ -1051,7 +1065,7 @@ def test_quantile( raises=com.OperationNotDefinedError, ), pytest.mark.notyet( - ["postgres", "duckdb", "snowflake"], + ["postgres", "duckdb", "snowflake", "risingwave"], raises=com.UnsupportedOperationError, reason="backend only implements population correlation coefficient", ), @@ -1095,7 +1109,7 @@ def test_quantile( ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function covar_pop(integer, integer) does not exist", ), ], @@ -1124,7 +1138,7 @@ def test_quantile( ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function covar_pop(integer, integer) does not exist", ), ], @@ -1608,7 +1622,9 @@ def test_grouped_case(backend, con): @pytest.mark.notyet(["oracle"], raises=OracleDatabaseError) @pytest.mark.notyet(["pyspark"], raises=PySparkAnalysisException) @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) +@pytest.mark.notyet(["risingwave"], raises=AssertionError, strict=False) def test_group_concat_over_window(backend, con): + # TODO: this test is flaky on risingwave and I DO NOT LIKE IT input_df = pd.DataFrame( { "s": ["a|b|c", "b|a|c", "b|b|b|c|a"], diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 24f116d12a76..d8e776f54d31 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -8,7 +8,6 @@ import pandas.testing as tm import pytest import pytz -import sqlalchemy as sa import toolz from pytest import param @@ -24,6 +23,8 @@ MySQLOperationalError, PolarsComputeError, PsycoPg2IndeterminateDatatype, + PsycoPg2InternalError, + PsycoPg2ProgrammingError, PsycoPg2SyntaxError, Py4JJavaError, PySparkAnalysisException, @@ -83,7 +84,19 @@ def test_array_column(backend, alltypes, df): backend.assert_series_equal(result, expected, check_names=False) -def test_array_scalar(con): +ARRAY_BACKEND_TYPES = { + "clickhouse": "Array(Float64)", + "snowflake": "ARRAY", + "trino": "array(double)", + "bigquery": "ARRAY", + "duckdb": "DOUBLE[]", + "postgres": "numeric[]", + "risingwave": "numeric[]", + "flink": "ARRAY NOT NULL", +} + + +def test_array_scalar(con, backend): expr = ibis.array([1.0, 2.0, 3.0]) assert isinstance(expr, ir.ArrayScalar) @@ -126,11 +139,6 @@ def test_array_concat_variadic(con): # Issues #2370 @pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) -@pytest.mark.notyet( - ["risingwave"], - raises=sa.exc.InternalError, - reason="Bind error: cannot determine type of empty array", -) @pytest.mark.notyet(["trino"], raises=TrinoUserError) def test_array_concat_some_empty(con): left = ibis.literal([]) @@ -210,7 +218,7 @@ def test_array_index(con, idx): ) @pytest.mark.notimpl( ["risingwave"], - raises=ValueError, + raises=AssertionError, reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", ) @pytest.mark.never( @@ -243,10 +251,11 @@ def test_array_discovery(backend): raises=GoogleBadRequest, ) @pytest.mark.notimpl(["dask"], raises=ValueError) -@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=ValueError, + # TODO: valueerror -> assertion error + raises=AssertionError, reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", ) def test_unnest_simple(backend): @@ -266,11 +275,6 @@ def test_unnest_simple(backend): @builtin_array @pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", -) def test_unnest_complex(backend): array_types = backend.array_types df = array_types.execute() @@ -309,11 +313,6 @@ def test_unnest_complex(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", -) def test_unnest_idempotent(backend): array_types = backend.array_types df = array_types.execute() @@ -335,11 +334,6 @@ def test_unnest_idempotent(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", -) def test_unnest_no_nulls(backend): array_types = backend.array_types df = array_types.execute() @@ -366,17 +360,8 @@ def test_unnest_no_nulls(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) -@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", -) -@pytest.mark.broken( - ["pandas"], - raises=ValueError, - reason="all the input arrays must have same number of dimensions", -) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.broken(["risingwave"], raises=AssertionError) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -426,10 +411,11 @@ def test_unnest_default_name(backend): ["datafusion"], raises=Exception, reason="array_types table isn't defined" ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( +@pytest.mark.broken( ["risingwave"], - raises=ValueError, - reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", + raises=AssertionError, + reason="not broken; row ordering is not guaranteed and sometimes this test will pass", + strict=False, ) def test_array_slice(backend, start, stop): array_types = backend.array_types @@ -452,6 +438,11 @@ def test_array_slice(backend, start, stop): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.broken( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.notimpl( ["dask", "pandas"], raises=com.OperationNotDefinedError, @@ -488,7 +479,7 @@ def test_array_slice(backend, start, stop): ) @pytest.mark.broken( ["risingwave"], - raises=AssertionError, + raises=PsycoPg2InternalError, reason="TODO(Kexiang): seems a bug", ) def test_array_map(con, input, output, func): @@ -541,6 +532,11 @@ def test_array_map(con, input, output, func): param({"a": [[1, 2], [4]]}, {"a": [[2], [4]]}, id="no_nulls"), ], ) +@pytest.mark.notyet( + "risingwave", + raises=PsycoPg2InternalError, + reason="no support for not null column constraint", +) @pytest.mark.parametrize( "predicate", [ @@ -563,15 +559,11 @@ def test_array_filter(con, input, output, predicate): @builtin_array @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", -) @pytest.mark.broken( - ["flink"], - raises=Py4JJavaError, - reason="Caused by: java.lang.NullPointerException", + ["risingwave"], + raises=AssertionError, + reason="not broken; row ordering is not guaranteed and sometimes this test will pass", + strict=False, ) def test_array_contains(backend, con): t = backend.array_types @@ -621,11 +613,6 @@ def test_array_position(backend, con, a, expected_array): @builtin_array @pytest.mark.notimpl(["dask", "polars"], raises=com.OperationNotDefinedError) -@pytest.mark.broken( - ["risingwave"], - raises=AssertionError, - reason="TODO(Kexiang): seems a bug", -) @pytest.mark.parametrize( ("a"), [ @@ -716,13 +703,13 @@ def test_array_unique(con, input, expected): raises=AssertionError, reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735", ) -def test_array_sort(con): - t = ibis.memtable({"a": [[3, 2], [], [42, 42], []]}) - expr = t.a.sort() +def test_array_sort(backend, con): + t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)}) + expr = t.mutate(a=t.a.sort()).order_by("id") result = con.execute(expr) expected = pd.Series([[2, 3], [], [42, 42], []], dtype="object") - assert frozenset(map(tuple, result.values)) == frozenset( + assert frozenset(map(tuple, result["a"].values)) == frozenset( map(tuple, expected.values) ) @@ -822,9 +809,9 @@ def test_array_intersect(con, data): raises=ClickHouseDatabaseError, reason="ClickHouse won't accept dicts for struct type values", ) -@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) -@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError ) @@ -845,7 +832,6 @@ def test_unnest_struct(con): "dask", "datafusion", "druid", - "flink", "oracle", "pandas", "polars", @@ -856,7 +842,7 @@ def test_unnest_struct(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=ValueError, + raises=com.OperationNotDefinedError, reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", ) def test_zip(backend): @@ -883,9 +869,9 @@ def test_zip(backend): raises=ClickHouseDatabaseError, reason="https://github.com/ClickHouse/ClickHouse/issues/41112", ) -@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) -@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], raises=com.OperationNotDefinedError, @@ -944,7 +930,11 @@ def flatten_data(): @pytest.mark.notyet( ["postgres", "risingwave"], reason="Postgres doesn't truly support arrays of arrays", - raises=(com.OperationNotDefinedError, PsycoPg2IndeterminateDatatype), + raises=( + com.OperationNotDefinedError, + PsycoPg2IndeterminateDatatype, + PsycoPg2InternalError, + ), ) @pytest.mark.parametrize( ("column", "expected"), @@ -1061,7 +1051,7 @@ def test_range_start_stop_step(con, start, stop, step): @pytest.mark.notimpl(["flink", "dask"], raises=com.OperationNotDefinedError) @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Invalid parameter step: step size cannot equal zero", ) def test_range_start_stop_step_zero(con, start, stop): @@ -1100,6 +1090,11 @@ def test_unnest_empty_array(con): raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl(["sqlite"], raises=com.UnsupportedBackendType) +@pytest.mark.notyet( + "risingwave", + raises=PsycoPg2InternalError, + reason="no support for not null column constraint", +) def test_array_map_with_conflicting_names(backend, con): t = ibis.memtable({"x": [[1, 2]]}, schema=ibis.schema(dict(x="!array"))) expr = t.select(a=t.x.map(lambda x: x + 1)).select( @@ -1188,7 +1183,7 @@ def swap(token): id="pos", marks=pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_interval() does not exist", ), ), @@ -1204,7 +1199,7 @@ def swap(token): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function neg(interval) does not exist", ), ], @@ -1224,7 +1219,7 @@ def swap(token): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function neg(interval) does not exist", ), ], @@ -1256,7 +1251,7 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): pytest.mark.notyet(["polars"], raises=PolarsComputeError), pytest.mark.notyet( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_interval() does not exist", ), ], @@ -1275,7 +1270,7 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): ), pytest.mark.notyet( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function neg(interval) does not exist", ), ], @@ -1305,24 +1300,14 @@ def test_repr_timestamp_array(con, monkeypatch): assert ibis.options.default_backend is con expr = ibis.array(pd.date_range("2010-01-01", "2010-01-03", freq="D").tolist()) - assert "No translation rule" not in repr(expr) - assert "OperationNotDefinedError" not in repr(expr) + assert "Translation to backend failed" not in repr(expr) @pytest.mark.notyet( ["dask", "datafusion", "flink", "polars"], raises=com.OperationNotDefinedError, ) -@pytest.mark.broken( - ["risingwave"], - raises=sa.exc.OperationalError, - reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14734", -) -@pytest.mark.broken( - ["pandas"], - raises=ValueError, - reason="cannot reindex on an axis with duplicate labels", -) +@pytest.mark.broken(["pandas"], raises=ValueError, reason="reindex on duplicate values") def test_unnest_range(con): expr = ibis.range(2).unnest().name("x").as_table().mutate({"y": 1.0}) result = con.execute(expr) diff --git a/ibis/backends/tests/test_asof_join.py b/ibis/backends/tests/test_asof_join.py index ffe86146b65f..2a1901efc520 100644 --- a/ibis/backends/tests/test_asof_join.py +++ b/ibis/backends/tests/test_asof_join.py @@ -95,6 +95,7 @@ def time_keyed_right(time_keyed_df2): "oracle", "mssql", "sqlite", + "risingwave", ] ) def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op): @@ -135,6 +136,7 @@ def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op "oracle", "mssql", "sqlite", + "risingwave", ] ) def test_keyed_asof_join_with_tolerance( diff --git a/ibis/backends/tests/test_benchmarks.py b/ibis/backends/tests/test_benchmarks.py deleted file mode 100644 index 3234d3c8693f..000000000000 --- a/ibis/backends/tests/test_benchmarks.py +++ /dev/null @@ -1,900 +0,0 @@ -from __future__ import annotations - -import copy -import functools -import inspect -import itertools -import os -import string - -import numpy as np -import pandas as pd -import pytest -import sqlalchemy as sa -from packaging.version import parse as vparse - -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis.backends.base import _get_backend_names - -# from ibis.backends.pandas.udf import udf - -# FIXME(kszucs): pytestmark = pytest.mark.benchmark -pytestmark = pytest.mark.skip(reason="the backends must be rewritten first") - - -def make_t(): - return ibis.table( - [ - ("_timestamp", "int32"), - ("dim1", "int32"), - ("dim2", "int32"), - ("valid_seconds", "int32"), - ("meas1", "int32"), - ("meas2", "int32"), - ("year", "int32"), - ("month", "int32"), - ("day", "int32"), - ("hour", "int32"), - ("minute", "int32"), - ], - name="t", - ) - - -@pytest.fixture(scope="module") -def t(): - return make_t() - - -def make_base(t): - return t[ - ( - (t.year > 2016) - | ((t.year == 2016) & (t.month > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour > 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute >= 5) - ) - ) - & ( - (t.year < 2016) - | ((t.year == 2016) & (t.month < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour < 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute <= 5) - ) - ) - ] - - -@pytest.fixture(scope="module") -def base(t): - return make_base(t) - - -def make_large_expr(base): - src_table = base - src_table = src_table.mutate( - _timestamp=(src_table["_timestamp"] - src_table["_timestamp"] % 3600) - .cast("int32") - .name("_timestamp"), - valid_seconds=300, - ) - - aggs = [] - for meas in ["meas1", "meas2"]: - aggs.append(src_table[meas].sum().cast("float").name(meas)) - src_table = src_table.aggregate( - aggs, by=["_timestamp", "dim1", "dim2", "valid_seconds"] - ) - - part_keys = ["year", "month", "day", "hour", "minute"] - ts_col = src_table["_timestamp"].cast("timestamp") - new_cols = {} - for part_key in part_keys: - part_col = getattr(ts_col, part_key)() - new_cols[part_key] = part_col - src_table = src_table.mutate(**new_cols) - return src_table[ - [ - "_timestamp", - "dim1", - "dim2", - "meas1", - "meas2", - "year", - "month", - "day", - "hour", - "minute", - ] - ] - - -@pytest.fixture(scope="module") -def large_expr(base): - return make_large_expr(base) - - -@pytest.mark.benchmark(group="construction") -@pytest.mark.parametrize( - "construction_fn", - [ - pytest.param(lambda *_: make_t(), id="small"), - pytest.param(lambda t, *_: make_base(t), id="medium"), - pytest.param(lambda _, base: make_large_expr(base), id="large"), - ], -) -def test_construction(benchmark, construction_fn, t, base): - benchmark(construction_fn, t, base) - - -@pytest.mark.benchmark(group="builtins") -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -@pytest.mark.parametrize("builtin", [hash, str]) -def test_builtins(benchmark, expr_fn, builtin, t, base, large_expr): - expr = expr_fn(t, base, large_expr) - benchmark(builtin, expr) - - -_backends = set(_get_backend_names()) -# compile is a no-op -_backends.remove("pandas") - -_XFAIL_COMPILE_BACKENDS = {"dask", "pyspark", "polars", "risingwave"} - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -def test_compile(benchmark, module, expr_fn, t, base, large_expr): - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - expr = expr_fn(t, base, large_expr) - try: - benchmark(mod.compile, expr) - except (sa.exc.NoSuchModuleError, ImportError) as e: # delayed imports - pytest.skip(str(e)) - - -@pytest.fixture(scope="module") -def pt(): - n = 60_000 - data = pd.DataFrame( - { - "key": np.random.choice(16000, size=n), - "low_card_key": np.random.choice(30, size=n), - "value": np.random.rand(n), - "timestamps": pd.date_range( - start="2023-05-05 16:37:57", periods=n, freq="s" - ).values, - "timestamp_strings": pd.date_range( - start="2023-05-05 16:37:39", periods=n, freq="s" - ).values.astype(str), - "repeated_timestamps": pd.date_range(start="2018-09-01", periods=30).repeat( - int(n / 30) - ), - } - ) - - return ibis.pandas.connect(dict(df=data)).table("df") - - -def high_card_group_by(t): - return t.group_by(t.key).aggregate(avg_value=t.value.mean()) - - -def cast_to_dates(t): - return t.timestamps.cast(dt.date) - - -def cast_to_dates_from_strings(t): - return t.timestamp_strings.cast(dt.date) - - -def multikey_group_by_with_mutate(t): - return ( - t.mutate(dates=t.timestamps.cast("date")) - .group_by(["low_card_key", "dates"]) - .aggregate(avg_value=lambda t: t.value.mean()) - ) - - -def simple_sort(t): - return t.order_by([t.key]) - - -def simple_sort_projection(t): - return t[["key", "value"]].order_by(["key"]) - - -def multikey_sort(t): - return t.order_by(["low_card_key", "key"]) - - -def multikey_sort_projection(t): - return t[["low_card_key", "key", "value"]].order_by(["low_card_key", "key"]) - - -def low_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.low_card_key, - ) - - -def low_card_grouped_rolling(t): - return t.value.mean().over(low_card_rolling_window(t)) - - -def high_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.key, - ) - - -def high_card_grouped_rolling(t): - return t.value.mean().over(high_card_rolling_window(t)) - - -# @udf.reduction(["double"], "double") -# def my_mean(series): -# return series.mean() - - -def low_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(high_card_rolling_window(t)) - - -# @udf.analytic(["double"], "double") -# def my_zscore(series): -# return (series - series.mean()) / series.std() - - -def low_card_window(t): - return ibis.window(group_by=t.low_card_key) - - -def high_card_window(t): - return ibis.window(group_by=t.key) - - -def low_card_window_analytics_udf(t): - return my_zscore(t.value).over(low_card_window(t)) - - -def high_card_window_analytics_udf(t): - return my_zscore(t.value).over(high_card_window(t)) - - -# @udf.reduction(["double", "double"], "double") -# def my_wm(v, w): -# return np.average(v, weights=w) - - -def low_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -broken_pandas_grouped_rolling = pytest.mark.xfail( - condition=vparse("1.4") <= vparse(pd.__version__) < vparse("1.4.2"), - raises=ValueError, - reason="https://github.com/pandas-dev/pandas/pull/44068", -) - - -@pytest.mark.benchmark(group="execution") -@pytest.mark.parametrize( - "expression_fn", - [ - pytest.param(high_card_group_by, id="high_card_group_by"), - pytest.param(cast_to_dates, id="cast_to_dates"), - pytest.param(cast_to_dates_from_strings, id="cast_to_dates_from_strings"), - pytest.param(multikey_group_by_with_mutate, id="multikey_group_by_with_mutate"), - pytest.param(simple_sort, id="simple_sort"), - pytest.param(simple_sort_projection, id="simple_sort_projection"), - pytest.param(multikey_sort, id="multikey_sort"), - pytest.param(multikey_sort_projection, id="multikey_sort_projection"), - pytest.param( - low_card_grouped_rolling, - id="low_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling, - id="high_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - low_card_grouped_rolling_udf_mean, - id="low_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_mean, - id="high_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param(low_card_window_analytics_udf, id="low_card_window_analytics_udf"), - pytest.param( - high_card_window_analytics_udf, id="high_card_window_analytics_udf" - ), - pytest.param( - low_card_grouped_rolling_udf_wm, - id="low_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_wm, - id="high_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - ], -) -def test_execute(benchmark, expression_fn, pt): - expr = expression_fn(pt) - benchmark(expr.execute) - - -@pytest.fixture(scope="module") -def part(): - return ibis.table( - dict( - p_partkey="int64", - p_size="int64", - p_type="string", - p_mfgr="string", - ), - name="part", - ) - - -@pytest.fixture(scope="module") -def supplier(): - return ibis.table( - dict( - s_suppkey="int64", - s_nationkey="int64", - s_name="string", - s_acctbal="decimal(15, 3)", - s_address="string", - s_phone="string", - s_comment="string", - ), - name="supplier", - ) - - -@pytest.fixture(scope="module") -def partsupp(): - return ibis.table( - dict( - ps_partkey="int64", - ps_suppkey="int64", - ps_supplycost="decimal(15, 3)", - ), - name="partsupp", - ) - - -@pytest.fixture(scope="module") -def nation(): - return ibis.table( - dict(n_nationkey="int64", n_regionkey="int64", n_name="string"), - name="nation", - ) - - -@pytest.fixture(scope="module") -def region(): - return ibis.table(dict(r_regionkey="int64", r_name="string"), name="region") - - -@pytest.fixture(scope="module") -def tpc_h02(part, supplier, partsupp, nation, region): - REGION = "EUROPE" - SIZE = 25 - TYPE = "BRASS" - - expr = ( - part.join(partsupp, part.p_partkey == partsupp.ps_partkey) - .join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = ( - partsupp.join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = subexpr[ - (subexpr.r_name == REGION) & (expr.p_partkey == subexpr.ps_partkey) - ] - - filters = [ - expr.p_size == SIZE, - expr.p_type.like(f"%{TYPE}"), - expr.r_name == REGION, - expr.ps_supplycost == subexpr.ps_supplycost.min(), - ] - q = expr.filter(filters) - - q = q.select( - [ - q.s_acctbal, - q.s_name, - q.n_name, - q.p_partkey, - q.p_mfgr, - q.s_address, - q.s_phone, - q.s_comment, - ] - ) - - return q.order_by( - [ - ibis.desc(q.s_acctbal), - q.n_name, - q.s_name, - q.p_partkey, - ] - ).limit(100) - - -@pytest.mark.benchmark(group="repr") -def test_repr_tpc_h02(benchmark, tpc_h02): - benchmark(repr, tpc_h02) - - -@pytest.mark.benchmark(group="repr") -def test_repr_huge_union(benchmark): - n = 10 - raw_types = [ - "int64", - "float64", - "string", - "array, b: map>>>", - ] - tables = [ - ibis.table( - list(zip(string.ascii_letters, itertools.cycle(raw_types))), - name=f"t{i:d}", - ) - for i in range(n) - ] - expr = functools.reduce(ir.Table.union, tables) - benchmark(repr, expr) - - -@pytest.mark.benchmark(group="node_args") -def test_op_argnames(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.argnames, expr.op()) - - -@pytest.mark.benchmark(group="node_args") -def test_op_args(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.args, expr.op()) - - -@pytest.mark.benchmark(group="datatype") -def test_complex_datatype_parse(benchmark): - type_str = "array, b: map>>>" - expected = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - assert dt.parse(type_str) == expected - benchmark(dt.parse, type_str) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize("func", [str, hash]) -def test_complex_datatype_builtins(benchmark, func): - datatype = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - benchmark(func, datatype) - - -@pytest.mark.benchmark(group="equality") -def test_large_expr_equals(benchmark, tpc_h02): - benchmark(ir.Expr.equals, tpc_h02, copy.deepcopy(tpc_h02)) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize( - "dtypes", - [ - pytest.param( - [ - obj - for _, obj in inspect.getmembers( - dt, - lambda obj: isinstance(obj, dt.DataType), - ) - ], - id="singletons", - ), - pytest.param( - dt.Array( - dt.Struct( - dict( - a=dt.Array(dt.string), - b=dt.Map(dt.string, dt.Array(dt.int64)), - ) - ) - ), - id="complex", - ), - ], -) -def test_eq_datatypes(benchmark, dtypes): - def eq(a, b): - assert a == b - - benchmark(eq, dtypes, copy.deepcopy(dtypes)) - - -def multiple_joins(table, num_joins): - for _ in range(num_joins): - table = table.mutate(dummy=ibis.literal("")) - table = table.left_join(table, ["dummy"])[[table]] - - -@pytest.mark.parametrize("num_joins", [1, 10]) -@pytest.mark.parametrize("num_columns", [1, 10, 100]) -def test_multiple_joins(benchmark, num_joins, num_columns): - table = ibis.table( - {f"col_{i:d}": "string" for i in range(num_columns)}, - name="t", - ) - benchmark(multiple_joins, table, num_joins) - - -@pytest.fixture -def customers(): - return ibis.table( - dict( - customerid="int32", - name="string", - address="string", - citystatezip="string", - birthdate="date", - phone="string", - timezone="string", - lat="float64", - long="float64", - ), - name="customers", - ) - - -@pytest.fixture -def orders(): - return ibis.table( - dict( - orderid="int32", - customerid="int32", - ordered="timestamp", - shipped="timestamp", - items="string", - total="float64", - ), - name="orders", - ) - - -@pytest.fixture -def orders_items(): - return ibis.table( - dict(orderid="int32", sku="string", qty="int32", unit_price="float64"), - name="orders_items", - ) - - -@pytest.fixture -def products(): - return ibis.table( - dict( - sku="string", - desc="string", - weight_kg="float64", - cost="float64", - dims_cm="string", - ), - name="products", - ) - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -def test_compile_with_drops( - benchmark, module, customers, orders, orders_items, products -): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - .drop("dims_cm", "cost") - .mutate(o_date=lambda t: t.shipped.date()) - .filter(lambda t: t.ordered == t.shipped) - ) - - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - try: - benchmark(mod.compile, expr) - except sa.exc.NoSuchModuleError as e: - pytest.skip(str(e)) - - -def test_repr_join(benchmark, customers, orders, orders_items, products): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - ) - op = expr.op() - benchmark(repr, op) - - -@pytest.mark.parametrize("overwrite", [True, False], ids=["overwrite", "no_overwrite"]) -def test_insert_duckdb(benchmark, overwrite, tmp_path): - pytest.importorskip("duckdb") - - n_rows = int(1e4) - table_name = "t" - schema = ibis.schema(dict(a="int64", b="int64", c="int64")) - t = ibis.memtable(dict.fromkeys(list("abc"), range(n_rows)), schema=schema) - - con = ibis.duckdb.connect(tmp_path / "test_insert.ddb") - con.create_table(table_name, schema=schema) - benchmark(con.insert, table_name, t, overwrite=overwrite) - - -def test_snowflake_medium_sized_to_pandas(benchmark): - pytest.importorskip("snowflake.connector") - - if (url := os.environ.get("SNOWFLAKE_URL")) is None: - pytest.skip("SNOWFLAKE_URL environment variable not set") - - con = ibis.connect(url) - - # LINEITEM at scale factor 1 is around 6MM rows, but we limit to 1,000,000 - # to make the benchmark fast enough for development, yet large enough to show a - # difference if there's a performance hit - lineitem = con.table("LINEITEM", schema="SNOWFLAKE_SAMPLE_DATA.TPCH_SF1").limit( - 1_000_000 - ) - - benchmark.pedantic(lineitem.to_pandas, rounds=5, iterations=1, warmup_rounds=1) - - -def test_parse_many_duckdb_types(benchmark): - parse = pytest.importorskip("ibis.backends.duckdb.datatypes").DuckDBType.from_string - - def parse_many(types): - list(map(parse, types)) - - types = ["VARCHAR", "INTEGER", "DOUBLE", "BIGINT"] * 1000 - benchmark(parse_many, types) - - -@pytest.fixture(scope="session") -def sql() -> str: - return """ - SELECT t1.id as t1_id, x, t2.id as t2_id, y - FROM t1 INNER JOIN t2 - ON t1.id = t2.id - """ - - -@pytest.fixture(scope="session") -def ddb(tmp_path_factory): - duckdb = pytest.importorskip("duckdb") - - N = 20_000_000 - - con = duckdb.connect() - - path = str(tmp_path_factory.mktemp("duckdb") / "data.ddb") - sql = ( - lambda var, table, n=N: f""" - CREATE TABLE {table} AS - SELECT ROW_NUMBER() OVER () AS id, {var} - FROM ( - SELECT {var} - FROM RANGE({n}) _ ({var}) - ORDER BY RANDOM() - ) - """ - ) - - with duckdb.connect(path) as con: - con.execute(sql("x", table="t1")) - con.execute(sql("y", table="t2")) - return path - - -def test_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - # yes, we're benchmarking duckdb here, not ibis - # - # we do this to get a baseline for comparison - duckdb = pytest.importorskip("duckdb") - con = duckdb.connect(ddb, read_only=True) - - benchmark(lambda sql: con.sql(sql).to_arrow_table(), sql) - - -def test_ibis_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect(ddb, read_only=True) - - expr = con.sql(sql) - benchmark(expr.to_pyarrow) - - -@pytest.fixture -def diffs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "difference": "float64", - "pct_difference": "float64", - "pct_threshold": "float64", - "validation_status": "string", - }, - name="diffs", - ) - - -@pytest.fixture -def srcs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "validation_type": "string", - "aggregation_type": "string", - "table_name": "string", - "column_name": "string", - "primary_keys": "string", - "num_random_rows": "string", - "agg_value": "float64", - }, - name="srcs", - ) - - -@pytest.fixture -def nrels(): - return 300 - - -def make_big_union(t, nrels): - return ibis.union(*[t] * nrels) - - -@pytest.fixture -def src(srcs, nrels): - return make_big_union(srcs, nrels) - - -@pytest.fixture -def diff(diffs, nrels): - return make_big_union(diffs, nrels) - - -def test_big_eq_expr(benchmark, src, diff): - benchmark(ops.core.Node.equals, src.op(), diff.op()) - - -def test_big_join_expr(benchmark, src, diff): - benchmark(ir.Table.join, src, diff, ["validation_name"], how="outer") - - -def test_big_join_execute(benchmark, nrels): - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect() - - # cache to avoid a request-per-union operand - src = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580336/source_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - diff = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580340/differences_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - expr = src.join(diff, ["validation_name"], how="outer") - t = benchmark.pedantic(expr.to_pyarrow, rounds=1, iterations=1, warmup_rounds=1) - assert len(t) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 19868964eb27..8bad125da763 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -25,7 +25,11 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.conftest import ALL_BACKENDS -from ibis.backends.tests.errors import Py4JJavaError, PyDruidProgrammingError +from ibis.backends.tests.errors import ( + PsycoPg2InternalError, + Py4JJavaError, + PyDruidProgrammingError, +) from ibis.util import gen_name, guid if TYPE_CHECKING: @@ -115,7 +119,8 @@ def test_create_table(backend, con, temp_table, lamduh, sch): marks=[ pytest.mark.notyet(["clickhouse"], reason="Can't specify both"), pytest.mark.notyet( - ["pyspark", "trino", "exasol"], reason="No support for temp tables" + ["pyspark", "trino", "exasol", "risingwave"], + reason="No support for temp tables", ), pytest.mark.never(["polars"], reason="Everything in-memory is temp"), pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"), @@ -132,7 +137,8 @@ def test_create_table(backend, con, temp_table, lamduh, sch): id="temp, no overwrite", marks=[ pytest.mark.notyet( - ["pyspark", "trino", "exasol"], reason="No support for temp tables" + ["pyspark", "trino", "exasol", "risingwave"], + reason="No support for temp tables", ), pytest.mark.never(["polars"], reason="Everything in-memory is temp"), pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"), @@ -308,7 +314,7 @@ def tmpcon(alchemy_con): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_create_temporary_table_from_schema(tmpcon, new_schema): @@ -375,7 +381,7 @@ def test_rename_table(con, temp_table, temp_table_orig): ) @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason='Feature is not yet implemented: column constraints "NOT NULL"', ) def test_nullable_input_output(con, temp_table): @@ -719,11 +725,6 @@ def test_unsigned_integer_type(alchemy_con, alchemy_temp_table): marks=mark.postgres, id="postgresql", ), - param( - "postgresql://root:@localhost:4566/dev", - marks=mark.risingwave, - id="risingwave", - ), param( "pyspark://?spark.app.name=test-pyspark", marks=[ @@ -1120,11 +1121,6 @@ def test_set_backend_name(name, monkeypatch): marks=mark.postgres, id="postgres", ), - param( - "postgres://root:@localhost:4566/dev", - marks=mark.risingwave, - id="risingwave", - ), ], ) def test_set_backend_url(url, monkeypatch): @@ -1188,7 +1184,7 @@ def test_create_table_timestamp(con, temp_table): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_ref_count(backend, con, alltypes): @@ -1213,7 +1209,7 @@ def test_persist_expression_ref_count(backend, con, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression(backend, alltypes): @@ -1232,7 +1228,7 @@ def test_persist_expression(backend, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_contextmanager(backend, alltypes): @@ -1253,7 +1249,7 @@ def test_persist_expression_contextmanager(backend, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): @@ -1276,7 +1272,7 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): ) @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @@ -1318,7 +1314,7 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_repeated_cache(alltypes): @@ -1345,6 +1341,11 @@ def test_persist_expression_repeated_cache(alltypes): ["oracle"], reason="Oracle error message for a missing table/view doesn't include the name of the table", ) +@pytest.mark.never( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" @@ -1431,7 +1432,7 @@ def test_create_schema(con_create_schema): @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: information_schema.schemata is not supported,", ) def test_list_schemas(con_create_schema): diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 9866e6ae07ce..30e51cb4e297 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -231,7 +231,7 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): backend.assert_series_equal(foo2.x.execute(), expected2) -_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink", "risingwave"} +_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink"} no_sqlglot_dialect = sorted( # TODO(cpcloud): remove the strict=False hack once backends are ported to # sqlglot @@ -244,11 +244,6 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): "dialect", [*sorted(_get_backend_names() - _NO_SQLGLOT_DIALECT), *no_sqlglot_dialect], ) -@pytest.mark.notyet( - ["risingwave"], - raises=ValueError, - reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", -) @pytest.mark.notyet(["polars"], raises=PolarsComputeError) @dot_sql_notimpl @dot_sql_never @@ -276,11 +271,6 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): ["druid"], raises=AttributeError, reason="druid doesn't respect column names" ) @pytest.mark.notyet(["snowflake", "bigquery"]) -@pytest.mark.notyet( - ["risingwave"], - raises=ValueError, - reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", -) @dot_sql_notimpl @dot_sql_never def test_con_dot_sql_transpile(backend, con, dialect, df): @@ -299,11 +289,6 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): @dot_sql_notimpl @dot_sql_never @pytest.mark.notimpl(["druid", "flink", "polars", "exasol"]) -@pytest.mark.notyet( - ["risingwave"], - raises=ValueError, - reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", -) def test_order_by_no_projection(backend): con = backend.connection expr = ( diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index b3ce4d4cfaf0..02eefb296c3a 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -4,7 +4,6 @@ import pyarrow as pa import pyarrow.csv as pcsv import pytest -import sqlalchemy as sa from pytest import param import ibis @@ -342,11 +341,6 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): id="decimal128", marks=[ pytest.mark.notyet(["flink"], raises=NotImplementedError), - pytest.mark.notyet( - ["risingwave"], - raises=sa.exc.DBAPIError, - reason="Feature is not yet implemented: unsupported data type: NUMERIC(38,9)", - ), pytest.mark.notyet(["exasol"], raises=ExaQueryError), ], ), @@ -367,11 +361,6 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): raises=(PySparkParseException, PySparkArithmeticException), reason="precision is out of range", ), - pytest.mark.notyet( - ["risingwave"], - raises=sa.exc.DBAPIError, - reason="Feature is not yet implemented: unsupported data type: NUMERIC(76,38)", - ), pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), ], @@ -495,16 +484,7 @@ def test_to_pandas_batches_empty_table(backend, con): @pytest.mark.parametrize( "n", [ - param( - None, - marks=[ - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit null", - ), - ], - ), + None, 1, ], ) @@ -516,19 +496,11 @@ def test_to_pandas_batches_nonempty_table(backend, con, n): assert sum(map(len, t.to_pandas_batches())) == n +@pytest.mark.notimpl(["flink"]) @pytest.mark.parametrize( "n", [ - param( - None, - marks=[ - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit null", - ), - ], - ), + None, 0, 1, 2, diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index f4065eb9058f..8ffd569a71a9 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa import toolz from pytest import param @@ -26,6 +25,7 @@ ImpalaHiveServer2Error, MySQLProgrammingError, OracleDatabaseError, + PsycoPg2InternalError, Py4JJavaError, PyDruidProgrammingError, PyODBCDataError, @@ -548,7 +548,7 @@ def test_order_by(backend, alltypes, df, key, df_kwargs): @pytest.mark.notimpl(["dask", "pandas", "polars", "mssql", "druid"]) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function random() does not exist", ) def test_order_by_random(alltypes): @@ -852,12 +852,12 @@ def test_typeof(con): @pytest.mark.notimpl(["datafusion", "druid"]) @pytest.mark.notimpl(["pyspark"], condition=is_older_than("pyspark", "3.5.0")) @pytest.mark.notyet(["dask"], reason="not supported by the backend") +@pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="not supported by exasol") @pytest.mark.broken( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="https://github.com/risingwavelabs/risingwave/issues/1343", ) -@pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="not supported by exasol") def test_isin_uncorrelated( backend, batting, awards_players, batting_df, awards_players_df ): @@ -1037,11 +1037,6 @@ def query(t, group_cols): reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet", raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason='sql parser error: Expected ), found: TEXT at line:3, column:219 Near "))]) AS anon_1(f1"', -) @pytest.mark.broken( ["trino"], reason="invalid code generated for unnesting a struct", @@ -1163,7 +1158,7 @@ def test_pivot_wider(backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function last(double precision) does not exist, do you mean left or least", ) def test_distinct_on_keep(backend, on, keep): @@ -1233,7 +1228,7 @@ def test_distinct_on_keep(backend, on, keep): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function first(double precision) does not exist", ) def test_distinct_on_keep_is_none(backend, on): @@ -1287,8 +1282,6 @@ def test_hash_consistent(backend, alltypes): "pyspark", "risingwave", "sqlite", - "clickhouse", - "mssql", ] ) def test_hashbytes(backend, alltypes): @@ -1322,8 +1315,6 @@ def hash_256(col): "risingwave", "snowflake", "trino", - "pyspark", - "mssql", ] ) @pytest.mark.notyet( @@ -1352,7 +1343,6 @@ def hash_256(col): "pandas", "dask", "oracle", - "risingwave", "snowflake", "sqlite", ] @@ -1513,26 +1503,12 @@ def test_try_cast_func(con, from_val, to_type, func): param( slice(None, None), lambda t: t.count().to_pandas(), - marks=[ - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit/offset", - ), - ], id="[:]", ), param(slice(0, 0), lambda _: 0, id="[0:0]"), param( slice(0, None), lambda t: t.count().to_pandas(), - marks=[ - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit/offset", - ), - ], id="[0:]", ), # positive stop @@ -1588,11 +1564,6 @@ def test_try_cast_func(con, from_val, to_type, func): raises=ImpalaHiveServer2Error, reason="impala doesn't support OFFSET without ORDER BY", ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit/offset", - ), pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), ], ), @@ -1680,16 +1651,16 @@ def test_static_table_slice(backend, slc, expected_count_fn): raises=com.UnsupportedArgumentError, reason="Removed half-baked dynamic offset functionality for now", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="risingwave doesn't support limit/offset", +) @pytest.mark.notyet( ["trino"], raises=TrinoUserError, reason="backend doesn't support dynamic limit/offset", ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit/offset", -) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @pytest.mark.notyet( ["clickhouse"], @@ -1770,16 +1741,16 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): ) @pytest.mark.notyet(["pyspark"], reason="pyspark doesn't support dynamic limit/offset") @pytest.mark.notyet(["flink"], reason="flink doesn't support dynamic limit/offset") -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="risingwave doesn't support limit/offset", -) @pytest.mark.notyet( ["mssql"], reason="doesn't support dynamic limit/offset; compiles incorrectly in sqlglot", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="risingwave doesn't support limit/offset", +) def test_dynamic_table_slice_with_computed_offset(backend): t = backend.functional_alltypes @@ -1798,17 +1769,10 @@ def test_dynamic_table_slice_with_computed_offset(backend): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl( - [ - "druid", - "flink", - "polars", - "snowflake", - ] -) +@pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function random() does not exist", ) def test_sample(backend): @@ -1826,17 +1790,10 @@ def test_sample(backend): backend.assert_frame_equal(empty, df.iloc[:0]) -@pytest.mark.notimpl( - [ - "druid", - "flink", - "polars", - "snowflake", - ] -) +@pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function random() does not exist", ) def test_sample_memtable(con, backend): @@ -1895,11 +1852,6 @@ def test_substitute(backend): ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) @pytest.mark.notimpl(["flink"], reason="no sqlglot dialect", raises=ValueError) -@pytest.mark.notimpl( - ["risingwave"], - raises=ValueError, - reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", -) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index a20cd61a98d3..887c10547b6d 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -198,7 +198,7 @@ def test_semi_join_topk(con, batting, awards_players, func): @pytest.mark.notimpl(["dask", "druid", "exasol", "oracle"]) @pytest.mark.notimpl( - ["postgres", "mssql"], + ["postgres", "mssql", "risingwave"], raises=com.IbisTypeError, reason="postgres can't handle null types columns", ) diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 68f8a88796d4..441332cc8f6d 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -3,13 +3,12 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis import ibis.common.exceptions as exc import ibis.expr.datatypes as dt -from ibis.backends.tests.errors import Py4JJavaError +from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError pytestmark = [ pytest.mark.never( @@ -38,7 +37,7 @@ def test_map_table(backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_column_map_values(backend): @@ -73,7 +72,7 @@ def test_column_map_merge(backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_literal_map_keys(con): @@ -93,7 +92,7 @@ def test_literal_map_keys(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_literal_map_values(con): @@ -145,7 +144,7 @@ def test_map_scalar_contains_key_scalar(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_map_scalar_contains_key_column(backend, alltypes, df): @@ -215,7 +214,7 @@ def test_literal_map_merge(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_literal_map_getitem_broadcast(backend, alltypes, df): @@ -237,7 +236,7 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_literal_map_get_broadcast(backend, alltypes, df): @@ -269,7 +268,7 @@ def test_literal_map_get_broadcast(backend, alltypes, df): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_map_construct_dict(con, keys, values): @@ -361,7 +360,7 @@ def test_map_get_with_null_on_not_nullable(con, null_value): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_map_get_with_null_on_null_type_with_null(con, null_value): @@ -392,7 +391,7 @@ def test_map_get_with_null_on_null_type_with_non_null(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_map_create_table(con, temp_table): @@ -410,7 +409,7 @@ def test_map_create_table(con, temp_table): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function hstore(character varying[], character varying[]) does not exist", ) def test_map_length(con): diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 76ffe94aae62..b92abea1470c 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis @@ -25,6 +24,7 @@ MySQLOperationalError, OracleDatabaseError, PsycoPg2DivisionByZero, + PsycoPg2InternalError, Py4JError, PyDruidProgrammingError, PyODBCDataError, @@ -254,9 +254,9 @@ def test_numeric_literal(con, backend, expr, expected_types): "dask": decimal.Decimal("1.1"), "exasol": decimal.Decimal("1"), "duckdb": decimal.Decimal("1.1"), - "risingwave": 1.1, "impala": decimal.Decimal("1"), "postgres": decimal.Decimal("1.1"), + "risingwave": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1"), @@ -296,9 +296,9 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": decimal.Decimal("1.1"), "trino": decimal.Decimal("1.1"), "duckdb": decimal.Decimal("1.100000000"), - "risingwave": 1.1, "impala": decimal.Decimal("1.1"), "postgres": decimal.Decimal("1.1"), + "risingwave": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1.1"), @@ -332,8 +332,8 @@ def test_numeric_literal(con, backend, expr, expected_types): "bigquery": decimal.Decimal("1.1"), "sqlite": decimal.Decimal("1.1"), "dask": decimal.Decimal("1.1"), - "risingwave": 1.1, "postgres": decimal.Decimal("1.1"), + "risingwave": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "clickhouse": decimal.Decimal( @@ -384,10 +384,10 @@ def test_numeric_literal(con, backend, expr, expected_types): ibis.literal(decimal.Decimal("Infinity"), type=dt.decimal), # TODO(krzysztof-kwitt): Should we unify it? { - "risingwave": float("nan"), "bigquery": float("inf"), "sqlite": decimal.Decimal("Infinity"), "postgres": decimal.Decimal("Infinity"), + "risingwave": decimal.Decimal("Infinity"), "pandas": decimal.Decimal("Infinity"), "dask": decimal.Decimal("Infinity"), "pyspark": decimal.Decimal("Infinity"), @@ -406,13 +406,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["trino"], - "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " - "message=\"line 1:51: 'Infinity' is not a valid decimal literal\", " - "query_id=20230128_024107_01084_y8zm3)", - raises=sa.exc.ProgrammingError, - ), pytest.mark.notyet( ["mysql", "impala"], raises=com.UnsupportedOperationError ), @@ -455,10 +448,10 @@ def test_numeric_literal(con, backend, expr, expected_types): ibis.literal(decimal.Decimal("-Infinity"), type=dt.decimal), # TODO(krzysztof-kwitt): Should we unify it? { - "risingwave": float("nan"), "bigquery": float("-inf"), "sqlite": decimal.Decimal("-Infinity"), "postgres": decimal.Decimal("-Infinity"), + "risingwave": decimal.Decimal("-Infinity"), "pandas": decimal.Decimal("-Infinity"), "dask": decimal.Decimal("-Infinity"), "pyspark": decimal.Decimal("-Infinity"), @@ -477,13 +470,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["trino"], - "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " - "message=\"line 1:51: '-Infinity' is not a valid decimal literal\", " - "query_id=20230128_024107_01084_y8zm3)", - raises=sa.exc.ProgrammingError, - ), pytest.mark.notyet( ["mysql", "impala"], raises=com.UnsupportedOperationError ), @@ -551,13 +537,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "Unsupported precision. Supported values: [1 : 76]. Current value: None", raises=NotImplementedError, ), - pytest.mark.broken( - ["trino"], - "(trino.exceptions.TrinoUserError) TrinoUserError(type=USER_ERROR, name=INVALID_LITERAL, " - "message=\"line 1:51: 'NaN' is not a valid decimal literal\", " - "query_id=20230128_024107_01084_y8zm3)", - raises=sa.exc.ProgrammingError, - ), pytest.mark.notyet( ["mysql", "impala"], raises=com.UnsupportedOperationError ), @@ -754,12 +733,12 @@ def test_isnan_isinf( math.log(5.556, 2), id="log-base", marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function log10(numeric, numeric) does not exist", ), - pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), ], ), param( @@ -773,14 +752,34 @@ def test_isnan_isinf( math.log(5.556, 2), id="log2", marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function log10(numeric, numeric) does not exist", ), - pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), ], ), + param( + L(5.556).log10(), + math.log10(5.556), + id="log10", + ), + param( + L(5.556).radians(), + math.radians(5.556), + id="radians", + ), + param( + L(5.556).degrees(), + math.degrees(5.556), + id="degrees", + ), + param( + L(11) % 3, + 11 % 3, + id="mod", + ), param(L(5.556).log10(), math.log10(5.556), id="log10"), param( L(5.556).radians(), @@ -929,12 +928,12 @@ def test_simple_math_functions_columns( lambda t: t.double_col.add(1).log(2), lambda t: np.log2(t.double_col + 1), marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function log10(numeric, numeric) does not exist", ), - pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), ], id="log2", ), @@ -971,7 +970,7 @@ def test_simple_math_functions_columns( pytest.mark.notimpl(["polars"], raises=com.UnsupportedArgumentError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function log10(numeric, numeric) does not exist", ), ], @@ -1197,7 +1196,6 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), - pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1210,7 +1208,6 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), - pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1223,7 +1220,6 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), - pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1236,7 +1232,6 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), - pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1319,6 +1314,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "snowflake", "trino", "postgres", + "risingwave", "mysql", "druid", "mssql", @@ -1326,11 +1322,6 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ], reason="Not SQLAlchemy backends", ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="Feature is not yet implemented: unsupported data type: NUMERIC(5)", -) def test_sa_default_numeric_precision_and_scale( con, backend, default_precisions, default_scales, temp_table ): @@ -1364,13 +1355,13 @@ def test_sa_default_numeric_precision_and_scale( assert_equal(schema, expected) +@pytest.mark.notimpl(["dask", "pandas", "polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function random() does not exist", ) -@pytest.mark.notimpl(["dask", "pandas", "polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_random(con): expr = ibis.random() result = con.execute(expr) @@ -1481,7 +1472,7 @@ def test_constants(con, const): param(lambda t: t.int_col, lambda _: 3, id="col_scalar"), ], ) -@pytest.mark.notimpl(["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError)) +@pytest.mark.notimpl(["exasol"], raises=(ExaQueryError)) @flink_no_bitwise def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): expr = op(left_fn(alltypes), right_fn(alltypes)).name("tmp") @@ -1518,7 +1509,7 @@ def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): ], ) @pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) -@pytest.mark.notimpl(["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError)) +@pytest.mark.notimpl(["exasol"], raises=(ExaQueryError)) @flink_no_bitwise def test_bitwise_shift(backend, alltypes, df, op, left_fn, right_fn): expr = op(left_fn(alltypes), right_fn(alltypes)).name("tmp") diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index a72b7c140b22..67c7b5123281 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -6,13 +6,16 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis import ibis.expr.datatypes as dt from ibis import _ -from ibis.backends.tests.errors import OracleDatabaseError, Py4JJavaError +from ibis.backends.tests.errors import ( + OracleDatabaseError, + PsycoPg2InternalError, + Py4JJavaError, +) @pytest.mark.parametrize( @@ -38,11 +41,6 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): [("2009-03-01", "2010-07-03"), ("2014-12-01", "2017-01-05")], ) @pytest.mark.notimpl(["trino", "druid"]) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function make_date(integer, integer, integer) does not exist", -) @pytest.mark.broken(["oracle"], raises=OracleDatabaseError) def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) @@ -116,7 +114,7 @@ def test_scalar_param_struct(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_date(integer, integer, integer) does not exist", ) def test_scalar_param_map(con): @@ -179,11 +177,6 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col): ids=["string", "date", "datetime"], ) @pytest.mark.notimpl(["druid", "oracle"]) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function make_date(integer, integer, integer) does not exist", -) def test_scalar_param_date(backend, alltypes, value): param = ibis.param("date") ds_col = alltypes.date_string_col diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index f1ecdd6bdab6..64824b612462 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -409,12 +409,7 @@ def test_register_garbage(con, monkeypatch): ], ) @pytest.mark.notyet( - ["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] -) -@pytest.mark.notimpl( - ["flink"], - raises=ValueError, - reason="read_parquet() missing required argument: 'schema'", + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -445,12 +440,17 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"] -) -@pytest.mark.notimpl( - ["flink"], - raises=ValueError, - reason="read_parquet() missing required argument: 'schema'", + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -469,12 +469,17 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"] -) -@pytest.mark.notimpl( - ["flink"], - raises=ValueError, - reason="read_csv() missing required argument: 'schema'", + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -556,12 +561,7 @@ def num_diamonds(data_dir): [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet( - ["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] -) -@pytest.mark.notimpl( - ["flink"], - raises=ValueError, - reason="read_csv() missing required argument: 'schema'", + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 41102559ad9c..4df076da7f97 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -4,14 +4,13 @@ import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis import ibis.common.exceptions as com import ibis.expr.types as ir from ibis import _ -from ibis.backends.tests.errors import PyDruidProgrammingError +from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError @pytest.fixture @@ -84,7 +83,7 @@ def test_union_mixed_distinct(backend, union_subsets): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: INTERSECT all", ), ], @@ -138,7 +137,7 @@ def test_intersect(backend, alltypes, df, distinct): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: EXCEPT all", ), ], @@ -223,7 +222,7 @@ def test_top_level_union(backend, con, alltypes, distinct): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: INTERSECT all", ), ], diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 4cd7d0d8ffa2..0db31eb5662d 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -8,10 +8,9 @@ from ibis import _ from ibis.backends.conftest import _get_backends_to_test -sa = pytest.importorskip("sqlalchemy") sg = pytest.importorskip("sqlglot") -pytestmark = pytest.mark.notimpl(["flink", "risingwave"]) +pytestmark = pytest.mark.notimpl(["flink"]) simple_literal = param(ibis.literal(1), id="simple_literal") array_literal = param( @@ -27,7 +26,7 @@ ) no_structs = pytest.mark.never( ["impala", "mysql", "sqlite", "mssql", "exasol"], - raises=(NotImplementedError, sa.exc.CompileError, exc.UnsupportedBackendType), + raises=(NotImplementedError, exc.UnsupportedBackendType), reason="structs not supported in the backend", ) no_struct_literals = pytest.mark.notimpl( @@ -62,9 +61,6 @@ def test_literal(backend, expr): @pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") -@pytest.mark.xfail_version( - mssql=["sqlalchemy>=2"], reason="sqlalchemy 2 prefixes literals with `N`" -) def test_group_by_has_index(backend, snapshot): countries = ibis.table( dict(continent="string", population="int64"), name="countries" diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 05c6c268749f..d6bd099508b6 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis @@ -14,6 +13,7 @@ from ibis.backends.tests.errors import ( ClickHouseDatabaseError, OracleDatabaseError, + PsycoPg2InternalError, PyDruidProgrammingError, PyODBCProgrammingError, ) @@ -62,7 +62,7 @@ ), pytest.mark.broken( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', ), ], @@ -90,7 +90,7 @@ ), pytest.mark.broken( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', ), ], @@ -233,11 +233,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -248,11 +243,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -268,11 +258,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["druid"], reason="No posix support", raises=AssertionError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -283,11 +268,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -300,11 +280,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -317,11 +292,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -334,11 +304,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -349,11 +314,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -364,11 +324,6 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function textregexeq(character varying, character varying) does not exist", - ), ], ), param( @@ -991,7 +946,7 @@ def test_multiple_subs(con): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function levenshtein(character varying, character varying) does not exist", ) @pytest.mark.parametrize( diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 95f7df9f4ea5..f2b8c99fc73b 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -79,7 +79,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres", "risingwave"]) +@pytest.mark.notimpl(["postgres"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 68a9306c26d5..3573a598ddf6 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa import sqlglot as sg from pytest import param @@ -30,6 +29,7 @@ OracleDatabaseError, PolarsComputeError, PolarsPanicException, + PsycoPg2InternalError, Py4JJavaError, PyDruidProgrammingError, PyODBCProgrammingError, @@ -152,6 +152,11 @@ def test_timestamp_extract(backend, alltypes, df, attr): raises=AssertionError, reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -634,7 +639,7 @@ def test_date_truncate(backend, alltypes, df, unit): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Bind error: Invalid unit: week", ), ], @@ -657,7 +662,7 @@ def test_date_truncate(backend, alltypes, df, unit): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Bind error: Invalid unit: millisecond", ), ], @@ -681,7 +686,7 @@ def test_date_truncate(backend, alltypes, df, unit): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Bind error: Invalid unit: microsecond", ), ], @@ -738,7 +743,7 @@ def convert_to_offset(offset, displacement_type=displacement_type): pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Bind error: Invalid unit: week", ), ], @@ -831,7 +836,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop", marks=[ pytest.mark.notimpl( - ["dask", "risingwave", "snowflake", "sqlite", "bigquery", "exasol"], + ["dask", "snowflake", "sqlite", "bigquery", "exasol"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -851,14 +856,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop-different-units", marks=[ pytest.mark.notimpl( - [ - "sqlite", - "risingwave", - "polars", - "snowflake", - "bigquery", - "exasol", - ], + ["sqlite", "polars", "snowflake", "bigquery", "exasol"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -982,11 +980,6 @@ def convert_to_offset(x): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function make_date(integer, integer, integer) does not exist", - ), pytest.mark.broken( ["flink"], raises=com.UnsupportedOperationError, @@ -1428,13 +1421,6 @@ def test_interval_add_cast_column(backend, alltypes, df): ), "%Y%m%d", marks=[ - pytest.mark.notimpl( - [ - "risingwave", - ], - raises=AttributeError, - reason="Neither 'concat' object nor 'Comparator' object has an attribute 'value'", - ), pytest.mark.notimpl( [ "polars", @@ -1617,7 +1603,6 @@ def test_integer_to_timestamp(backend, con, unit): [ "dask", "pandas", - "risingwave", "clickhouse", "sqlite", "datafusion", @@ -1723,6 +1708,11 @@ def test_day_of_week_column(backend, alltypes, df): "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" ), ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -1803,7 +1793,7 @@ def test_now_from_projection(alltypes): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=com.OperationNotDefinedError, reason="function make_date(integer, integer, integer) does not exist", ) def test_date_literal(con, backend): @@ -1837,7 +1827,7 @@ def test_date_literal(con, backend): @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", ) def test_timestamp_literal(con, backend): @@ -1895,7 +1885,7 @@ def test_timestamp_literal(con, backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", ) def test_timestamp_with_timezone_literal(con, timezone, expected): @@ -1928,7 +1918,7 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): @pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_time(integer, integer, integer) does not exist", ) def test_time_literal(con, backend): @@ -2078,7 +2068,7 @@ def test_interval_literal(con, backend): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=com.OperationNotDefinedError, reason="function make_date(integer, integer, integer) does not exist", ) def test_date_column_from_ymd(backend, con, alltypes, df): @@ -2100,12 +2090,12 @@ def test_date_column_from_ymd(backend, con, alltypes, df): raises=AttributeError, reason="StringColumn' object has no attribute 'year'", ) +@pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function make_timestamp(smallint, smallint, smallint, smallint, smallint, smallint) does not exist", ) -@pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp( @@ -2234,11 +2224,6 @@ def build_date_col(t): param(lambda _: DATE, build_date_col, id="date_column"), ], ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function make_date(integer, integer, integer) does not exist", -) def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): left = left_fn(alltypes) right = right_fn(alltypes) @@ -2360,12 +2345,12 @@ def test_large_timestamp(con): reason="assert Timestamp('2023-01-07 13:20:05.561000') == Timestamp('2023-01-07 13:20:05.561000231')", raises=AssertionError, ), + pytest.mark.notimpl(["exasol"], raises=AssertionError), pytest.mark.notyet( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Parse error: timestamp without time zone Can't cast string to timestamp (expected format is YYYY-MM-DD HH:MM:SS[.D+{up to 6 digits}] or YYYY-MM-DD HH:MM or YYYY-MM-DD or ISO 8601 format)", ), - pytest.mark.notimpl(["exasol"], raises=AssertionError), ], ), ], @@ -2395,11 +2380,6 @@ def test_timestamp_precision_output(con, ts, scale, unit): ], raises=com.OperationNotDefinedError, ) -@pytest.mark.notyet( - ["risingwave"], - reason="risingwave doesn't have any easy way to accurately compute the delta in specific units", - raises=com.OperationNotDefinedError, -) @pytest.mark.parametrize( ("start", "end", "unit", "expected"), [ @@ -2416,7 +2396,7 @@ def test_timestamp_precision_output(con, ts, scale, unit): reason="time types not yet implemented in ibis for the clickhouse backend", ), pytest.mark.notyet( - ["postgres"], + ["postgres", "risingwave"], reason="postgres doesn't have any easy way to accurately compute the delta in specific units", raises=com.OperationNotDefinedError, ), @@ -2565,7 +2545,7 @@ def test_delta(con, start, end, unit, expected): @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", ) def test_timestamp_bucket(backend, kws, pd_freq): @@ -2604,7 +2584,7 @@ def test_timestamp_bucket(backend, kws, pd_freq): @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", ) def test_timestamp_bucket_offset(backend, offset_mins): @@ -2717,11 +2697,6 @@ def test_time_literal_sql(dialect, snapshot, micros): param(datetime.date.fromisoformat, id="fromstring"), ], ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="function make_date(integer, integer, integer) does not exist", -) def test_date_scalar(con, value, func): expr = ibis.date(func(value)).name("tmp") diff --git a/ibis/backends/tests/test_uuid.py b/ibis/backends/tests/test_uuid.py index 5802727f205d..7b427bea9173 100644 --- a/ibis/backends/tests/test_uuid.py +++ b/ibis/backends/tests/test_uuid.py @@ -4,7 +4,6 @@ import uuid import pytest -import sqlalchemy.exc import ibis import ibis.common.exceptions as com @@ -21,6 +20,7 @@ "flink": "CHAR(36) NOT NULL", "impala": "STRING", "postgres": "uuid", + "risingwave": "character varying", "snowflake": "VARCHAR", "sqlite": "text", "trino": "uuid", @@ -28,11 +28,6 @@ @pytest.mark.notimpl(["datafusion", "polars"], raises=NotImplementedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=sqlalchemy.exc.InternalError, - reason="Feature is not yet implemented: unsupported data type: UUID", -) @pytest.mark.notimpl(["polars"], raises=NotImplementedError) @pytest.mark.notimpl(["datafusion"], raises=Exception) def test_uuid_literal(con, backend): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 0059e5df0955..fdfdada99871 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis @@ -19,6 +18,7 @@ ImpalaHiveServer2Error, MySQLOperationalError, OracleDatabaseError, + PsycoPg2InternalError, Py4JJavaError, PyDruidProgrammingError, PyODBCProgrammingError, @@ -148,7 +148,7 @@ def calc_zscore(s): pytest.mark.notimpl(["dask"], raises=NotImplementedError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Unrecognized window function: percent_rank", ), ], @@ -165,7 +165,7 @@ def calc_zscore(s): pytest.mark.notimpl(["dask"], raises=NotImplementedError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Unrecognized window function: cume_dist", ), ], @@ -196,7 +196,7 @@ def calc_zscore(s): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Unrecognized window function: ntile", ), ], @@ -236,12 +236,8 @@ def calc_zscore(s): ["impala", "mssql"], raises=com.OperationNotDefinedError ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), - pytest.mark.notimpl( - ["flink"], - raises=com.OperationNotDefinedError, - reason="No translation rule for ", - ), - pytest.mark.notimpl(["risingwave"], raises=sa.exc.InternalError), + pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError), ], ), param( @@ -407,7 +403,7 @@ def test_grouped_bounded_expanding_window( pytest.mark.notimpl(["dask"], raises=NotImplementedError), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -667,14 +663,10 @@ def test_grouped_unbounded_window( @pytest.mark.broken(["dask"], raises=AssertionError) @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="OVER RANGE FOLLOWING windows are not supported in Flink yet", -) +@pytest.mark.notimpl(["flink"], raises=com.UnsupportedOperationError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ) def test_simple_ungrouped_unbound_following_window( @@ -706,7 +698,7 @@ def test_simple_ungrouped_unbound_following_window( @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ) @pytest.mark.xfail_version(datafusion=["datafusion==35"]) @@ -740,7 +732,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -772,16 +764,16 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): pytest.mark.notimpl( ["pandas", "dask"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="Feature is not yet implemented: Unrecognized window function: ntile", - ), pytest.mark.notimpl( ["flink"], raises=Py4JJavaError, reason="CalciteContextException: Argument to function 'NTILE' must be a literal", ), + pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -858,7 +850,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): marks=[ pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -893,7 +885,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -906,7 +898,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): marks=[ pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -944,7 +936,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): ), pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ), ], @@ -1061,17 +1053,17 @@ def test_ungrouped_unbounded_window( reason="RANGE OFFSET frame for 'DB::ColumnNullable' ORDER BY column is not implemented", raises=ClickHouseDatabaseError, ) -@pytest.mark.notimpl( - ["risingwave"], - raises=sa.exc.InternalError, - reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", -) @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.broken( ["mysql"], raises=MySQLOperationalError, reason="https://github.com/tobymao/sqlglot/issues/2779", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", +) def test_grouped_bounded_range_window(backend, alltypes, df): # Explanation of the range window spec below: # @@ -1129,7 +1121,7 @@ def gb_fn(df): ) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Unrecognized window function: percent_rank", ) def test_percent_rank_whole_table_no_order_by(backend, alltypes, df): @@ -1180,7 +1172,7 @@ def agg(df): @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ) def test_mutate_window_filter(backend, alltypes): @@ -1257,16 +1249,16 @@ def test_first_last(backend): ["mssql"], raises=PyODBCProgrammingError, reason="not support by the backend" ) @pytest.mark.broken(["flink"], raises=Py4JJavaError, reason="bug in Flink") -@pytest.mark.broken( - ["risingwave"], - raises=sa.exc.InternalError, - reason="sql parser error: Expected literal int, found: INTERVAL at line:1, column:99", -) @pytest.mark.broken( ["exasol"], raises=ExaQueryError, reason="database can't handle UTC timestamps in DataFrames", ) +@pytest.mark.broken( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="sql parser error: Expected literal int, found: INTERVAL at line:1, column:99", +) def test_range_expression_bounds(backend): t = ibis.memtable( { @@ -1313,7 +1305,7 @@ def test_range_expression_bounds(backend): ) @pytest.mark.broken( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Unrecognized window function: percent_rank", ) def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): @@ -1348,7 +1340,7 @@ def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): @pytest.mark.notyet(["flink"], raises=com.UnsupportedOperationError) @pytest.mark.notimpl( ["risingwave"], - raises=sa.exc.InternalError, + raises=PsycoPg2InternalError, reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", ) def test_windowed_order_by_sequence_is_preserved(con): diff --git a/pyproject.toml b/pyproject.toml index 1fdb7609db48..a57364788779 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ atpublic = ">=2.3,<5" bidict = ">=0.22.1,<1" multipledispatch = ">=0.6,<2" numpy = ">=1,<2" -pandas = ">=1.2.5,<3" +pandas = ">=1.2.5,<2.1" parsy = ">=2,<3" pyarrow = ">=2,<16" pyarrow-hotfix = ">=0.4,<1" @@ -192,11 +192,11 @@ mysql = ["pymysql"] oracle = ["oracledb", "packaging"] pandas = ["regex"] polars = ["polars", "packaging"] -risingwave = ["psycopg2"] postgres = ["psycopg2"] pyspark = ["pyspark", "packaging"] snowflake = ["snowflake-connector-python", "packaging"] sqlite = ["regex"] +risingwave = ["psycopg2"] trino = ["trino"] # non-backend extras visualization = ["graphviz"]