From 37825e0701332610d27c1285491a6b13d504aff6 Mon Sep 17 00:00:00 2001 From: Kexiang Wang Date: Wed, 27 Dec 2023 22:55:50 -0500 Subject: [PATCH] feat(risingwave): init impl for Risingwave --- .github/workflows/ibis-backends.yml | 16 +- ci/schema/risingwave.sql | 177 +++ compose.yaml | 85 ++ docker/risingwave/risingwave.toml | 2 + ibis/backends/conftest.py | 1 + ibis/backends/risingwave/__init__.py | 282 +++++ ibis/backends/risingwave/compiler.py | 34 + ibis/backends/risingwave/datatypes.py | 83 ++ ibis/backends/risingwave/registry.py | 861 ++++++++++++++ ibis/backends/risingwave/tests/__init__.py | 0 ibis/backends/risingwave/tests/conftest.py | 124 ++ .../test_client/test_compile_toplevel/out.sql | 2 + .../test_analytic_functions/out.sql | 7 + .../test_union_cte/False/out.sql | 1 + .../test_union_cte/True/out.sql | 1 + ibis/backends/risingwave/tests/test_client.py | 158 +++ .../risingwave/tests/test_functions.py | 1032 +++++++++++++++++ ibis/backends/risingwave/tests/test_json.py | 17 + ibis/backends/tests/test_aggregation.py | 152 ++- ibis/backends/tests/test_array.py | 135 ++- ibis/backends/tests/test_binary.py | 1 + ibis/backends/tests/test_client.py | 67 +- ibis/backends/tests/test_column.py | 1 + ibis/backends/tests/test_dot_sql.py | 17 +- ibis/backends/tests/test_examples.py | 3 +- ibis/backends/tests/test_export.py | 39 +- ibis/backends/tests/test_generic.py | 106 +- ibis/backends/tests/test_json.py | 4 +- ibis/backends/tests/test_map.py | 96 +- ibis/backends/tests/test_network.py | 5 + ibis/backends/tests/test_numeric.py | 63 +- ibis/backends/tests/test_param.py | 17 + ibis/backends/tests/test_register.py | 22 +- ibis/backends/tests/test_set_ops.py | 97 +- ibis/backends/tests/test_sql.py | 2 +- ibis/backends/tests/test_string.py | 88 +- ibis/backends/tests/test_struct.py | 8 +- ibis/backends/tests/test_temporal.py | 117 +- ibis/backends/tests/test_timecontext.py | 1 + ibis/backends/tests/test_udf.py | 1 + ibis/backends/tests/test_uuid.py | 5 + ibis/backends/tests/test_vectorized_udf.py | 2 +- ibis/backends/tests/test_window.py | 102 +- ibis/tests/benchmarks/test_benchmarks.py | 2 +- poetry.lock | 19 +- pyproject.toml | 10 + requirements-dev.txt | 1 + 47 files changed, 3927 insertions(+), 139 deletions(-) create mode 100644 ci/schema/risingwave.sql create mode 100644 docker/risingwave/risingwave.toml create mode 100644 ibis/backends/risingwave/__init__.py create mode 100644 ibis/backends/risingwave/compiler.py create mode 100644 ibis/backends/risingwave/datatypes.py create mode 100644 ibis/backends/risingwave/registry.py create mode 100644 ibis/backends/risingwave/tests/__init__.py create mode 100644 ibis/backends/risingwave/tests/conftest.py create mode 100644 ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql create mode 100644 ibis/backends/risingwave/tests/test_client.py create mode 100644 ibis/backends/risingwave/tests/test_functions.py create mode 100644 ibis/backends/risingwave/tests/test_json.py diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index f83a36245a16..5bed3c303d09 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -121,6 +121,12 @@ jobs: - postgres sys-deps: - libgeos-dev + - name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - name: impala title: Impala extras: @@ -209,6 +215,14 @@ jobs: - postgres sys-deps: - libgeos-dev + - os: windows-latest + backend: + name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - os: windows-latest backend: name: postgres @@ -677,7 +691,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: remove deps that are not compatible with sqlalchemy 2 - run: poetry remove snowflake-sqlalchemy sqlalchemy-exasol + run: poetry remove snowflake-sqlalchemy sqlalchemy-exasol sqlalchemy-risingwave - name: add sqlalchemy 2 run: poetry update sqlalchemy diff --git a/ci/schema/risingwave.sql b/ci/schema/risingwave.sql new file mode 100644 index 000000000000..cedfa8449d60 --- /dev/null +++ b/ci/schema/risingwave.sql @@ -0,0 +1,177 @@ +SET RW_IMPLICIT_FLUSH=true; + +DROP TABLE IF EXISTS diamonds CASCADE; + +CREATE TABLE diamonds ( + carat FLOAT, + cut TEXT, + color TEXT, + clarity TEXT, + depth FLOAT, + "table" FLOAT, + price BIGINT, + x FLOAT, + y FLOAT, + z FLOAT +) WITH ( + connector = 'posix_fs', + match_pattern = 'diamonds.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS astronauts CASCADE; + +CREATE TABLE astronauts ( + "id" BIGINT, + "number" BIGINT, + "nationwide_number" BIGINT, + "name" VARCHAR, + "original_name" VARCHAR, + "sex" VARCHAR, + "year_of_birth" BIGINT, + "nationality" VARCHAR, + "military_civilian" VARCHAR, + "selection" VARCHAR, + "year_of_selection" BIGINT, + "mission_number" BIGINT, + "total_number_of_missions" BIGINT, + "occupation" VARCHAR, + "year_of_mission" BIGINT, + "mission_title" VARCHAR, + "ascend_shuttle" VARCHAR, + "in_orbit" VARCHAR, + "descend_shuttle" VARCHAR, + "hours_mission" DOUBLE PRECISION, + "total_hrs_sum" DOUBLE PRECISION, + "field21" BIGINT, + "eva_hrs_mission" DOUBLE PRECISION, + "total_eva_hrs" DOUBLE PRECISION +) WITH ( + connector = 'posix_fs', + match_pattern = 'astronauts.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS batting CASCADE; + +CREATE TABLE batting ( + "playerID" TEXT, + "yearID" BIGINT, + stint BIGINT, + "teamID" TEXT, + "lgID" TEXT, + "G" BIGINT, + "AB" BIGINT, + "R" BIGINT, + "H" BIGINT, + "X2B" BIGINT, + "X3B" BIGINT, + "HR" BIGINT, + "RBI" BIGINT, + "SB" BIGINT, + "CS" BIGINT, + "BB" BIGINT, + "SO" BIGINT, + "IBB" BIGINT, + "HBP" BIGINT, + "SH" BIGINT, + "SF" BIGINT, + "GIDP" BIGINT +) WITH ( + connector = 'posix_fs', + match_pattern = 'batting.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS awards_players CASCADE; + +CREATE TABLE awards_players ( + "playerID" TEXT, + "awardID" TEXT, + "yearID" BIGINT, + "lgID" TEXT, + tie TEXT, + notes TEXT +) WITH ( + connector = 'posix_fs', + match_pattern = 'awards_players.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS functional_alltypes CASCADE; + +CREATE TABLE functional_alltypes ( + id INTEGER, + bool_col BOOLEAN, + tinyint_col SMALLINT, + smallint_col SMALLINT, + int_col INTEGER, + bigint_col BIGINT, + float_col REAL, + double_col DOUBLE PRECISION, + date_string_col TEXT, + string_col TEXT, + timestamp_col TIMESTAMP WITHOUT TIME ZONE, + year INTEGER, + month INTEGER +) WITH ( + connector = 'posix_fs', + match_pattern = 'functional_alltypes.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS tzone CASCADE; + +CREATE TABLE tzone ( + ts TIMESTAMP WITH TIME ZONE, + key TEXT, + value DOUBLE PRECISION +); + +INSERT INTO tzone + SELECT + CAST('2017-05-28 11:01:31.000400' AS TIMESTAMP WITH TIME ZONE) + + t * INTERVAL '1 day 1 second' AS ts, + CHR(97 + t) AS key, + t + t / 10.0 AS value + FROM generate_series(0, 9) AS t; + +DROP TABLE IF EXISTS array_types CASCADE; + +CREATE TABLE IF NOT EXISTS array_types ( + x BIGINT[], + y TEXT[], + z DOUBLE PRECISION[], + grouper TEXT, + scalar_column DOUBLE PRECISION, + multi_dim BIGINT[][] +); + +INSERT INTO array_types VALUES + (ARRAY[1, 2, 3], ARRAY['a', 'b', 'c'], ARRAY[1.0, 2.0, 3.0], 'a', 1.0, ARRAY[ARRAY[NULL::BIGINT, NULL, NULL], ARRAY[1, 2, 3]]), + (ARRAY[4, 5], ARRAY['d', 'e'], ARRAY[4.0, 5.0], 'a', 2.0, ARRAY[]::BIGINT[][]), + (ARRAY[6, NULL], ARRAY['f', NULL], ARRAY[6.0, NULL], 'a', 3.0, ARRAY[NULL, ARRAY[]::BIGINT[], NULL]), + (ARRAY[NULL, 1, NULL], ARRAY[NULL, 'a', NULL], ARRAY[]::DOUBLE PRECISION[], 'b', 4.0, ARRAY[ARRAY[1], ARRAY[2], ARRAY[NULL::BIGINT], ARRAY[3]]), + (ARRAY[2, NULL, 3], ARRAY['b', NULL, 'c'], NULL, 'b', 5.0, NULL), + (ARRAY[4, NULL, NULL, 5], ARRAY['d', NULL, NULL, 'e'], ARRAY[4.0, NULL, NULL, 5.0], 'c', 6.0, ARRAY[ARRAY[1, 2, 3]]); + +DROP TABLE IF EXISTS json_t CASCADE; + +CREATE TABLE IF NOT EXISTS json_t (js JSONB); + +INSERT INTO json_t VALUES + ('{"a": [1,2,3,4], "b": 1}'), + ('{"a":null,"b":2}'), + ('{"a":"foo", "c":null}'), + ('null'), + ('[42,47,55]'), + ('[]'); + +DROP TABLE IF EXISTS win CASCADE; +CREATE TABLE win (g TEXT, x BIGINT, y BIGINT); +INSERT INTO win VALUES + ('a', 0, 3), + ('a', 1, 2), + ('a', 2, 0), + ('a', 3, 1), + ('a', 4, 1); diff --git a/compose.yaml b/compose.yaml index 141f51a3eaff..1a561b895919 100644 --- a/compose.yaml +++ b/compose.yaml @@ -537,6 +537,88 @@ services: networks: - impala + risingwave-minio: + image: "quay.io/minio/minio:latest" + command: + - server + - "--address" + - "0.0.0.0:9301" + - "--console-address" + - "0.0.0.0:9400" + - /data + expose: + - "9301" + - "9400" + ports: + - "9301:9301" + - "9400:9400" + depends_on: [] + volumes: + - "risingwave-minio:/data" + entrypoint: /bin/sh -c "set -e; mkdir -p \"/data/hummock001\"; /usr/bin/docker-entrypoint.sh \"$$0\" \"$$@\" " + environment: + MINIO_CI_CD: "1" + MINIO_ROOT_PASSWORD: hummockadmin + MINIO_ROOT_USER: hummockadmin + MINIO_DOMAIN: "risingwave-minio" + container_name: risingwave-minio + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/9301; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + + risingwave: + image: ghcr.io/risingwavelabs/risingwave:nightly-20240122 + command: "standalone --meta-opts=\" \ + --advertise-addr 0.0.0.0:5690 \ + --backend mem \ + --state-store hummock+minio://hummockadmin:hummockadmin@risingwave-minio:9301/hummock001 \ + --data-directory hummock_001 \ + --config-path /risingwave.toml\" \ + --compute-opts=\" \ + --config-path /risingwave.toml \ + --advertise-addr 0.0.0.0:5688 \ + --role both \" \ + --frontend-opts=\" \ + --config-path /risingwave.toml \ + --listen-addr 0.0.0.0:4566 \ + --advertise-addr 0.0.0.0:4566 \" \ + --compactor-opts=\" \ + --advertise-addr 0.0.0.0:6660 \"" + expose: + - "4566" + ports: + - "4566:4566" + depends_on: + - risingwave-minio + volumes: + - "./docker/risingwave/risingwave.toml:/risingwave.toml" + - risingwave:/data + environment: + RUST_BACKTRACE: "1" + # If ENABLE_TELEMETRY is not set, telemetry will start by default + ENABLE_TELEMETRY: ${ENABLE_TELEMETRY:-true} + container_name: risingwave + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/6660; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5688; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/4566; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5690; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + networks: impala: # docker defaults to naming networks "$PROJECT_$NETWORK" but the Java Hive @@ -553,6 +635,7 @@ networks: oracle: exasol: flink: + risingwave: volumes: clickhouse: @@ -563,3 +646,5 @@ volumes: postgres: exasol: impala: + risingwave-minio: + risingwave: diff --git a/docker/risingwave/risingwave.toml b/docker/risingwave/risingwave.toml new file mode 100644 index 000000000000..43d57926ed16 --- /dev/null +++ b/docker/risingwave/risingwave.toml @@ -0,0 +1,2 @@ +# RisingWave config file to be mounted into the Docker containers. +# See https://github.com/risingwavelabs/risingwave/blob/main/src/config/example.toml for example diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index 404d173606ed..1ec7a09ba057 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -539,6 +539,7 @@ def ddl_con(ddl_backend): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "trino", diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py new file mode 100644 index 000000000000..04de491f6dfe --- /dev/null +++ b/ibis/backends/risingwave/__init__.py @@ -0,0 +1,282 @@ +"""Risingwave backend.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Callable, Literal + +import sqlalchemy as sa + +import ibis.common.exceptions as exc +import ibis.expr.operations as ops +from ibis import util +from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend +from ibis.backends.risingwave.compiler import RisingwaveCompiler +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.common.exceptions import InvalidDecoratorError + +if TYPE_CHECKING: + from collections.abc import Iterable + + import ibis.expr.datatypes as dt + + +def _verify_source_line(func_name: str, line: str): + if line.startswith("@"): + raise InvalidDecoratorError(func_name, line) + return line + + +class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema): + name = "risingwave" + compiler = RisingwaveCompiler + supports_temporary_tables = False + supports_create_or_replace = False + supports_python_udfs = False + + def do_connect( + self, + host: str | None = None, + user: str | None = None, + password: str | None = None, + port: int = 5432, + database: str | None = None, + schema: str | None = None, + url: str | None = None, + driver: Literal["psycopg2"] = "psycopg2", + ) -> None: + """Create an Ibis client connected to Risingwave database. + + Parameters + ---------- + host + Hostname + user + Username + password + Password + port + Port number + database + Database to connect to + schema + Risingwave schema to use. If `None`, use the default `search_path`. + url + SQLAlchemy connection string. + + If passed, the other connection arguments are ignored. + driver + Database driver + + Examples + -------- + >>> import os + >>> import getpass + >>> import ibis + >>> host = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") + >>> user = os.environ.get("IBIS_TEST_RISINGWAVE_USER", getpass.getuser()) + >>> password = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD") + >>> database = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") + >>> con = connect(database=database, host=host, user=user, password=password) + >>> con.list_tables() # doctest: +ELLIPSIS + [...] + >>> t = con.table("functional_alltypes") + >>> t + RisingwaveTable[table] + name: functional_alltypes + schema: + id : int32 + bool_col : boolean + tinyint_col : int16 + smallint_col : int16 + int_col : int32 + bigint_col : int64 + float_col : float32 + double_col : float64 + date_string_col : string + string_col : string + timestamp_col : timestamp + year : int32 + month : int32 + """ + if driver != "psycopg2": + raise NotImplementedError("psycopg2 is currently the only supported driver") + + alchemy_url = self._build_alchemy_url( + url=url, + host=host, + port=port, + user=user, + password=password, + database=database, + driver=f"risingwave+{driver}", + ) + + connect_args = {} + if schema is not None: + connect_args["options"] = f"-csearch_path={schema}" + + engine = sa.create_engine( + alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool + ) + + @sa.event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + with dbapi_connection.cursor() as cur: + cur.execute("SET TIMEZONE = UTC") + + super().do_connect(engine) + + def list_tables(self, like=None, schema=None): + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + schema + The schema to perform the list against. + + ::: {.callout-warning} + ## `schema` refers to database hierarchy + + The `schema` parameter does **not** refer to the column names and + types of `table`. + ::: + """ + tables = self.inspector.get_table_names(schema=schema) + views = self.inspector.get_view_names(schema=schema) + return self._filter_with_like(tables + views, like) + + def list_databases(self, like=None) -> list[str]: + # http://dba.stackexchange.com/a/1304/58517 + dbs = sa.table( + "pg_database", + sa.column("datname", sa.TEXT()), + sa.column("datistemplate", sa.BOOLEAN()), + schema="pg_catalog", + ) + query = sa.select(dbs.c.datname).where(sa.not_(dbs.c.datistemplate)) + with self.begin() as con: + databases = list(con.execute(query).scalars()) + + return self._filter_with_like(databases, like) + + @property + def current_database(self) -> str: + return self._scalar_query(sa.select(sa.func.current_database())) + + @property + def current_schema(self) -> str: + return self._scalar_query(sa.select(sa.func.current_schema())) + + def function(self, name: str, *, schema: str | None = None) -> Callable: + query = sa.text( + """ +SELECT + n.nspname as schema, + pg_catalog.pg_get_function_result(p.oid) as return_type, + string_to_array(pg_catalog.pg_get_function_arguments(p.oid), ', ') as signature, + CASE p.prokind + WHEN 'a' THEN 'agg' + WHEN 'w' THEN 'window' + WHEN 'p' THEN 'proc' + ELSE 'func' + END as "Type" +FROM pg_catalog.pg_proc p +LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace +WHERE p.proname = :name +""" + + "AND n.nspname OPERATOR(pg_catalog.~) :schema COLLATE pg_catalog.default" + * (schema is not None) + ).bindparams(name=name, schema=f"^({schema})$") + + def split_name_type(arg: str) -> tuple[str, dt.DataType]: + name, typ = arg.split(" ", 1) + return name, RisingwaveType.from_string(typ) + + with self.begin() as con: + rows = con.execute(query).mappings().fetchall() + + if not rows: + name = f"{schema}.{name}" if schema else name + raise exc.MissingUDFError(name) + elif len(rows) > 1: + raise exc.AmbiguousUDFError(name) + + [row] = rows + return_type = RisingwaveType.from_string(row["return_type"]) + signature = list(map(split_name_type, row["signature"])) + + # dummy callable + def fake_func(*args, **kwargs): + ... + + fake_func.__name__ = name + fake_func.__signature__ = inspect.Signature( + [ + inspect.Parameter( + name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typ + ) + for name, typ in signature + ], + return_annotation=return_type, + ) + fake_func.__annotations__ = {"return": return_type, **dict(signature)} + op = ops.udf.scalar.builtin(fake_func, schema=schema) + return op + + def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: + name = util.gen_name("risingwave_metadata") + type_info_sql = """\ + SELECT + attname, + format_type(atttypid, atttypmod) AS type + FROM pg_attribute + WHERE attrelid = CAST(:name AS regclass) + AND attnum > 0 + AND NOT attisdropped + ORDER BY attnum""" + if self.inspector.has_table(query): + query = f"TABLE {query}" + + text = sa.text(type_info_sql).bindparams(name=name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE VIEW IF NOT EXISTS {name} AS {query}") + try: + yield from ( + (col, RisingwaveType.from_string(typestr)) + for col, typestr in con.execute(text) + ) + finally: + con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}") + + def _get_temp_view_definition( + self, name: str, definition: sa.sql.compiler.Compiled + ) -> str: + yield f"DROP VIEW IF EXISTS {name}" + yield f"CREATE TEMPORARY VIEW {name} AS {definition}" + + def create_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support creating a schema in a different database" + ) + if_not_exists = "IF NOT EXISTS " * force + name = self._quote(name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + + def drop_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support dropping a schema in a different database" + ) + name = self._quote(name) + if_exists = "IF EXISTS " * force + with self.begin() as con: + con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py new file mode 100644 index 000000000000..b4bcd9c0b9d5 --- /dev/null +++ b/ibis/backends/risingwave/compiler.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import ibis.expr.operations as ops +from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.backends.risingwave.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample + + +class RisingwaveExprTranslator(AlchemyExprTranslator): + _registry = operation_registry.copy() + _rewrites = AlchemyExprTranslator._rewrites.copy() + _has_reduction_filter_syntax = True + _supports_tuple_syntax = True + _dialect_name = "risingwave" + + # it does support it, but we can't use it because of support for pivot + supports_unnest_in_select = False + + type_mapper = RisingwaveType + + +rewrites = RisingwaveExprTranslator.rewrites + + +@rewrites(ops.Any) +@rewrites(ops.All) +def _any_all_no_op(expr): + return expr + + +class RisingwaveCompiler(AlchemyCompiler): + translator_class = RisingwaveExprTranslator + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/risingwave/datatypes.py b/ibis/backends/risingwave/datatypes.py new file mode 100644 index 000000000000..389210486a6f --- /dev/null +++ b/ibis/backends/risingwave/datatypes.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as psql +import sqlalchemy.types as sat + +import ibis.expr.datatypes as dt +from ibis.backends.base.sql.alchemy.datatypes import AlchemyType +from ibis.backends.base.sqlglot.datatypes import PostgresType as SqlglotPostgresType + +_from_postgres_types = { + psql.DOUBLE_PRECISION: dt.Float64, + psql.JSONB: dt.JSON, + psql.JSON: dt.JSON, + psql.BYTEA: dt.Binary, +} + + +_postgres_interval_fields = { + "YEAR": "Y", + "MONTH": "M", + "DAY": "D", + "HOUR": "h", + "MINUTE": "m", + "SECOND": "s", + "YEAR TO MONTH": "M", + "DAY TO HOUR": "h", + "DAY TO MINUTE": "m", + "DAY TO SECOND": "s", + "HOUR TO MINUTE": "m", + "HOUR TO SECOND": "s", + "MINUTE TO SECOND": "s", +} + + +class RisingwaveType(AlchemyType): + dialect = "risingwave" + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine: + if dtype.is_floating(): + if isinstance(dtype, dt.Float64): + return psql.DOUBLE_PRECISION + else: + return psql.REAL + elif dtype.is_array(): + # Unwrap the array element type because sqlalchemy doesn't allow arrays of + # arrays. This doesn't affect the underlying data. + while dtype.is_array(): + dtype = dtype.value_type + return sa.ARRAY(cls.from_ibis(dtype)) + elif dtype.is_map(): + if not (dtype.key_type.is_string() and dtype.value_type.is_string()): + raise TypeError( + f"Risingwave only supports map, got: {dtype}" + ) + return psql.HSTORE() + elif dtype.is_uuid(): + return psql.UUID() + else: + return super().from_ibis(dtype) + + @classmethod + def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: + if dtype := _from_postgres_types.get(type(typ)): + return dtype(nullable=nullable) + elif isinstance(typ, psql.HSTORE): + return dt.Map(dt.string, dt.string, nullable=nullable) + elif isinstance(typ, psql.INTERVAL): + field = typ.fields.upper() + if (unit := _postgres_interval_fields.get(field, None)) is None: + raise ValueError(f"Unknown Risingwave interval field {field!r}") + elif unit in {"Y", "M"}: + raise ValueError( + "Variable length intervals are not yet supported with Risingwave" + ) + return dt.Interval(unit=unit, nullable=nullable) + else: + return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_string(cls, type_string: str) -> RisingwaveType: + return SqlglotPostgresType.from_string(type_string) diff --git a/ibis/backends/risingwave/registry.py b/ibis/backends/risingwave/registry.py new file mode 100644 index 000000000000..ee9b3d600dfd --- /dev/null +++ b/ibis/backends/risingwave/registry.py @@ -0,0 +1,861 @@ +from __future__ import annotations + +import functools +import itertools +import locale +import operator +import platform +import re +import string + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import GenericFunction + +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops + +# used for literal translate +from ibis.backends.base.sql.alchemy import ( + fixed_arity, + get_sqla_table, + reduction, + sqlalchemy_operation_registry, + sqlalchemy_window_functions_registry, + unary, + varargs, +) +from ibis.backends.base.sql.alchemy.registry import ( + _bitwise_op, + _extract, + get_col, +) + +operation_registry = sqlalchemy_operation_registry.copy() +operation_registry.update(sqlalchemy_window_functions_registry) + +_truncate_precisions = { + "us": "microseconds", + "ms": "milliseconds", + "s": "second", + "m": "minute", + "h": "hour", + "D": "day", + "W": "week", + "M": "month", + "Q": "quarter", + "Y": "year", +} + + +def _timestamp_truncate(t, op): + sa_arg = t.translate(op.arg) + try: + precision = _truncate_precisions[op.unit.short] + except KeyError: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit!r}") + return sa.func.date_trunc(precision, sa_arg) + + +def _timestamp_bucket(t, op): + arg = t.translate(op.arg) + interval = t.translate(op.interval) + + origin = sa.literal_column("timestamp '1970-01-01 00:00:00'") + + if op.offset is not None: + origin = origin + t.translate(op.offset) + return sa.func.date_bin(interval, arg, origin) + + +def _typeof(t, op): + sa_arg = t.translate(op.arg) + typ = sa.cast(sa.func.pg_typeof(sa_arg), sa.TEXT) + + # select pg_typeof('thing') returns unknown so we have to check the child's + # type for nullness + return sa.case( + ((typ == "unknown") & (op.arg.dtype != dt.null), "text"), + ((typ == "unknown") & (op.arg.dtype == dt.null), "null"), + else_=typ, + ) + + +_strftime_to_postgresql_rules = { + "%a": "TMDy", # TM does it in a locale dependent way + "%A": "TMDay", + "%w": "D", # 1-based day of week, see below for how we make this 0-based + "%d": "DD", # day of month + "%-d": "FMDD", # - is no leading zero for Python same for FM in postgres + "%b": "TMMon", # Sep + "%B": "TMMonth", # September + "%m": "MM", # 01 + "%-m": "FMMM", # 1 + "%y": "YY", # 15 + "%Y": "YYYY", # 2015 + "%H": "HH24", # 09 + "%-H": "FMHH24", # 9 + "%I": "HH12", # 09 + "%-I": "FMHH12", # 9 + "%p": "AM", # AM or PM + "%M": "MI", # zero padded minute + "%-M": "FMMI", # Minute + "%S": "SS", # zero padded second + "%-S": "FMSS", # Second + "%f": "US", # zero padded microsecond + "%z": "OF", # utf offset + "%Z": "TZ", # uppercase timezone name + "%j": "DDD", # zero padded day of year + "%-j": "FMDDD", # day of year + "%U": "WW", # 1-based week of year + # 'W': ?, # meh +} + +try: + _strftime_to_postgresql_rules.update( + { + "%c": locale.nl_langinfo(locale.D_T_FMT), # locale date and time + "%x": locale.nl_langinfo(locale.D_FMT), # locale date + "%X": locale.nl_langinfo(locale.T_FMT), # locale time + } + ) +except AttributeError: + HAS_LANGINFO = False +else: + HAS_LANGINFO = True + + +# translate strftime spec into mostly equivalent Risingwave spec +_scanner = re.Scanner( # type: ignore # re does have a Scanner attribute + # double quotes need to be escaped + [('"', lambda *_: r"\"")] + + [ + ( + "|".join( + map( + "(?:{})".format, + itertools.chain( + _strftime_to_postgresql_rules.keys(), + [ + # "%e" is in the C standard and Python actually + # generates this if your spec contains "%c" but we + # don't officially support it as a specifier so we + # need to special case it in the scanner + "%e", + r"\s+", + rf"[{re.escape(string.punctuation)}]", + rf"[^{re.escape(string.punctuation)}\s]+", + ], + ), + ) + ), + lambda _, token: token, + ) + ] +) + + +_lexicon_values = frozenset(_strftime_to_postgresql_rules.values()) + +_locale_specific_formats = frozenset(["%c", "%x", "%X"]) +_strftime_blacklist = frozenset(["%w", "%U", "%e"]) | _locale_specific_formats + + +def _reduce_tokens(tokens, arg): + # current list of tokens + curtokens = [] + + # reduced list of tokens that accounts for blacklisted values + reduced = [] + + non_special_tokens = frozenset(_strftime_to_postgresql_rules) - _strftime_blacklist + + # TODO: how much of a hack is this? + for token in tokens: + if token in _locale_specific_formats and not HAS_LANGINFO: + raise com.UnsupportedOperationError( + f"Format string component {token!r} is not supported on {platform.system()}" + ) + # we are a non-special token %A, %d, etc. + if token in non_special_tokens: + curtokens.append(_strftime_to_postgresql_rules[token]) + + # we have a string like DD, to escape this we + # surround it with double quotes + elif token in _lexicon_values: + curtokens.append(f'"{token}"') + + # we have a token that needs special treatment + elif token in _strftime_blacklist: + if token == "%w": + value = sa.extract("dow", arg) # 0 based day of week + elif token == "%U": + value = sa.cast(sa.func.to_char(arg, "WW"), sa.SMALLINT) - 1 + elif token in ("%c", "%x", "%X"): + # re scan and tokenize this pattern + try: + new_pattern = _strftime_to_postgresql_rules[token] + except KeyError: + raise ValueError( + "locale specific date formats (%%c, %%x, %%X) are " + "not yet implemented for %s" % platform.system() + ) + + new_tokens, _ = _scanner.scan(new_pattern) + value = functools.reduce( + sa.sql.ColumnElement.concat, + _reduce_tokens(new_tokens, arg), + ) + elif token == "%e": + # pad with spaces instead of zeros + value = sa.func.replace(sa.func.to_char(arg, "DD"), "0", " ") + + reduced += [ + sa.func.to_char(arg, "".join(curtokens)), + sa.cast(value, sa.TEXT), + ] + + # empty current token list in case there are more tokens + del curtokens[:] + + # uninteresting text + else: + curtokens.append(token) + # append result to r if we had more tokens or if we have no + # blacklisted tokens + if curtokens: + reduced.append(sa.func.to_char(arg, "".join(curtokens))) + return reduced + + +def _strftime(arg, pattern): + tokens, _ = _scanner.scan(pattern.value) + reduced = _reduce_tokens(tokens, arg) + return functools.reduce(sa.sql.ColumnElement.concat, reduced) + + +def _find_in_set(t, op): + # TODO + # this operation works with any type, not just strings. should the + # operation itself also have this property? + return ( + sa.func.coalesce( + sa.func.array_position( + pg.array(list(map(t.translate, op.values))), + t.translate(op.needle), + ), + 0, + ) + - 1 + ) + + +def _log(t, op): + arg, base = op.args + sa_arg = t.translate(arg) + if base is not None: + sa_base = t.translate(base) + return sa.cast( + sa.func.log(sa.cast(sa_base, sa.NUMERIC), sa.cast(sa_arg, sa.NUMERIC)), + t.get_sqla_type(op.dtype), + ) + return sa.func.ln(sa_arg) + + +def _regex_extract(arg, pattern, index): + # wrap in parens to support 0th group being the whole string + pattern = "(" + pattern + ")" + # arrays are 1-based in postgres + index = index + 1 + does_match = sa.func.textregexeq(arg, pattern) + matches = sa.func.regexp_match(arg, pattern, type_=pg.ARRAY(sa.TEXT)) + return sa.case((does_match, matches[index]), else_=None) + + +def _array_repeat(t, op): + """Repeat an array.""" + arg = t.translate(op.arg) + times = t.translate(op.times) + + array_length = sa.func.cardinality(arg) + array = sa.sql.elements.Grouping(arg) if isinstance(op.arg, ops.Literal) else arg + + # sequence from 1 to the total number of elements desired in steps of 1. + series = sa.func.generate_series(1, times * array_length).table_valued() + + # if our current index modulo the array's length is a multiple of the + # array's length, then the index is the array's length + index = sa.func.coalesce( + sa.func.nullif(series.column % array_length, 0), array_length + ) + + # tie it all together in a scalar subquery and collapse that into an ARRAY + return sa.func.array(sa.select(array[index]).scalar_subquery()) + + +def _table_column(t, op): + ctx = t.context + table = op.table + + sa_table = get_sqla_table(ctx, table) + out_expr = get_col(sa_table, op) + + if op.dtype.is_timestamp(): + timezone = op.dtype.timezone + if timezone is not None: + out_expr = out_expr.op("AT TIME ZONE")(timezone).label(op.name) + + # If the column does not originate from the table set in the current SELECT + # context, we should format as a subquery + if t.permit_subquery and ctx.is_foreign_expr(table): + return sa.select(out_expr) + + return out_expr + + +def _round(t, op): + arg, digits = op.args + sa_arg = t.translate(arg) + + if digits is None: + return sa.func.round(sa_arg) + + # postgres doesn't allow rounding of double precision values to a specific + # number of digits (though simple truncation on doubles is allowed) so + # we cast to numeric and then cast back if necessary + result = sa.func.round(sa.cast(sa_arg, sa.NUMERIC), t.translate(digits)) + if digits is not None and arg.dtype.is_decimal(): + return result + result = sa.cast(result, pg.DOUBLE_PRECISION()) + return result + + +def _mod(t, op): + left, right = map(t.translate, op.args) + + # postgres doesn't allow modulus of double precision values, so upcast and + # then downcast later if necessary + if not op.dtype.is_integer(): + left = sa.cast(left, sa.NUMERIC) + right = sa.cast(right, sa.NUMERIC) + + result = left % right + if op.dtype.is_float64(): + return sa.cast(result, pg.DOUBLE_PRECISION()) + else: + return result + + +def _neg_idx_to_pos(array, idx): + return sa.case((idx < 0, sa.func.cardinality(array) + idx), else_=idx) + + +def _array_slice(*, index_converter, array_length, func): + def translate(t, op): + arg = t.translate(op.arg) + + arg_length = array_length(arg) + + if (start := op.start) is None: + start = 0 + else: + start = t.translate(start) + start = sa.func.least(arg_length, index_converter(arg, start)) + + if (stop := op.stop) is None: + stop = arg_length + else: + stop = index_converter(arg, t.translate(stop)) + + return func(arg, start + 1, stop) + + return translate + + +def _array_index(*, index_converter, func): + def translate(t, op): + sa_array = t.translate(op.arg) + sa_index = t.translate(op.index) + if isinstance(op.arg, ops.Literal): + sa_array = sa.sql.elements.Grouping(sa_array) + return func(sa_array, index_converter(sa_array, sa_index) + 1) + + return translate + + +def _literal(t, op): + dtype = op.dtype + value = op.value + + if value is None: + return ( + sa.null() if dtype.is_null() else sa.cast(sa.null(), t.get_sqla_type(dtype)) + ) + if dtype.is_interval(): + return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'") + elif dtype.is_array(): + return pg.array(value) + elif dtype.is_map(): + return pg.hstore(list(value.keys()), list(value.values())) + elif dtype.is_time(): + return sa.func.make_time( + value.hour, value.minute, value.second + value.microsecond / 1e6 + ) + elif dtype.is_date(): + return sa.func.make_date(value.year, value.month, value.day) + elif dtype.is_timestamp(): + if (tz := dtype.timezone) is not None: + return sa.func.to_timestamp(value.timestamp()).op("AT TIME ZONE")(tz) + return sa.cast(sa.literal(value.isoformat()), sa.TIMESTAMP()) + else: + return sa.literal(value) + + +def _string_agg(t, op): + agg = sa.func.string_agg(t.translate(op.arg), t.translate(op.sep)) + if (where := op.where) is not None: + return agg.filter(t.translate(where)) + return agg + + +def _corr(t, op): + if op.how == "sample": + raise ValueError( + f"{t.__class__.__name__} only implements population correlation " + "coefficient" + ) + return _binary_variance_reduction(sa.func.corr)(t, op) + + +def _covar(t, op): + suffix = {"sample": "samp", "pop": "pop"} + how = suffix.get(op.how, "samp") + func = getattr(sa.func, f"covar_{how}") + return _binary_variance_reduction(func)(t, op) + + +def _mode(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + return sa.func.mode().within_group(t.translate(arg)) + + +def _quantile(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) + + +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) + + +def _binary_variance_reduction(func): + def variance_compiler(t, op): + x = op.left + if (x_type := x.dtype).is_boolean(): + x = ops.Cast(x, dt.Int32(nullable=x_type.nullable)) + + y = op.right + if (y_type := y.dtype).is_boolean(): + y = ops.Cast(y, dt.Int32(nullable=y_type.nullable)) + + if t._has_reduction_filter_syntax: + result = func(t.translate(x), t.translate(y)) + + if (where := op.where) is not None: + return result.filter(t.translate(where)) + return result + else: + if (where := op.where) is not None: + x = ops.IfElse(where, x, None) + y = ops.IfElse(where, y, None) + return func(t.translate(x), t.translate(y)) + + return variance_compiler + + +def _arg_min_max(sort_func): + def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: + arg = t.translate(op.arg) + key = t.translate(op.key) + + conditions = [arg != sa.null(), key != sa.null()] + + agg = sa.func.array_agg(pg.aggregate_order_by(arg, sort_func(key))) + + if (where := op.where) is not None: + conditions.append(t.translate(where)) + return agg.filter(sa.and_(*conditions))[1] + + return translate + + +def _arbitrary(t, op): + if (how := op.how) == "heavy": + raise com.UnsupportedOperationError( + f"risingwave backend doesn't support how={how!r} for the arbitrary() aggregate" + ) + func = getattr(sa.func, op.how) + return t._reduction(func, op) + + +class rw_struct_field(GenericFunction): + inherit_cache = True + + +@compiles(rw_struct_field) +def compile_struct_field_postgresql(element, compiler, **kw): + arg, field = element.clauses + return f"({compiler.process(arg, **kw)}).{field.name}" + + +def _struct_field(t, op): + arg = op.arg + idx = arg.dtype.names.index(op.field) + 1 + field_name = sa.literal_column(f"f{idx:d}") + return rw_struct_field( + t.translate(arg), field_name, type_=t.get_sqla_type(op.dtype) + ) + + +def _struct_column(t, op): + types = op.dtype.types + return sa.func.row( + # we have to cast here, otherwise risingwave refuses to allow the statement + *map(t.translate, map(ops.Cast, op.values, types)), + type_=t.get_sqla_type( + dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)}) + ), + ) + + +def _unnest(t, op): + arg = op.arg + row_type = arg.dtype.value_type + + types = getattr(row_type, "types", (row_type,)) + + is_struct = row_type.is_struct() + derived = ( + sa.func.unnest(t.translate(arg)) + .table_valued( + *( + sa.column(f"f{i:d}", stype) + for i, stype in enumerate(map(t.get_sqla_type, types), start=1) + ) + ) + .render_derived(with_types=is_struct) + ) + + # wrap in a row column so that we can return a single column from this rule + if not is_struct: + return derived.c[0] + return sa.func.row(*derived.c) + + +def _array_sort(arg): + flat = sa.func.unnest(arg).column_valued() + return sa.func.array(sa.select(flat).order_by(flat).scalar_subquery()) + + +def _array_position(haystack, needle): + t = ( + sa.func.unnest(haystack) + .table_valued("value", with_ordinality="idx", name="haystack") + .render_derived() + ) + idx = t.c.idx - 1 + return sa.func.coalesce( + sa.select(idx).where(t.c.value == needle).limit(1).scalar_subquery(), -1 + ) + + +def _array_map(t, op): + return sa.func.array( + # this translates to the function call, with column names the same as + # the parameter names in the lambda + sa.select(t.translate(op.body)) + .select_from( + # unnest the input array + sa.func.unnest(t.translate(op.arg)) + # name the columns of the result the same as the lambda parameter + # so that we can reference them as such in the outer query + .table_valued(op.param) + .render_derived() + ) + .scalar_subquery() + ) + + +def _array_filter(t, op): + param = op.param + return sa.func.array( + sa.select(sa.column(param, type_=t.get_sqla_type(op.arg.dtype.value_type))) + .select_from( + sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived() + ) + .where(t.translate(op.body)) + .scalar_subquery() + ) + + +def zero_value(dtype): + if dtype.is_interval(): + return sa.func.make_interval() + return 0 + + +def interval_sign(v): + zero = sa.func.make_interval() + return sa.case((v == zero, 0), (v < zero, -1), (v > zero, 1)) + + +def _sign(value, dtype): + if dtype.is_interval(): + return interval_sign(value) + return sa.func.sign(value) + + +def _range(t, op): + start = t.translate(op.start) + stop = t.translate(op.stop) + step = t.translate(op.step) + satype = t.get_sqla_type(op.dtype) + seq = sa.func.generate_series(start, stop, step, type_=satype) + zero = zero_value(op.step.dtype) + return sa.case( + ( + sa.and_( + sa.func.nullif(step, zero).is_not(None), + _sign(step, op.step.dtype) == _sign(stop - start, op.step.dtype), + ), + sa.func.array_remove( + sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype + ), + ), + else_=sa.cast(pg.array([]), satype), + ) + + +operation_registry.update( + { + ops.Literal: _literal, + # We override this here to support time zones + ops.TableColumn: _table_column, + ops.Argument: lambda t, op: sa.column( + op.param, type_=t.get_sqla_type(op.dtype) + ), + # types + ops.TypeOf: _typeof, + # Floating + ops.IsNan: fixed_arity(lambda arg: arg == float("nan"), 1), + ops.IsInf: fixed_arity( + lambda arg: sa.or_(arg == float("inf"), arg == float("-inf")), 1 + ), + # boolean reductions + ops.Any: reduction(sa.func.bool_or), + ops.All: reduction(sa.func.bool_and), + # strings + ops.GroupConcat: _string_agg, + ops.Capitalize: unary(sa.func.initcap), + ops.RegexSearch: fixed_arity(lambda x, y: x.op("~")(y), 2), + # postgres defaults to replacing only the first occurrence + ops.RegexReplace: fixed_arity( + lambda string, pattern, replacement: sa.func.regexp_replace( + string, pattern, replacement, "g" + ), + 3, + ), + ops.Translate: fixed_arity(sa.func.translate, 3), + ops.RegexExtract: fixed_arity(_regex_extract, 3), + ops.StringSplit: fixed_arity( + lambda col, sep: sa.func.string_to_array( + col, sep, type_=sa.ARRAY(col.type) + ), + 2, + ), + ops.FindInSet: _find_in_set, + # math + ops.Log: _log, + ops.Log2: unary(lambda x: sa.func.log(2, x)), + ops.Log10: unary(sa.func.log), + ops.Round: _round, + ops.Modulus: _mod, + # dates and times + ops.DateFromYMD: fixed_arity(sa.func.make_date, 3), + ops.DateTruncate: _timestamp_truncate, + ops.TimestampTruncate: _timestamp_truncate, + ops.TimestampBucket: _timestamp_bucket, + ops.IntervalFromInteger: ( + lambda t, op: t.translate(op.arg) + * sa.text(f"INTERVAL '1 {op.dtype.resolution}'") + ), + ops.DateAdd: fixed_arity(operator.add, 2), + ops.DateSub: fixed_arity(operator.sub, 2), + ops.DateDiff: fixed_arity(operator.sub, 2), + ops.TimestampAdd: fixed_arity(operator.add, 2), + ops.TimestampSub: fixed_arity(operator.sub, 2), + ops.TimestampDiff: fixed_arity(operator.sub, 2), + ops.Strftime: fixed_arity(_strftime, 2), + ops.ExtractEpochSeconds: fixed_arity( + lambda arg: sa.cast(sa.extract("epoch", arg), sa.INTEGER), 1 + ), + ops.ExtractDayOfYear: _extract("doy"), + ops.ExtractWeekOfYear: _extract("week"), + # extracting the second gives us the fractional part as well, so smash that + # with a cast to SMALLINT + ops.ExtractSecond: fixed_arity( + lambda arg: sa.cast(sa.func.floor(sa.extract("second", arg)), sa.SMALLINT), + 1, + ), + # we get total number of milliseconds including seconds with extract so we + # mod 1000 + ops.ExtractMillisecond: fixed_arity( + lambda arg: sa.cast( + sa.func.floor(sa.extract("millisecond", arg)) % 1000, + sa.SMALLINT, + ), + 1, + ), + ops.DayOfWeekIndex: fixed_arity( + lambda arg: sa.cast( + sa.cast(sa.extract("dow", arg) + 6, sa.SMALLINT) % 7, sa.SMALLINT + ), + 1, + ), + ops.DayOfWeekName: fixed_arity( + lambda arg: sa.func.trim(sa.func.to_char(arg, "Day")), 1 + ), + ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3), + # array operations + ops.ArrayLength: unary(sa.func.cardinality), + ops.ArrayCollect: reduction(sa.func.array_agg), + ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))), + ops.ArraySlice: _array_slice( + index_converter=_neg_idx_to_pos, + array_length=sa.func.cardinality, + func=lambda arg, start, stop: arg[start:stop], + ), + ops.ArrayIndex: _array_index( + index_converter=_neg_idx_to_pos, func=lambda arg, index: arg[index] + ), + ops.ArrayConcat: varargs(lambda *args: functools.reduce(operator.add, args)), + ops.ArrayRepeat: _array_repeat, + ops.Unnest: _unnest, + ops.Covariance: _covar, + ops.Correlation: _corr, + ops.BitwiseXor: _bitwise_op("#"), + ops.Mode: _mode, + ops.ApproxMedian: _median, + ops.Median: _median, + ops.Quantile: _quantile, + ops.MultiQuantile: _quantile, + ops.TimestampNow: lambda t, op: sa.literal_column( + "CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.dtype) + ), + ops.MapGet: fixed_arity( + lambda arg, key, default: sa.case( + (arg.has_key(key), arg[key]), else_=default + ), + 3, + ), + ops.MapContains: fixed_arity(pg.HSTORE.Comparator.has_key, 2), + ops.MapKeys: unary(pg.HSTORE.Comparator.keys), + ops.MapValues: unary(pg.HSTORE.Comparator.vals), + ops.MapMerge: fixed_arity(operator.add, 2), + ops.MapLength: unary(lambda arg: sa.func.cardinality(arg.keys())), + ops.Map: fixed_arity(pg.hstore, 2), + ops.ArgMin: _arg_min_max(sa.asc), + ops.ArgMax: _arg_min_max(sa.desc), + ops.ToJSONArray: unary( + lambda arg: sa.case( + ( + sa.func.json_typeof(arg) == "array", + sa.func.array( + sa.select( + sa.func.json_array_elements(arg).column_valued() + ).scalar_subquery() + ), + ), + else_=sa.null(), + ) + ), + ops.ArrayStringJoin: fixed_arity( + lambda sep, arr: sa.func.array_to_string(arr, sep), 2 + ), + ops.Strip: unary(lambda arg: sa.func.trim(arg, string.whitespace)), + ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)), + ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), + ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2), + ops.Arbitrary: _arbitrary, + ops.StructColumn: _struct_column, + ops.StructField: _struct_field, + ops.First: reduction(sa.func.first), + ops.Last: reduction(sa.func.last), + ops.ExtractMicrosecond: fixed_arity( + lambda arg: sa.extract("microsecond", arg) % 1_000_000, 1 + ), + ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2), + ops.ArraySort: fixed_arity(_array_sort, 1), + ops.ArrayIntersect: fixed_arity( + lambda left, right: sa.func.array( + sa.intersect( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayRemove: fixed_arity( + lambda left, right: sa.func.array( + sa.except_( + sa.select(sa.func.unnest(left).column_valued()), sa.select(right) + ).scalar_subquery() + ), + 2, + ), + ops.ArrayUnion: fixed_arity( + lambda left, right: sa.func.array( + sa.union( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayDistinct: fixed_arity( + lambda arg: sa.case( + (arg.is_(sa.null()), sa.null()), + else_=sa.func.array( + sa.select( + sa.distinct(sa.func.unnest(arg).column_valued()) + ).scalar_subquery() + ), + ), + 1, + ), + ops.ArrayPosition: fixed_arity(_array_position, 2), + ops.ArrayMap: _array_map, + ops.ArrayFilter: _array_filter, + ops.IntegerRange: _range, + ops.TimestampRange: _range, + ops.RegexSplit: fixed_arity(sa.func.regexp_split_to_array, 2), + } +) diff --git a/ibis/backends/risingwave/tests/__init__.py b/ibis/backends/risingwave/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/backends/risingwave/tests/conftest.py b/ibis/backends/risingwave/tests/conftest.py new file mode 100644 index 000000000000..35cfe6b8e1db --- /dev/null +++ b/ibis/backends/risingwave/tests/conftest.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +import pytest +import sqlalchemy as sa + +import ibis +from ibis.backends.conftest import init_database +from ibis.backends.tests.base import ServiceBackendTest + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + +PG_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", os.environ.get("PGUSER", "root")) +PG_PASS = os.environ.get( + "IBIS_TEST_RISINGWAVE_PASSWORD", os.environ.get("PGPASSWORD", "") +) +PG_HOST = os.environ.get( + "IBIS_TEST_RISINGWAVE_HOST", os.environ.get("PGHOST", "localhost") +) +PG_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", os.environ.get("PGPORT", 4566)) +IBIS_TEST_RISINGWAVE_DB = os.environ.get( + "IBIS_TEST_RISINGWAVE_DATABASE", os.environ.get("PGDATABASE", "dev") +) + + +class TestConf(ServiceBackendTest): + # postgres rounds half to even for double precision and half away from zero + # for numeric and decimal + + returned_timestamp_unit = "s" + supports_structs = False + rounding_method = "half_to_even" + service_name = "risingwave" + deps = "psycopg2", "sqlalchemy" + + @property + def test_files(self) -> Iterable[Path]: + return self.data_dir.joinpath("csv").glob("*.csv") + + def _load_data( + self, + *, + user: str = PG_USER, + password: str = PG_PASS, + host: str = PG_HOST, + port: int = PG_PORT, + database: str = IBIS_TEST_RISINGWAVE_DB, + **_: Any, + ) -> None: + """Load test data into a Risingwave backend instance. + + Parameters + ---------- + data_dir + Location of test data + script_dir + Location of scripts defining schemas + """ + init_database( + url=sa.engine.make_url( + f"risingwave://{user}:{password}@{host}:{port:d}/{database}" + ), + database=database, + schema=self.ddl_script, + isolation_level="AUTOCOMMIT", + recreate=False, + ) + + @staticmethod + def connect(*, tmpdir, worker_id, port: int | None = None, **kw): + con = ibis.risingwave.connect( + host=PG_HOST, + port=port or PG_PORT, + user=PG_USER, + password=PG_PASS, + database=IBIS_TEST_RISINGWAVE_DB, + **kw, + ) + cursor = con.raw_sql("SET RW_IMPLICIT_FLUSH TO true;") + cursor.close() + return con + + +@pytest.fixture(scope="session") +def con(tmp_path_factory, data_dir, worker_id): + return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection + + +@pytest.fixture(scope="module") +def db(con): + return con.database() + + +@pytest.fixture(scope="module") +def alltypes(db): + return db.functional_alltypes + + +@pytest.fixture(scope="module") +def df(alltypes): + return alltypes.execute() + + +@pytest.fixture(scope="module") +def alltypes_sqla(con, alltypes): + name = alltypes.op().name + return con._get_sqla_table(name) + + +@pytest.fixture(scope="module") +def intervals(con): + return con.table("intervals") + + +@pytest.fixture +def translate(): + from ibis.backends.risingwave import Backend + + context = Backend.compiler.make_context() + return lambda expr: Backend.compiler.translator_class(expr, context).get_result() diff --git a/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql new file mode 100644 index 000000000000..cfbcf133a863 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql @@ -0,0 +1,2 @@ +SELECT sum(t0.foo) AS "Sum(foo)" +FROM t0 AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql new file mode 100644 index 000000000000..c00dec1bed25 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql @@ -0,0 +1,7 @@ +SELECT + RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS rank, + DENSE_RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS dense_rank, + CUME_DIST() OVER (ORDER BY t0.double_col ASC) AS cume_dist, + NTILE(7) OVER (ORDER BY t0.double_col ASC) - 1 AS ntile, + PERCENT_RANK() OVER (ORDER BY t0.double_col ASC) AS percent_rank +FROM functional_alltypes AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql new file mode 100644 index 000000000000..34761d9a76e0 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION ALL SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION ALL SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql new file mode 100644 index 000000000000..6ce31e7468bb --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/test_client.py b/ibis/backends/risingwave/tests/test_client.py new file mode 100644 index 000000000000..b5c7cfa98560 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_client.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os + +import pandas as pd +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis.tests.util import assert_equal + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + +RISINGWAVE_TEST_DB = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") +IBIS_RISINGWAVE_HOST = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") +IBIS_RISINGWAVE_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", "4566") +IBIS_RISINGWAVE_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", "root") +IBIS_RISINGWAVE_PASS = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD", "") + + +def test_table(alltypes): + assert isinstance(alltypes, ir.Table) + + +def test_array_execute(alltypes): + d = alltypes.limit(10).double_col + s = d.execute() + assert isinstance(s, pd.Series) + assert len(s) == 10 + + +def test_literal_execute(con): + expr = ibis.literal("1234") + result = con.execute(expr) + assert result == "1234" + + +def test_simple_aggregate_execute(alltypes): + d = alltypes.double_col.sum() + v = d.execute() + assert isinstance(v, float) + + +def test_list_tables(con): + assert con.list_tables() + assert len(con.list_tables(like="functional")) == 1 + + +def test_compile_toplevel(snapshot): + t = ibis.table([("foo", "double")], name="t0") + + expr = t.foo.sum() + result = ibis.postgres.compile(expr) + snapshot.assert_match(str(result), "out.sql") + + +def test_list_databases(con): + assert RISINGWAVE_TEST_DB is not None + assert RISINGWAVE_TEST_DB in con.list_databases() + + +def test_schema_type_conversion(con): + typespec = [ + # name, type, nullable + ("jsonb", postgresql.JSONB, True, dt.JSON), + ] + + sqla_types = [] + ibis_types = [] + for name, t, nullable, ibis_type in typespec: + sqla_types.append(sa.Column(name, t, nullable=nullable)) + ibis_types.append((name, ibis_type(nullable=nullable))) + + # Create a table with placeholder stubs for JSON, JSONB, and UUID. + table = sa.Table("tname", sa.MetaData(), *sqla_types) + + # Check that we can correctly create a schema with dt.any for the + # missing types. + schema = con._schema_from_sqla_table(table) + expected = ibis.schema(ibis_types) + + assert_equal(schema, expected) + + +@pytest.mark.parametrize("params", [{}, {"database": RISINGWAVE_TEST_DB}]) +def test_create_and_drop_table(con, temp_table, params): + sch = ibis.schema( + [ + ("first_name", "string"), + ("last_name", "string"), + ("department_name", "string"), + ("salary", "float64"), + ] + ) + + con.create_table(temp_table, schema=sch, **params) + assert con.table(temp_table, **params) is not None + + con.drop_table(temp_table, **params) + + with pytest.raises(sa.exc.NoSuchTableError): + con.table(temp_table, **params) + + +@pytest.mark.parametrize( + ("pg_type", "expected_type"), + [ + param(pg_type, ibis_type, id=pg_type.lower()) + for (pg_type, ibis_type) in [ + ("boolean", dt.boolean), + ("bytea", dt.binary), + ("bigint", dt.int64), + ("smallint", dt.int16), + ("integer", dt.int32), + ("text", dt.string), + ("real", dt.float32), + ("double precision", dt.float64), + ("character varying", dt.string), + ("date", dt.date), + ("time", dt.time), + ("time without time zone", dt.time), + ("timestamp without time zone", dt.timestamp), + ("timestamp with time zone", dt.Timestamp("UTC")), + ("interval", dt.Interval("s")), + ("numeric", dt.decimal), + ("jsonb", dt.json), + ] + ], +) +def test_get_schema_from_query(con, pg_type, expected_type): + name = con._quote(ibis.util.guid()) + with con.begin() as c: + c.exec_driver_sql(f"CREATE TABLE {name} (x {pg_type}, y {pg_type}[])") + expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type))) + result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}") + assert result_schema == expected_schema + with con.begin() as c: + c.exec_driver_sql(f"DROP TABLE {name}") + + +@pytest.mark.xfail(reason="unsupported insert with CTEs") +def test_insert_with_cte(con): + X = con.create_table("X", schema=ibis.schema(dict(id="int")), temp=False) + expr = X.join(X.mutate(a=X["id"] + 1), ["id"]) + Y = con.create_table("Y", expr, temp=False) + assert Y.execute().empty + con.drop_table("Y") + con.drop_table("X") + + +def test_connect_url_with_empty_host(): + con = ibis.connect("risingwave:///dev") + assert con.con.url.host is None diff --git a/ibis/backends/risingwave/tests/test_functions.py b/ibis/backends/risingwave/tests/test_functions.py new file mode 100644 index 000000000000..c8874e390c60 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_functions.py @@ -0,0 +1,1032 @@ +from __future__ import annotations + +import operator +import string +import warnings +from datetime import datetime + +import numpy as np +import pandas as pd +import pandas.testing as tm +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis import config +from ibis import literal as L + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + + +@pytest.mark.parametrize( + ("left_func", "right_func"), + [ + param( + lambda t: t.double_col.cast("int8"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int8", + ), + param( + lambda t: t.double_col.cast("int16"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int16", + ), + param( + lambda t: t.string_col.cast("double"), + lambda at: sa.cast(at.c.string_col, postgresql.DOUBLE_PRECISION), + id="string_to_double", + ), + param( + lambda t: t.string_col.cast("float32"), + lambda at: sa.cast(at.c.string_col, postgresql.REAL), + id="string_to_float", + ), + param( + lambda t: t.string_col.cast("decimal"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC()), + id="string_to_decimal_no_params", + ), + param( + lambda t: t.string_col.cast("decimal(9, 3)"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC(9, 3)), + id="string_to_decimal_params", + ), + ], +) +def test_cast(alltypes, alltypes_sqla, translate, left_func, right_func): + left = left_func(alltypes) + right = right_func(alltypes_sqla.alias("t0")) + assert str(translate(left.op()).compile()) == str(right.compile()) + + +def test_date_cast(alltypes, alltypes_sqla, translate): + result = alltypes.date_string_col.cast("date") + expected = sa.cast(alltypes_sqla.alias("t0").c.date_string_col, sa.DATE) + assert str(translate(result.op())) == str(expected) + + +@pytest.mark.parametrize( + "column", + [ + "id", + "bool_col", + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "date_string_col", + "string_col", + "timestamp_col", + "year", + "month", + ], +) +def test_noop_cast(alltypes, alltypes_sqla, translate, column): + col = alltypes[column] + result = col.cast(col.type()) + expected = alltypes_sqla.alias("t0").c[column] + assert result.equals(col) + assert str(translate(result.op())) == str(expected) + + +def test_timestamp_cast_noop(alltypes, alltypes_sqla, translate): + # See GH #592 + result1 = alltypes.timestamp_col.cast("timestamp") + result2 = alltypes.int_col.cast("timestamp") + + assert isinstance(result1, ir.TimestampColumn) + assert isinstance(result2, ir.TimestampColumn) + + expected1 = alltypes_sqla.alias("t0").c.timestamp_col + expected2 = sa.cast( + sa.func.to_timestamp(alltypes_sqla.alias("t0").c.int_col), sa.TIMESTAMP() + ) + + assert str(translate(result1.op())) == str(expected1) + assert str(translate(result2.op())) == str(expected2) + + +@pytest.mark.parametrize(("value", "expected"), [(0, None), (5.5, 5.5)]) +def test_nullif_zero(con, value, expected): + assert con.execute(L(value).nullif(0)) == expected + + +@pytest.mark.parametrize(("value", "expected"), [("foo_bar", 7), ("", 0)]) +def test_string_length(con, value, expected): + assert con.execute(L(value).length()) == expected + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + param(operator.methodcaller("left", 3), "foo", id="left"), + param(operator.methodcaller("right", 3), "bar", id="right"), + param(operator.methodcaller("substr", 0, 3), "foo", id="substr_0_3"), + param(operator.methodcaller("substr", 4, 3), "bar", id="substr_4, 3"), + param(operator.methodcaller("substr", 1), "oo_bar", id="substr_1"), + ], +) +def test_string_substring(con, op, expected): + value = L("foo_bar") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "expected"), + [("lstrip", "foo "), ("rstrip", " foo"), ("strip", "foo")], +) +def test_string_strip(con, opname, expected): + op = operator.methodcaller(opname) + value = L(" foo ") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "count", "char", "expected"), + [("lpad", 6, " ", " foo"), ("rpad", 6, " ", "foo ")], +) +def test_string_pad(con, opname, count, char, expected): + op = operator.methodcaller(opname, count, char) + value = L("foo") + assert con.execute(op(value)) == expected + + +def test_string_reverse(con): + assert con.execute(L("foo").reverse()) == "oof" + + +def test_string_upper(con): + assert con.execute(L("foo").upper()) == "FOO" + + +def test_string_lower(con): + assert con.execute(L("FOO").lower()) == "foo" + + +@pytest.mark.parametrize( + ("haystack", "needle", "expected"), + [ + ("foobar", "bar", True), + ("foobar", "foo", True), + ("foobar", "baz", False), + ("100%", "%", True), + ("a_b_c", "_", True), + ], +) +def test_string_contains(con, haystack, needle, expected): + value = L(haystack) + expr = value.contains(needle) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")], +) +def test_capitalize(con, value, expected): + assert con.execute(L(value).capitalize()) == expected + + +def test_repeat(con): + expr = L("bar ").repeat(3) + assert con.execute(expr) == "bar bar bar " + + +def test_re_replace(con): + expr = L("fudge|||chocolate||candy").re_replace("\\|{2,3}", ", ") + assert con.execute(expr) == "fudge, chocolate, candy" + + +def test_translate(con): + expr = L("faab").translate("a", "b") + assert con.execute(expr) == "fbbb" + + +@pytest.mark.parametrize( + ("raw_value", "expected"), [("a", 0), ("b", 1), ("d", -1), (None, 3)] +) +def test_find_in_set(con, raw_value, expected): + value = L(raw_value, dt.string) + haystack = ["a", "b", "c", None] + expr = value.find_in_set(haystack) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("raw_value", "opname", "expected"), + [ + (None, "isnull", True), + (1, "isnull", False), + (None, "notnull", False), + (1, "notnull", True), + ], +) +def test_isnull_notnull(con, raw_value, opname, expected): + lit = L(raw_value) + op = operator.methodcaller(opname) + expr = op(lit) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("foobar").find("bar"), 3, id="find_pos"), + param(L("foobar").find("baz"), -1, id="find_neg"), + param(L("foobar").like("%bar"), True, id="like_left_pattern"), + param(L("foobar").like("foo%"), True, id="like_right_pattern"), + param(L("foobar").like("%baz%"), False, id="like_both_sides_pattern"), + param(L("foobar").like(["%bar"]), True, id="like_list_left_side"), + param(L("foobar").like(["foo%"]), True, id="like_list_right_side"), + param(L("foobar").like(["%baz%"]), False, id="like_list_both_sides"), + param(L("foobar").like(["%bar", "foo%"]), True, id="like_list_multiple"), + param(L("foobarfoo").replace("foo", "H"), "HbarH", id="replace"), + param(L("a").ascii_str(), ord("a"), id="ascii_str"), + ], +) +def test_string_functions(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("abcd").re_search("[a-z]"), True, id="re_search_match"), + param(L("abcd").re_search(r"[\d]+"), False, id="re_search_no_match"), + param(L("1222").re_search(r"[\d]+"), True, id="re_search_match_number"), + ], +) +def test_regexp(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.NA.fillna(5), 5, id="filled"), + param(L(5).fillna(10), 5, id="not_filled"), + param(L(5).nullif(5), None, id="nullif_null"), + param(L(10).nullif(5), 10, id="nullif_not_null"), + ], +) +def test_fillna_nullif(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(5, None, 4), 5, id="first"), + param(ibis.coalesce(ibis.NA, 4, ibis.NA), 4, id="second"), + param(ibis.coalesce(ibis.NA, ibis.NA, 3.14), 3.14, id="third"), + ], +) +def test_coalesce(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(ibis.NA, ibis.NA), None, id="all_null"), + param( + ibis.coalesce( + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ), + None, + id="all_nulls_with_all_cast", + ), + ], +) +def test_coalesce_all_na(con, expr, expected): + assert con.execute(expr) is None + + +def test_coalesce_all_na_double(con): + expr = ibis.coalesce(ibis.NA, ibis.NA, ibis.NA.cast("double")) + assert np.isnan(con.execute(expr)) + + +def test_numeric_builtins_work(alltypes, df): + expr = alltypes.double_col.fillna(0) + result = expr.execute() + expected = df.double_col.fillna(0) + expected.name = "Coalesce()" + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("op", "pandas_op"), + [ + param( + lambda t: (t.double_col > 20).ifelse(10, -20), + lambda df: pd.Series(np.where(df.double_col > 20, 10, -20), dtype="int8"), + id="simple", + ), + param( + lambda t: (t.double_col > 20).ifelse(10, -20).abs(), + lambda df: pd.Series( + np.where(df.double_col > 20, 10, -20), dtype="int8" + ).abs(), + id="abs", + ), + ], +) +def test_ifelse(alltypes, df, op, pandas_op): + expr = op(alltypes) + result = expr.execute() + result.name = None + expected = pandas_op(df) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + # tier and histogram + param( + lambda d: d.bucket([0, 10, 25, 50, 100]), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_over_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], include_over=True), + lambda s: pd.cut( + s, [0, 10, 25, 50, np.inf], right=False, labels=False + ).astype("int8"), + id="include_over_true", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], close_extreme=False), + lambda s: pd.cut(s, [0, 10, 25, 50], right=False, labels=False), + id="close_extreme_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], closed="right", close_extreme=False), + lambda s: pd.cut( + s, + [0, 10, 25, 50], + include_lowest=False, + right=True, + labels=False, + ), + id="closed_right", + ), + param( + lambda d: d.bucket([10, 25, 50, 100], include_under=True), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_under_true", + ), + ], +) +def test_bucket(alltypes, df, func, pandas_func): + expr = func(alltypes.double_col) + result = expr.execute() + expected = pandas_func(df.double_col) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_category_label(alltypes, df): + t = alltypes + d = t.double_col + + bins = [0, 10, 25, 50, 100] + labels = ["a", "b", "c", "d"] + bucket = d.bucket(bins) + expr = bucket.label(labels) + result = expr.execute() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = pd.Series(pd.Categorical(result, ordered=True)) + + result.name = "double_col" + + expected = pd.cut(df.double_col, bins, labels=labels, right=False) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("distinct", [True, False]) +def test_union_cte(alltypes, distinct, snapshot): + t = alltypes + expr1 = t.group_by(t.string_col).aggregate(metric=t.double_col.sum()) + expr2 = expr1.view() + expr3 = expr1.view() + expr = expr1.union(expr2, distinct=distinct).union(expr3, distinct=distinct) + result = " ".join( + line.strip() + for line in str( + expr.compile().compile(compile_kwargs={"literal_binds": True}) + ).splitlines() + ) + snapshot.assert_match(result, "out.sql") + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + param( + lambda t, cond: t.bool_col.count(), + lambda df, cond: df.bool_col.count(), + id="count", + ), + param( + lambda t, cond: t.double_col.mean(), + lambda df, cond: df.double_col.mean(), + id="mean", + ), + param( + lambda t, cond: t.double_col.min(), + lambda df, cond: df.double_col.min(), + id="min", + ), + param( + lambda t, cond: t.double_col.max(), + lambda df, cond: df.double_col.max(), + id="max", + ), + param( + lambda t, cond: t.double_col.var(), + lambda df, cond: df.double_col.var(), + id="var", + ), + param( + lambda t, cond: t.double_col.std(), + lambda df, cond: df.double_col.std(), + id="std", + ), + param( + lambda t, cond: t.double_col.var(how="sample"), + lambda df, cond: df.double_col.var(ddof=1), + id="samp_var", + ), + param( + lambda t, cond: t.double_col.std(how="pop"), + lambda df, cond: df.double_col.std(ddof=0), + id="pop_std", + ), + param( + lambda t, cond: t.bool_col.count(where=cond), + lambda df, cond: df.bool_col[cond].count(), + id="count_where", + ), + param( + lambda t, cond: t.double_col.mean(where=cond), + lambda df, cond: df.double_col[cond].mean(), + id="mean_where", + ), + param( + lambda t, cond: t.double_col.min(where=cond), + lambda df, cond: df.double_col[cond].min(), + id="min_where", + ), + param( + lambda t, cond: t.double_col.max(where=cond), + lambda df, cond: df.double_col[cond].max(), + id="max_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond), + lambda df, cond: df.double_col[cond].var(), + id="var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond), + lambda df, cond: df.double_col[cond].std(), + id="std_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond, how="sample"), + lambda df, cond: df.double_col[cond].var(), + id="samp_var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond, how="pop"), + lambda df, cond: df.double_col[cond].std(ddof=0), + id="pop_std_where", + ), + ], +) +def test_aggregations(alltypes, df, func, pandas_func): + table = alltypes.limit(100) + df = df.head(table.count().execute()) + + cond = table.string_col.isin(["1", "7"]) + expr = func(table, cond) + result = expr.execute() + expected = pandas_func(df, cond.execute()) + + np.testing.assert_allclose(result, expected) + + +def test_not_contains(alltypes, df): + n = 100 + table = alltypes.limit(n) + expr = table.string_col.notin(["1", "7"]) + result = expr.execute() + expected = ~df.head(n).string_col.isin(["1", "7"]) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_group_concat(alltypes, df): + expr = alltypes.string_col.group_concat() + result = expr.execute() + expected = ",".join(df.string_col.dropna()) + assert result == expected + + +def test_distinct_aggregates(alltypes, df): + expr = alltypes.limit(100).double_col.nunique() + result = expr.execute() + assert result == df.head(100).double_col.nunique() + + +def test_not_exists(alltypes, df): + t = alltypes + t2 = t.view() + + expr = t[~((t.string_col == t2.string_col).any())] + result = expr.execute() + + left, right = df, t2.execute() + expected = left[left.string_col != right.string_col] + + tm.assert_frame_equal(result, expected, check_index_type=False, check_dtype=False) + + +def test_interactive_repr_shows_error(alltypes): + # #591. Doing this in Postgres because so many built-in functions are + # not available + + expr = alltypes.int_col.convert_base(10, 2) + + with config.option_context("interactive", True): + result = repr(expr) + + assert "no translation rule" in result.lower() + + +def test_subquery(alltypes, df): + t = alltypes + + expr = t.mutate(d=t.double_col.fillna(0)).limit(1000).group_by("string_col").size() + result = expr.execute().sort_values("string_col").reset_index(drop=True) + expected = ( + df.assign(d=df.double_col.fillna(0)) + .head(1000) + .groupby("string_col") + .string_col.count() + .rename("CountStar()") + .reset_index() + .sort_values("string_col") + .reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +def test_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + df_f = getattr(df.double_col, func) + result = t.select((t.double_col - f()).name("double_col")).execute().double_col + expected = df.double_col - df_f() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_rolling_window(alltypes, func, df): + t = alltypes + df = ( + df[["double_col", "timestamp_col"]] + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + window = ibis.window(order_by=t.timestamp_col, preceding=6, following=0) + f = getattr(t.double_col, func) + df_f = getattr(df.double_col.rolling(7, min_periods=0), func) + result = t.select(f().over(window).name("double_col")).execute().double_col + expected = df_f() + tm.assert_series_equal(result, expected) + + +def test_rolling_window_with_mlb(alltypes): + t = alltypes + window = ibis.trailing_window( + preceding=ibis.rows_with_max_lookback(3, ibis.interval(days=5)), + order_by=t.timestamp_col, + ) + expr = t["double_col"].sum().over(window) + with pytest.raises(NotImplementedError): + expr.execute() + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_partitioned_window(alltypes, func, df): + t = alltypes + window = ibis.window( + group_by=t.string_col, + order_by=t.timestamp_col, + preceding=6, + following=0, + ) + + def roller(func): + def rolled(df): + torder = df.sort_values("timestamp_col") + rolling = torder.double_col.rolling(7, min_periods=0) + return getattr(rolling, func)() + + return rolled + + f = getattr(t.double_col, func) + expr = f().over(window).name("double_col") + result = t.select(expr).execute().double_col + expected = df.groupby("string_col").apply(roller(func)).reset_index(drop=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + col = t.double_col - f().over(ibis.cumulative_window()) + expr = t.select(col.name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_ordered_window(alltypes, func, df): + t = alltypes + df = df.sort_values("timestamp_col").reset_index(drop=True) + window = ibis.cumulative_window(order_by=t.timestamp_col) + f = getattr(t.double_col, func) + expr = t.select((t.double_col - f().over(window)).name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "shift_amount"), [("lead", -1), ("lag", 1)], ids=["lead", "lag"] +) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_analytic_shift_functions(alltypes, df, func, shift_amount): + method = getattr(alltypes.double_col, func) + expr = method(1) + result = expr.execute().rename("double_col") + expected = df.double_col.shift(shift_amount) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "expected_index"), [("first", -1), ("last", 0)], ids=["first", "last"] +) +@pytest.mark.xfail(reason="Unsupported expr: (first(t0.double_col) + 1) - 1") +def test_first_last_value(alltypes, df, func, expected_index): + col = alltypes.order_by(ibis.desc(alltypes.string_col)).double_col + method = getattr(col, func) + # test that we traverse into expression trees + expr = (1 + method()) - 1 + result = expr.execute() + expected = df.double_col.iloc[expected_index] + assert result == expected + + +def test_null_column(alltypes): + t = alltypes + nrows = t.count().execute() + expr = t.mutate(na_column=ibis.NA).na_column + result = expr.execute() + tm.assert_series_equal(result, pd.Series([None] * nrows, name="na_column")) + + +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_window_with_arithmetic(alltypes, df): + t = alltypes + w = ibis.window(order_by=t.timestamp_col) + expr = t.mutate(new_col=ibis.row_number().over(w) / 2) + + df = df[["timestamp_col"]].sort_values("timestamp_col").reset_index(drop=True) + expected = df.assign(new_col=[x / 2.0 for x in range(len(df))]) + result = expr["timestamp_col", "new_col"].execute() + tm.assert_frame_equal(result, expected) + + +def test_anonymous_aggregate(alltypes, df): + t = alltypes + expr = t[t.double_col > t.double_col.mean()] + result = expr.execute() + expected = df[df.double_col > df.double_col.mean()].reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def array_types(con): + return con.table("array_types") + + +@pytest.mark.xfail( + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype" +) +def test_array_length(array_types): + expr = array_types.select( + array_types.x.length().name("x_length"), + array_types.y.length().name("y_length"), + array_types.z.length().name("z_length"), + ) + result = expr.execute() + expected = pd.DataFrame( + { + "x_length": [3, 2, 2, 3, 3, 4], + "y_length": [3, 2, 2, 3, 3, 4], + "z_length": [3, 2, 2, 0, None, 4], + } + ) + result_sorted = result.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + expected_sorted = expected.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + tm.assert_frame_equal(result_sorted, expected_sorted) + + +def custom_sort_none_first(arr): + return sorted(arr, key=lambda x: (x is not None, x)) + + +def test_head(con): + t = con.table("functional_alltypes") + result = t.head().execute() + expected = t.limit(5).execute() + tm.assert_frame_equal(result, expected) + + +def test_identical_to(con, df): + # TODO: abstract this testing logic out into parameterized fixtures + t = con.table("functional_alltypes") + dt = df[["tinyint_col", "double_col"]] + expr = t.tinyint_col.identical_to(t.double_col) + result = expr.execute() + expected = (dt.tinyint_col.isnull() & dt.double_col.isnull()) | ( + dt.tinyint_col == dt.double_col + ) + expected.name = result.name + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["invert", "neg"]) +def test_not_and_negate_bool(con, opname, df): + op = getattr(operator, opname) + t = con.table("functional_alltypes").limit(10) + expr = t.select(op(t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = op(df.head(10).bool_col) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "field", + [ + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "year", + "month", + ], +) +def test_negate_non_boolean(con, field, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t[field]).name(field)) + result = expr.execute()[field] + expected = -df.head(10)[field] + tm.assert_series_equal(result, expected) + + +def test_negate_boolean(con, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = -df.head(10).bool_col + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["sum", "mean", "min", "max", "std", "var"]) +def test_boolean_reduction(alltypes, opname, df): + op = operator.methodcaller(opname) + expr = op(alltypes.bool_col) + result = expr.execute() + assert result == op(df.bool_col) + + +def test_timestamp_with_timezone(con): + t = con.table("tzone") + result = t.ts.execute() + assert str(result.dtype.tz) + + +@pytest.fixture( + params=[ + None, + "UTC", + "America/New_York", + "America/Los_Angeles", + "Europe/Paris", + "Chile/Continental", + "Asia/Tel_Aviv", + "Asia/Tokyo", + "Africa/Nairobi", + "Australia/Sydney", + ] +) +def tz(request): + return request.param + + +@pytest.fixture +def tzone_compute(con, temp_table, tz): + schema = ibis.schema([("ts", dt.Timestamp(tz)), ("b", "double"), ("c", "string")]) + con.create_table(temp_table, schema=schema, temp=False) + t = con.table(temp_table) + + n = 10 + df = pd.DataFrame( + { + "ts": pd.date_range("2017-04-01", periods=n, tz=tz).values, + "b": np.arange(n).astype("float64"), + "c": list(string.ascii_lowercase[:n]), + } + ) + + df.to_sql( + temp_table, + con.con, + index=False, + if_exists="append", + dtype={"ts": sa.TIMESTAMP(timezone=True), "b": sa.FLOAT, "c": sa.TEXT}, + ) + + yield t + con.drop_table(temp_table) + + +def test_ts_timezone_is_preserved(tzone_compute, tz): + assert dt.Timestamp(tz).equals(tzone_compute.ts.type()) + + +def test_timestamp_with_timezone_select(tzone_compute, tz): + ts = tzone_compute.ts.execute() + assert str(getattr(ts.dtype, "tz", None)) == str(tz) + + +@pytest.mark.parametrize( + ("left", "right", "type"), + [ + param( + L("2017-04-01 01:02:33"), + datetime(2017, 4, 1, 1, 3, 34), + dt.timestamp, + id="ibis_timestamp", + ), + param( + datetime(2017, 4, 1, 1, 3, 34), + L("2017-04-01 01:02:33"), + dt.timestamp, + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize("opname", ["eq", "ne", "lt", "le", "gt", "ge"]) +def test_string_temporal_compare(con, opname, left, right, type): + op = getattr(operator, opname) + expr = op(left, right) + result = con.execute(expr) + left_raw = con.execute(L(left).cast(type)) + right_raw = con.execute(L(right).cast(type)) + expected = op(left_raw, right_raw) + assert result == expected + + +@pytest.mark.parametrize( + ("left", "right"), + [ + param( + L("2017-03-31 00:02:33").cast(dt.timestamp), + datetime(2017, 4, 1, 1, 3, 34), + id="ibis_timestamp", + ), + param( + datetime(2017, 3, 31, 0, 2, 33), + L("2017-04-01 01:03:34").cast(dt.timestamp), + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize( + "op", + [ + param( + lambda left, right: ibis.timestamp("2017-04-01 00:02:34").between( + left, right + ), + id="timestamp", + ), + param( + lambda left, right: ( + ibis.timestamp("2017-04-01").cast(dt.date).between(left, right) + ), + id="date", + ), + ], +) +def test_string_temporal_compare_between(con, op, left, right): + expr = op(left, right) + result = con.execute(expr) + assert isinstance(result, (bool, np.bool_)) + assert result + + +@pytest.mark.xfail( + reason="function make_date(integer, integer, integer) does not exist" +) +def test_scalar_parameter(con): + start_string, end_string = "2009-03-01", "2010-07-03" + + start = ibis.param(dt.date) + end = ibis.param(dt.date) + t = con.table("functional_alltypes") + col = t.date_string_col.cast("date") + expr = col.between(start, end).name("res") + expected_expr = col.between(start_string, end_string).name("res") + + result = expr.execute(params={start: start_string, end: end_string}) + expected = expected_expr.execute() + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_cast(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary") + result = expr.execute() + name = expr.get_name() + sql_string = ( + f"SELECT decode(string_col, 'escape') AS \"{name}\" " + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + raw_data = [row[0][0] for row in cur] + expected = pd.Series(raw_data, name=name) + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_round_trip(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary").cast("string") + result = expr.execute() + name = expr.get_name() + sql_string = ( + "SELECT encode(decode(string_col, 'escape'), 'escape') AS " + f'"{name}"' + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + expected = pd.Series([row[0][0] for row in cur], name=name) + tm.assert_series_equal(result, expected) diff --git a/ibis/backends/risingwave/tests/test_json.py b/ibis/backends/risingwave/tests/test_json.py new file mode 100644 index 000000000000..18edda8e3741 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_json.py @@ -0,0 +1,17 @@ +"""Tests for json data types.""" +from __future__ import annotations + +import json + +import pytest +from pytest import param + +import ibis + + +@pytest.mark.parametrize("data", [param({"status": True}, id="status")]) +def test_json(data, alltypes): + lit = ibis.literal(json.dumps(data), type="json").name("tmp") + expr = alltypes[[alltypes.id, lit]].head(1) + df = expr.execute() + assert df["tmp"].iloc[0] == data diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 6a44fe6f0bf2..b3a6b197b988 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -42,6 +42,7 @@ def mean_udf(s): "bigquery", "datafusion", "postgres", + "risingwave", "clickhouse", "impala", "duckdb", @@ -200,6 +201,7 @@ def test_aggregate_grouped(backend, alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -539,39 +541,51 @@ def mean_and_std(v): lambda t, where: t.double_col.arbitrary(where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_default", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="first", where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_first", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="last", where=where), @@ -597,6 +611,10 @@ def mean_and_std(v): raises=com.UnsupportedOperationError, reason="backend only supports the `first` option for `.arbitrary()", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), ], ), param( @@ -623,7 +641,14 @@ def mean_and_std(v): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl( - ["bigquery", "duckdb", "postgres", "pyspark", "trino"], + [ + "bigquery", + "duckdb", + "postgres", + "risingwave", + "pyspark", + "trino", + ], raises=com.UnsupportedOperationError, reason="how='heavy' not supported in the backend", ), @@ -638,19 +663,31 @@ def mean_and_std(v): lambda t, where: t.double_col.first(where=where), lambda t, where: t.double_col[where].iloc[0], id="first", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.last(where=where), lambda t, where: t.double_col[where].iloc[-1], id="last", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.bigint_col.bit_and(where=where), @@ -947,6 +984,11 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): reason="backend doesn't implement approximate quantiles yet", raises=com.OperationNotDefinedError, ), + pytest.mark.broken( + ["risingwave"], + reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64", + raises=sa.exc.InternalError, + ), ], ), ], @@ -995,6 +1037,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1010,6 +1057,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1035,6 +1087,11 @@ def test_quantile( raises=ValueError, reason="PySpark only implements sample correlation", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1056,7 +1113,14 @@ def test_quantile( reason="Correlation with how='sample' is not supported.", ), pytest.mark.notyet( - ["trino", "postgres", "duckdb", "snowflake", "oracle"], + [ + "trino", + "postgres", + "risingwave", + "duckdb", + "snowflake", + "oracle", + ], raises=ValueError, reason="XXXXSQLExprTranslator only implements population correlation coefficient", ), @@ -1079,6 +1143,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1108,6 +1177,11 @@ def test_quantile( raises=ValueError, reason="PySpark only implements sample correlation", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), ], @@ -1468,6 +1542,7 @@ def test_topk_filter_op(alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -1508,6 +1583,7 @@ def test_aggregate_list_like(backend, alltypes, df, agg_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 985d0911371b..28ec720f0237 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -104,6 +104,11 @@ def test_array_concat_variadic(con): raises=sa.exc.ProgrammingError, reason="backend can't infer the type of an empty array", ) +@pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: cannot determine type of empty array", +) def test_array_concat_some_empty(con): left = ibis.literal([]) right = ibis.literal([2, 1]) @@ -176,6 +181,11 @@ def test_array_index(con, idx): reason="backend does not support nullable nested types", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) @pytest.mark.never( ["bigquery"], reason="doesn't support arrays of arrays", raises=AssertionError ) @@ -207,6 +217,11 @@ def test_array_discovery(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_simple(backend): array_types = backend.array_types expected = ( @@ -224,6 +239,11 @@ def test_unnest_simple(backend): @builtin_array @pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_complex(backend): array_types = backend.array_types df = array_types.execute() @@ -262,6 +282,11 @@ def test_unnest_complex(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_idempotent(backend): array_types = backend.array_types df = array_types.execute() @@ -283,6 +308,11 @@ def test_unnest_idempotent(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_no_nulls(backend): array_types = backend.array_types df = array_types.execute() @@ -310,6 +340,11 @@ def test_unnest_no_nulls(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -345,6 +380,11 @@ def test_unnest_default_name(backend): ["datafusion", "flink"], raises=Exception, reason="array_types table isn't defined" ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_slice(backend, start, stop): array_types = backend.array_types expr = array_types.select(sliced=array_types.y[start:stop]) @@ -388,7 +428,13 @@ def test_array_slice(backend, start, stop): param({"a": [[1, 2], [4]]}, {"a": [[2, 3], [5]]}, id="no_nulls"), ], ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) def test_array_map(backend, con, input, output): + t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) expected = pd.DataFrame(output) @@ -452,6 +498,11 @@ def test_array_filter(backend, con, input, output): ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) @pytest.mark.never(["impala"], reason="array_types table isn't defined") +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_contains(backend, con): t = backend.array_types expr = t.x.contains(1) @@ -481,6 +532,11 @@ def test_array_position(backend, con): ["dask", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) def test_array_remove(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 2], [2, 2], []]}) expr = t.a.remove(2) @@ -512,6 +568,11 @@ def test_array_remove(backend, con): raises=(AssertionError, GoogleBadRequest), reason="bigquery doesn't support null elements in arrays", ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( ("input", "expected"), [ @@ -540,6 +601,11 @@ def test_array_unique(backend, con, input, expected): ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735", +) def test_array_sort(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 42], []]}) expr = t.a.sort() @@ -576,6 +642,11 @@ def test_array_union(con): @pytest.mark.notimpl( ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array..." ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( "data", [ @@ -613,7 +684,7 @@ def test_array_intersect(con, data): raises=ClickHouseDatabaseError, reason="ClickHouse won't accept dicts for struct type values", ) -@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) +@pytest.mark.notimpl(["postgres", "risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_struct(con): data = {"value": [[{"a": 1}, {"a": 2}], [{"a": 3}, {"a": 4}]]} @@ -631,9 +702,23 @@ def test_unnest_struct(con): ["impala", "mssql"], raises=com.OperationNotDefinedError, reason="no array support" ) @pytest.mark.notimpl( - ["dask", "datafusion", "druid", "oracle", "pandas", "polars", "postgres"], + [ + "dask", + "datafusion", + "druid", + "oracle", + "pandas", + "polars", + "postgres", + "risingwave", + ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_zip(backend): t = backend.array_types @@ -658,7 +743,7 @@ def test_zip(backend): raises=ClickHouseDatabaseError, reason="https://github.com/ClickHouse/ClickHouse/issues/41112", ) -@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) +@pytest.mark.notimpl(["postgres", "risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], @@ -713,7 +798,7 @@ def flatten_data(): ["bigquery"], reason="BigQuery doesn't support arrays of arrays", raises=TypeError ) @pytest.mark.notyet( - ["postgres"], + ["postgres", "risingwave"], reason="Postgres doesn't truly support arrays of arrays", raises=com.OperationNotDefinedError, ) @@ -784,6 +869,7 @@ def test_range_single_argument(con, n): @pytest.mark.notimpl( ["polars", "flink", "pandas", "dask"], raises=com.OperationNotDefinedError ) +@pytest.mark.skip("risingwave") def test_range_single_argument_unnest(backend, con, n): expr = ibis.range(n).unnest() result = con.execute(expr) @@ -830,6 +916,11 @@ def test_range_start_stop_step(con, start, stop, step): ["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream" ) @pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Invalid parameter step: step size cannot equal zero", +) def test_range_start_stop_step_zero(con, start, stop): expr = ibis.range(start, stop, 0) result = con.execute(expr) @@ -956,6 +1047,11 @@ def swap(token): ibis.interval(hours=1), "1H", id="pos", + marks=pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), ), param( datetime(2017, 1, 2), @@ -966,7 +1062,12 @@ def swap(token): marks=[ pytest.mark.broken( ["polars"], raises=AssertionError, reason="returns an empty array" - ) + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), param( @@ -982,6 +1083,11 @@ def swap(token): ["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1007,7 +1113,14 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): datetime(2017, 1, 2, tzinfo=pytz.UTC), ibis.interval(hours=0), id="pos", - marks=[pytest.mark.notyet(["polars"], raises=PolarsComputeError)], + marks=[ + pytest.mark.notyet(["polars"], raises=PolarsComputeError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), + ], ), param( datetime(2017, 1, 1, tzinfo=pytz.UTC), @@ -1021,6 +1134,11 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): ["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1053,6 +1171,11 @@ def test_repr_timestamp_array(con, monkeypatch): ["dask", "datafusion", "flink", "pandas", "polars"], raises=com.OperationNotDefinedError, ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.OperationalError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14734", +) def test_unnest_range(con): expr = ibis.range(2).unnest().name("x").as_table().mutate({"y": 1.0}) result = con.execute(expr) diff --git a/ibis/backends/tests/test_binary.py b/ibis/backends/tests/test_binary.py index 310c80bcca00..02445e9c9ff7 100644 --- a/ibis/backends/tests/test_binary.py +++ b/ibis/backends/tests/test_binary.py @@ -16,6 +16,7 @@ "sqlite": "blob", "trino": "STRING", "postgres": "bytea", + "risingwave": "bytea", "flink": "BINARY(1) NOT NULL", } diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 467a47f6cf09..e4402e94b2ad 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -256,6 +256,11 @@ def tmpcon(alchemy_con): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_create_temporary_table_from_schema(tmpcon, new_schema): temp_table = f"_{guid()}" table = tmpcon.create_table(temp_table, schema=new_schema, temp=True) @@ -288,6 +293,7 @@ def test_create_temporary_table_from_schema(tmpcon, new_schema): "pandas", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -318,6 +324,11 @@ def test_rename_table(con, temp_table, temp_table_orig): raises=com.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason='Feature is not yet implemented: column constraints "NOT NULL"', +) def test_nullable_input_output(con, temp_table): sch = ibis.schema( [("foo", "int64"), ("bar", dt.int64(nullable=False)), ("baz", "boolean")] @@ -363,7 +374,7 @@ def test_create_drop_view(ddl_con, temp_view): assert set(t_expr.schema().names) == set(v_expr.schema().names) -@mark.notimpl(["postgres", "polars"]) +@mark.notimpl(["postgres", "risingwave", "polars"]) @mark.notimpl( ["datafusion"], raises=NotImplementedError, @@ -578,6 +589,7 @@ def test_list_databases(alchemy_con): test_databases = { "sqlite": {"main"}, "postgres": {"postgres", "ibis_testing"}, + "risingwave": {"dev"}, "mssql": {"ibis_testing"}, "mysql": {"ibis_testing", "information_schema"}, "duckdb": {"memory"}, @@ -590,7 +602,7 @@ def test_list_databases(alchemy_con): @pytest.mark.never( - ["bigquery", "postgres", "mssql", "mysql", "snowflake", "oracle"], + ["bigquery", "postgres", "risingwave", "mssql", "mysql", "snowflake", "oracle"], reason="backend does not support client-side in-memory tables", raises=(sa.exc.OperationalError, TypeError, sa.exc.InterfaceError), ) @@ -663,6 +675,11 @@ def test_unsigned_integer_type(alchemy_con, alchemy_temp_table): marks=mark.postgres, id="postgresql", ), + param( + "postgresql://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), param( "pyspark://?spark.app.name=test-pyspark", marks=[ @@ -1193,6 +1210,11 @@ def test_set_backend_name(name, monkeypatch): marks=mark.postgres, id="postgres", ), + param( + "postgres://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), ], ) def test_set_backend_url(url, monkeypatch): @@ -1217,6 +1239,7 @@ def test_set_backend_url(url, monkeypatch): "pandas", "polars", "postgres", + "risingwave", "pyspark", "sqlite", ], @@ -1253,6 +1276,11 @@ def test_create_table_timestamp(con, temp_table): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_ref_count(backend, con, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation") persisted_table = non_persisted_table.cache() @@ -1273,6 +1301,11 @@ def test_persist_expression_ref_count(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression(backend, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation", other_calc="xyz") persisted_table = non_persisted_table.cache() @@ -1287,6 +1320,11 @@ def test_persist_expression(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager(backend, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc" @@ -1303,6 +1341,11 @@ def test_persist_expression_contextmanager(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1321,6 +1364,11 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): ["mssql"], reason="mssql supports support temporary tables through naming conventions", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") def test_persist_expression_multiple_refs(backend, con, alltypes): non_cached_table = alltypes.mutate( @@ -1358,6 +1406,11 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_repeated_cache(alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1373,6 +1426,11 @@ def test_persist_expression_repeated_cache(alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" @@ -1453,6 +1511,11 @@ def test_create_schema(con_create_schema): con_create_schema.drop_schema(schema) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: information_schema.schemata is not supported,", +) def test_list_schemas(con_create_schema): schemas = con_create_schema.list_schemas() assert len(schemas) == len(set(schemas)) diff --git a/ibis/backends/tests/test_column.py b/ibis/backends/tests/test_column.py index f26b2a876ded..f6b4bd8ee0f4 100644 --- a/ibis/backends/tests/test_column.py +++ b/ibis/backends/tests/test_column.py @@ -19,6 +19,7 @@ "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "trino", diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index b5aec3ebcc95..263759691aa5 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -217,7 +217,7 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): backend.assert_series_equal(foo2.x.execute(), expected2) -_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink"} +_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink", "risingwave"} no_sqlglot_dialect = sorted( param(backend, marks=pytest.mark.xfail) for backend in _NO_SQLGLOT_DIALECT ) @@ -227,6 +227,11 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): "dialect", [*sorted(_get_backend_names() - _NO_SQLGLOT_DIALECT), *no_sqlglot_dialect], ) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @table_dot_sql_notimpl @dot_sql_notimpl @dot_sql_notyet @@ -255,6 +260,11 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): @pytest.mark.notyet( ["oracle"], strict=False, reason="only works with backends that quote everything" ) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @dot_sql_notimpl @dot_sql_never def test_con_dot_sql_transpile(backend, con, dialect, df): @@ -272,6 +282,11 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): @dot_sql_never @pytest.mark.notimpl(["druid", "flink", "impala", "polars", "pyspark"]) @pytest.mark.notyet(["snowflake"], reason="snowflake column names are case insensitive") +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_order_by_no_projection(backend): con = backend.connection astronauts = con.table("astronauts") diff --git a/ibis/backends/tests/test_examples.py b/ibis/backends/tests/test_examples.py index f46d7ed6df19..d4f6505e36d9 100644 --- a/ibis/backends/tests/test_examples.py +++ b/ibis/backends/tests/test_examples.py @@ -16,7 +16,7 @@ reason="nix on linux cannot download duckdb extensions or data due to sandboxing", ) @pytest.mark.notimpl(["dask", "datafusion", "exasol", "pyspark"]) -@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino"]) +@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino", "risingwave"]) @pytest.mark.parametrize( ("example", "columns"), [ @@ -72,6 +72,7 @@ def test_load_examples(con, example, columns): "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 7d244e842c94..71a3f97afc69 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -237,6 +237,7 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -330,6 +331,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): marks=[ pytest.mark.notyet(["druid"], raises=sa.exc.ProgrammingError), pytest.mark.notyet(["exasol"], raises=sa.exc.DBAPIError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(38,9)", + ), ], ), param( @@ -351,6 +357,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): reason="precision is out of range", ), pytest.mark.notyet(["exasol"], raises=sa.exc.DBAPIError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(76,38)", + ), ], ), ], @@ -373,6 +384,7 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "bigquery", @@ -472,7 +484,20 @@ def test_to_pandas_batches_empty_table(backend, con): @pytest.mark.notimpl(["druid"]) @pytest.mark.parametrize( "n", - [param(None, marks=pytest.mark.notimpl(["exasol"], raises=sa.exc.CompileError)), 1], + [ + param( + None, + marks=[ + pytest.mark.notimpl(["exasol"], raises=sa.exc.CompileError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), + 1, + ], ) def test_to_pandas_batches_nonempty_table(backend, con, n): t = backend.functional_alltypes.limit(n) @@ -485,7 +510,17 @@ def test_to_pandas_batches_nonempty_table(backend, con, n): @pytest.mark.parametrize( "n", [ - param(None, marks=pytest.mark.notimpl(["exasol"], raises=sa.exc.CompileError)), + param( + None, + marks=[ + pytest.mark.notimpl(["exasol"], raises=sa.exc.CompileError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), 0, 1, 2, diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 69634a2153b7..58a549023e5b 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -36,6 +36,7 @@ "sqlite": "null", "trino": "unknown", "postgres": "null", + "risingwave": "null", } @@ -60,6 +61,7 @@ def test_null_literal(con, backend): "trino": "boolean", "duckdb": "BOOLEAN", "postgres": "boolean", + "risingwave": "boolean", "flink": "BOOLEAN NOT NULL", } @@ -143,6 +145,7 @@ def test_isna(backend, alltypes, col, filt): "duckdb", "impala", "postgres", + "risingwave", "mysql", "snowflake", "polars", @@ -301,6 +304,7 @@ def test_filter(backend, alltypes, sorted_df, predicate_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -540,6 +544,11 @@ def test_order_by(backend, alltypes, df, key, df_kwargs): @pytest.mark.notimpl(["dask", "pandas", "polars", "mssql", "druid"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_order_by_random(alltypes): expr = alltypes.filter(_.id < 100).order_by(ibis.random()).limit(5) r1 = expr.execute() @@ -783,6 +792,11 @@ def test_correlated_subquery(alltypes): @pytest.mark.notimpl(["polars", "pyspark"]) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason='DataFrame.iloc[:, 0] (column name="playerID") are different', +) def test_uncorrelated_subquery(backend, batting, batting_df): subset_batting = batting[batting.yearID <= 2000] expr = batting[_.yearID == subset_batting.yearID.max()]["playerID", "yearID"] @@ -855,6 +869,11 @@ def test_typeof(con): @pytest.mark.notimpl(["datafusion", "pyspark", "druid"]) @pytest.mark.notyet(["dask", "mssql"], reason="not supported by the backend") @pytest.mark.notimpl(["exasol"], raises=sa.exc.DBAPIError) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="https://github.com/risingwavelabs/risingwave/issues/1343", +) def test_isin_uncorrelated( backend, batting, awards_players, batting_df, awards_players_df ): @@ -997,6 +1016,11 @@ def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns) ) @pytest.mark.notimpl(["druid", "flink"], reason="no sqlglot dialect", raises=ValueError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_many_subqueries(con, snapshot): def query(t, group_cols): t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) @@ -1019,6 +1043,11 @@ def query(t, group_cols): reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet", raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected ), found: TEXT at line:3, column:219 Near "))]) AS anon_1(f1"', +) def test_pivot_longer(backend): diamonds = backend.diamonds df = diamonds.execute() @@ -1129,6 +1158,11 @@ def test_pivot_wider(backend): ["exasol"], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function last(double precision) does not exist, do you mean left or least", +) def test_distinct_on_keep(backend, on, keep): from ibis import _ @@ -1203,6 +1237,11 @@ def test_distinct_on_keep(backend, on, keep): raises=com.OperationNotDefinedError, reason="backend doesn't implement deduplication", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function first(double precision) does not exist", +) def test_distinct_on_keep_is_none(backend, on): from ibis import _ @@ -1225,7 +1264,7 @@ def test_distinct_on_keep_is_none(backend, on): assert len(result) == len(expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "flink", "exasol"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "flink", "exasol"]) @pytest.mark.notyet( [ "sqlite", @@ -1254,6 +1293,7 @@ def test_hash_consistent(backend, alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -1300,6 +1340,7 @@ def test_try_cast_expected(con, from_val, to_type, expected): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -1340,6 +1381,7 @@ def test_try_cast_expected_null(con, from_val, to_type): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -1370,6 +1412,7 @@ def test_try_cast_table(backend, con): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -1412,20 +1455,34 @@ def test_try_cast_func(con, from_val, to_type, func): param( slice(None, None), lambda t: t.count().to_pandas(), - marks=pytest.mark.notyet( - ["exasol"], - raises=sa.exc.CompileError, - ), + marks=[ + pytest.mark.notyet( + ["exasol"], + raises=sa.exc.CompileError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], id="[:]", ), param(slice(0, 0), lambda _: 0, id="[0:0]"), param( slice(0, None), lambda t: t.count().to_pandas(), - marks=pytest.mark.notyet( - ["exasol"], - raises=sa.exc.CompileError, - ), + marks=[ + pytest.mark.notyet( + ["exasol"], + raises=sa.exc.CompileError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], id="[0:]", ), # positive stop @@ -1472,6 +1529,11 @@ def test_try_cast_func(con, from_val, to_type, func): raises=com.UnsupportedArgumentError, reason="pyspark doesn't support non-zero offset until version 3.4", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), ], ), # positive stop @@ -1548,6 +1610,11 @@ def test_static_table_slice(backend, slc, expected_count_fn): raises=sa.exc.CompileError, reason="mssql doesn't support dynamic limit/offset without an ORDER BY", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) @pytest.mark.notimpl( ["exasol"], raises=sa.exc.CompileError, @@ -1624,6 +1691,11 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): reason="https://github.com/duckdb/duckdb/issues/8412", ) @pytest.mark.notyet(["flink"], reason="flink doesn't support dynamic limit/offset") +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) def test_dynamic_table_slice_with_computed_offset(backend): t = backend.functional_alltypes @@ -1652,6 +1724,11 @@ def test_dynamic_table_slice_with_computed_offset(backend): "exasol", ] ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample(backend): t = backend.functional_alltypes.filter(_.int_col >= 2) @@ -1677,6 +1754,11 @@ def test_sample(backend): "exasol", ] ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample_memtable(con, backend): df = pd.DataFrame({"x": [1, 2, 3, 4]}) res = con.execute(ibis.memtable(df).sample(0.5)) @@ -1697,6 +1779,7 @@ def test_sample_memtable(con, backend): "oracle", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -1737,6 +1820,11 @@ def test_substitute(backend): ) @pytest.mark.notimpl(["druid", "flink"], reason="no sqlglot dialect", raises=ValueError) @pytest.mark.notimpl(["exasol"], raises=ValueError, reason="unknown dialect") +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a diff --git a/ibis/backends/tests/test_json.py b/ibis/backends/tests/test_json.py index 78d379ae0bde..9b8a2865e8d4 100644 --- a/ibis/backends/tests/test_json.py +++ b/ibis/backends/tests/test_json.py @@ -12,7 +12,9 @@ pytestmark = [ pytest.mark.never(["impala"], reason="doesn't support JSON and never will"), pytest.mark.notyet(["clickhouse"], reason="upstream is broken"), - pytest.mark.notimpl(["datafusion", "exasol", "mssql", "druid", "oracle"]), + pytest.mark.notimpl( + ["datafusion", "exasol", "mssql", "druid", "oracle", "risingwave"] + ), ] diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 8e9e97e2528d..01562ad20d80 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -34,6 +35,11 @@ def test_map_table(backend): @pytest.mark.xfail_version( duckdb=["duckdb<0.8.0"], raises=exc.UnsupportedOperationError ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_column_map_values(backend): table = backend.map expr = table.select("idx", vals=table.kv.values()).order_by("idx") @@ -64,6 +70,11 @@ def test_column_map_merge(backend): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_keys(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.keys().name("tmp") @@ -79,6 +90,11 @@ def test_literal_map_keys(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_values(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.values().name("tmp") @@ -87,7 +103,7 @@ def test_literal_map_values(con): assert np.array_equal(result, ["a", "b"]) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -103,7 +119,9 @@ def test_scalar_isin_literal_map_keys(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -124,6 +142,11 @@ def test_map_scalar_contains_key_scalar(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_scalar_contains_key_column(backend, alltypes, df): value = {"1": "a", "3": "c"} mapping = ibis.literal(value) @@ -133,7 +156,9 @@ def test_map_scalar_contains_key_column(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -149,7 +174,9 @@ def test_map_column_contains_key_scalar(backend, alltypes, df): backend.assert_series_equal(result, series) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -164,7 +191,9 @@ def test_map_column_contains_key_column(alltypes): assert result.all() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -183,6 +212,11 @@ def test_literal_map_merge(con): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_getitem_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -200,6 +234,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_get_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -220,19 +259,27 @@ def test_literal_map_get_broadcast(backend, alltypes, df): [1, 2], id="string", marks=pytest.mark.notyet( - ["postgres"], reason="only support maps of string -> string" + ["postgres", "risingwave"], + reason="only support maps of string -> string", ), ), param(["a", "b"], ["1", "2"], id="int"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_construct_dict(con, keys, values): expr = ibis.map(keys, values) result = con.execute(expr.name("tmp")) assert result == dict(zip(keys, values)) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -246,7 +293,9 @@ def test_map_construct_array_column(con, alltypes, df): assert result.to_list() == expected.to_list() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -258,7 +307,9 @@ def test_map_get_with_compatible_value_smaller(con): assert con.execute(expr) == 3 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -270,7 +321,9 @@ def test_map_get_with_compatible_value_bigger(con): assert con.execute(expr) == 3000 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -283,7 +336,9 @@ def test_map_get_with_incompatible_value_different_kind(con): @pytest.mark.parametrize("null_value", [None, ibis.NA]) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -303,6 +358,11 @@ def test_map_get_with_null_on_not_nullable(con, null_value): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_get_with_null_on_null_type_with_null(con, null_value): value = ibis.literal({"A": None, "B": None}) expr = value.get("C", null_value) @@ -310,7 +370,9 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value): assert pd.isna(result) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -327,6 +389,11 @@ def test_map_get_with_null_on_null_type_with_non_null(con): raises=exc.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_create_table(con, temp_table): t = con.create_table( temp_table, @@ -340,6 +407,11 @@ def test_map_create_table(con, temp_table): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_length(con): expr = ibis.literal(dict(a="A", b="B")).length() assert con.execute(expr) == 2 diff --git a/ibis/backends/tests/test_network.py b/ibis/backends/tests/test_network.py index e6048ee907b2..dca5815c6855 100644 --- a/ibis/backends/tests/test_network.py +++ b/ibis/backends/tests/test_network.py @@ -20,6 +20,7 @@ "trino": "varchar(17)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(17) NOT NULL", } @@ -50,6 +51,7 @@ def test_macaddr_literal(con, backend): "trino": "127.0.0.1", "impala": "127.0.0.1", "postgres": "127.0.0.1", + "risingwave": "127.0.0.1", "pandas": "127.0.0.1", "pyspark": "127.0.0.1", "mysql": "127.0.0.1", @@ -67,6 +69,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(9)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(9) NOT NULL", }, id="ipv4", @@ -82,6 +85,7 @@ def test_macaddr_literal(con, backend): "trino": "2001:db8::1", "impala": "2001:db8::1", "postgres": "2001:db8::1", + "risingwave": "2001:db8::1", "pandas": "2001:db8::1", "pyspark": "2001:db8::1", "mysql": "2001:db8::1", @@ -99,6 +103,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(11)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(11) NOT NULL", }, id="ipv6", diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 93fdb3225814..202d68e73764 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -42,6 +42,7 @@ "trino": "integer", "duckdb": "TINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="int8", @@ -57,6 +58,7 @@ "trino": "integer", "duckdb": "SMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="int16", @@ -72,6 +74,7 @@ "trino": "integer", "duckdb": "INTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="int32", @@ -87,6 +90,7 @@ "trino": "integer", "duckdb": "BIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="int64", @@ -102,6 +106,7 @@ "trino": "integer", "duckdb": "UTINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="uint8", @@ -117,6 +122,7 @@ "trino": "integer", "duckdb": "USMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="uint16", @@ -132,6 +138,7 @@ "trino": "integer", "duckdb": "UINTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="uint32", @@ -147,6 +154,7 @@ "trino": "integer", "duckdb": "UBIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="uint64", @@ -162,6 +170,7 @@ "trino": "double", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, marks=[ @@ -193,6 +202,7 @@ "trino": "double", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, marks=[ @@ -214,6 +224,7 @@ "trino": "double", "duckdb": "DOUBLE", "postgres": "numeric", + "risingwave": "numeric", "flink": "DOUBLE NOT NULL", }, marks=[ @@ -249,6 +260,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "dask": decimal.Decimal("1.1"), "duckdb": decimal.Decimal("1.1"), "postgres": 1.1, + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": 1.1, @@ -265,6 +277,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 18) NOT NULL", }, marks=[ @@ -300,6 +313,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": 1.1, "duckdb": decimal.Decimal("1.100000000"), "postgres": 1.1, + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": 1.1, @@ -319,6 +333,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(38,9)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 9) NOT NULL", }, marks=[ @@ -349,6 +364,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": 1.1, "dask": decimal.Decimal("1.1"), "postgres": 1.1, + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": 1.1, @@ -367,6 +383,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", }, marks=[ pytest.mark.notimpl( @@ -405,6 +422,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": "Infinity", "sqlite": float("inf"), "postgres": float("nan"), + "risingwave": float("nan"), "pandas": decimal.Decimal("Infinity"), "dask": decimal.Decimal("Infinity"), "impala": float("inf"), @@ -417,6 +435,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "impala": "DOUBLE", }, marks=[ @@ -486,6 +505,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": "-Infinity", "sqlite": float("-inf"), "postgres": float("nan"), + "risingwave": float("nan"), "pandas": decimal.Decimal("-Infinity"), "dask": decimal.Decimal("-Infinity"), "impala": float("-inf"), @@ -498,6 +518,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "impala": "DOUBLE", }, marks=[ @@ -567,6 +588,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": "NaN", "sqlite": None, "postgres": float("nan"), + "risingwave": float("nan"), "pandas": decimal.Decimal("NaN"), "dask": decimal.Decimal("NaN"), "impala": float("nan"), @@ -579,6 +601,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "impala": "DOUBLE", }, marks=[ @@ -892,6 +915,11 @@ def test_isnan_isinf( raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), ], ), param( @@ -909,6 +937,11 @@ def test_isnan_isinf( raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), ], ), param( @@ -1064,7 +1097,14 @@ def test_simple_math_functions_columns( param( lambda t: t.double_col.add(1).log(2), lambda t: np.log2(t.double_col + 1), - marks=pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError), + marks=[ + pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], id="log2", ), param( @@ -1100,6 +1140,11 @@ def test_simple_math_functions_columns( reason="Base greatest(9000, t0.bigint_col) for logarithm not supported!", ), pytest.mark.notimpl(["polars"], raises=com.UnsupportedArgumentError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), ], ), ], @@ -1346,6 +1391,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1358,6 +1404,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1370,6 +1417,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1382,6 +1430,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1428,6 +1477,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ( { "postgres": None, + "risingwave": None, "mysql": 10, "snowflake": 38, "trino": 18, @@ -1438,6 +1488,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): }, { "postgres": None, + "risingwave": None, "mysql": 0, "snowflake": 0, "trino": 3, @@ -1464,6 +1515,11 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): reason="Not SQLAlchemy backends", ) @pytest.mark.notimpl(["druid", "exasol"], raises=KeyError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(5)", +) def test_sa_default_numeric_precision_and_scale( con, backend, default_precisions, default_scales, temp_table ): @@ -1500,6 +1556,11 @@ def test_sa_default_numeric_precision_and_scale( @pytest.mark.notimpl( ["dask", "pandas", "polars", "druid"], raises=com.OperationNotDefinedError ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_random(con): expr = ibis.random() result = con.execute(expr) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 03b3cfddff05..b04d0f762c7c 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -39,6 +39,11 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): ) @pytest.mark.notimpl(["datafusion", "mssql", "trino", "druid"]) @pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) @@ -76,6 +81,7 @@ def test_scalar_param_array(con): "datafusion", "impala", "postgres", + "risingwave", "pyspark", "druid", "oracle", @@ -111,6 +117,11 @@ def test_scalar_param_struct(con): "sql= SELECT MAP_FROM_ARRAYS(ARRAY['a', 'b', 'c'], ARRAY['ghi', 'def', 'abc']) '[' 'b' ']' AS `MapGet(param_0, 'b', None)`" ), ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_map(con): value = {"a": "ghi", "b": "def", "c": "abc"} param = ibis.param(dt.Map(dt.string, dt.string)) @@ -192,6 +203,11 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col): ids=["string", "date", "datetime"], ) @pytest.mark.notimpl(["druid", "oracle"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_date(backend, alltypes, value): param = ibis.param("date") ds_col = alltypes.date_string_col @@ -227,6 +243,7 @@ def test_scalar_param_date(backend, alltypes, value): @pytest.mark.notimpl( [ "postgres", + "risingwave", "datafusion", "clickhouse", "polars", diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index e9b2035741e5..a612eed9cf45 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -91,6 +91,7 @@ def gzip_csv(data_dir, tmp_path): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -117,6 +118,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -140,6 +142,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -196,6 +199,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -232,6 +236,7 @@ def test_register_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -271,6 +276,7 @@ def test_register_iterator_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -301,6 +307,7 @@ def test_register_pandas(con): "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -326,6 +333,7 @@ def test_register_pyarrow_tables(con): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -368,6 +376,7 @@ def test_csv_reregister_schema(con, tmp_path): "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -395,7 +404,9 @@ def test_register_garbage(con, monkeypatch): ("functional_alltypes.parquet", "funk_all"), ], ) -@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) +@pytest.mark.notyet( + ["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] +) @pytest.mark.notimpl( ["flink"], raises=ValueError, @@ -430,7 +441,7 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"] ) @pytest.mark.notimpl( ["flink"], @@ -454,7 +465,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + ["impala", "mssql", "mysql", "pandas", "postgres", "risingwave", "sqlite", "trino"] ) @pytest.mark.notimpl( ["flink"], @@ -487,6 +498,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): "mysql", "pandas", "postgres", + "risingwave", "sqlite", "trino", ] @@ -539,7 +551,9 @@ def num_diamonds(data_dir): "in_table_name", [param(None, id="default"), param("fancy_stones", id="file_name")], ) -@pytest.mark.notyet(["impala", "mssql", "mysql", "postgres", "sqlite", "trino"]) +@pytest.mark.notyet( + ["impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] +) @pytest.mark.notimpl( ["flink"], raises=ValueError, diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 970330c3c9cf..c1931e205175 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -67,19 +67,26 @@ def test_union_mixed_distinct(backend, union_subsets): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support INTERSECT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support INTERSECT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -114,19 +121,26 @@ def test_intersect(backend, alltypes, df, distinct): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support EXCEPT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support EXCEPT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: EXCEPT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -193,18 +207,25 @@ def test_top_level_union(backend, con, alltypes, distinct): True, param( False, - marks=pytest.mark.notimpl( - [ - "impala", - "bigquery", - "dask", - "mssql", - "pandas", - "snowflake", - "sqlite", - "exasol", - ] - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "bigquery", + "dask", + "mssql", + "pandas", + "snowflake", + "sqlite", + "exasol", + ] + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], ), ], ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index b635ca8e1d09..9db76cfa9c27 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -11,7 +11,7 @@ sa = pytest.importorskip("sqlalchemy") sg = pytest.importorskip("sqlglot") -pytestmark = pytest.mark.notimpl(["druid", "flink", "exasol"]) +pytestmark = pytest.mark.notimpl(["druid", "flink", "exasol", "risingwave"]) simple_literal = param(ibis.literal(1), id="simple_literal") array_literal = param( diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 997a8dea3e41..7c36da0d7b1f 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -30,6 +30,7 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(6) NOT NULL", }, id="string", @@ -45,14 +46,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote1", - marks=pytest.mark.broken( - ["oracle"], - raises=sa.exc.DatabaseError, - reason="ORA-01741: illegal zero length identifier", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=sa.exc.DatabaseError, + reason="ORA-01741: illegal zero length identifier", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), param( 'STRI"NG', @@ -65,14 +74,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote2", - marks=pytest.mark.broken( - ["oracle"], - raises=sa.exc.DatabaseError, - reason="ORA-25716", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=sa.exc.DatabaseError, + reason="ORA-25716", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), ], ) @@ -247,6 +264,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -258,6 +280,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -271,6 +298,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -282,6 +314,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -295,6 +332,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -308,6 +350,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -321,6 +368,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -332,6 +384,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -343,6 +400,11 @@ def uses_java_re(t): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -935,6 +997,7 @@ def test_substr_with_null_values(backend, alltypes, df): "mysql", "polars", "postgres", + "risingwave", "pyspark", "druid", "oracle", @@ -1008,6 +1071,11 @@ def test_multiple_subs(con): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function levenshtein(character varying, character varying) does not exist", +) @pytest.mark.parametrize( "right", ["sitting", ibis.literal("sitting")], ids=["python", "ibis"] ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index d59459f858c2..aa37880f72e2 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -53,7 +53,7 @@ def test_all_fields(struct, struct_df): _NULL_STRUCT_LITERAL = ibis.NA.cast("struct") -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" @@ -67,7 +67,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -83,7 +83,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" ) @@ -99,7 +99,7 @@ def test_struct_column(backend, alltypes, df): tm.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "polars"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 323b81f2a0e1..93799226abb2 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -221,6 +221,11 @@ def test_timestamp_extract(backend, alltypes, df, attr): ["mssql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError, ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -525,6 +530,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "impala", "mysql", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -761,6 +767,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'WEEK'. Was expecting one of: DAY, DAYS, HOUR", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), ], ), param( @@ -824,6 +835,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MILLISECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: millisecond", + ), ], ), param( @@ -848,6 +864,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MICROSECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: microsecond", + ), ], ), ], @@ -898,7 +919,14 @@ def convert_to_offset(offset, displacement_type=displacement_type): ), param( "W", - marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + marks=[ + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), + ], ), "D", ], @@ -1003,6 +1031,7 @@ def convert_to_offset(x): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "bigquery", @@ -1028,6 +1057,7 @@ def convert_to_offset(x): "clickhouse", "sqlite", "postgres", + "risingwave", "polars", "mysql", "impala", @@ -1158,6 +1188,11 @@ def convert_to_offset(x): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", + ), pytest.mark.broken( ["flink"], raises=com.UnsupportedOperationError, @@ -1633,6 +1668,7 @@ def test_interval_add_cast_column(backend, alltypes, df): pytest.mark.notimpl( [ "postgres", + "risingwave", "snowflake", ], raises=AttributeError, @@ -1760,7 +1796,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): ], ) @pytest.mark.notimpl( - ["mysql", "postgres", "sqlite", "druid", "oracle"], + ["mysql", "postgres", "risingwave", "sqlite", "druid", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @@ -1845,6 +1881,7 @@ def test_integer_to_timestamp(backend, con, unit): "dask", "pandas", "postgres", + "risingwave", "clickhouse", "sqlite", "impala", @@ -1880,6 +1917,11 @@ def test_string_to_timestamp(alltypes, fmt): ) @pytest.mark.notimpl(["mssql", "druid", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_scalar(con, date, expected_index, expected_day): expr = ibis.literal(date).cast(dt.date) result_index = con.execute(expr.day_of_week.index().name("tmp")) @@ -1896,6 +1938,11 @@ def test_day_of_week_scalar(con, date, expected_index, expected_day): reason="StringColumn' object has no attribute 'day_of_week'", ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_column(backend, alltypes, df): expr = alltypes.timestamp_col.day_of_week @@ -1927,6 +1974,11 @@ def test_day_of_week_column(backend, alltypes, df): ["mssql"], raises=com.OperationNotDefinedError, ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -1992,6 +2044,7 @@ def test_now_from_projection(alltypes): "trino": "date", "duckdb": "DATE", "postgres": "date", + "risingwave": "date", "flink": "DATE NOT NULL", } @@ -2018,6 +2071,11 @@ def test_now_from_projection(alltypes): ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_literal(con, backend): expr = ibis.date(2022, 2, 4) result = con.execute(expr) @@ -2037,6 +2095,7 @@ def test_date_literal(con, backend): "trino": "timestamp(3)", "duckdb": "TIMESTAMP", "postgres": "timestamp without time zone", + "risingwave": "timestamp without time zone", "flink": "TIMESTAMP(6) NOT NULL", } @@ -2065,6 +2124,11 @@ def test_date_literal(con, backend): ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_literal(con, backend): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0) result = con.execute(expr) @@ -2123,6 +2187,11 @@ def test_timestamp_literal(con, backend): ), ) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_with_timezone_literal(con, timezone, expected): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0).cast(dt.Timestamp(timezone=timezone)) result = con.execute(expr) @@ -2139,6 +2208,7 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): "trino": "time(3)", "duckdb": "TIME", "postgres": "time without time zone", + "risingwave": "time without time zone", } @@ -2170,6 +2240,11 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): ["druid"], raises=sa.exc.ProgrammingError, reason="SQL parse failed" ) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_time(integer, integer, integer) does not exist", +) def test_time_literal(con, backend): expr = ibis.time(16, 20, 0) result = con.execute(expr) @@ -2248,6 +2323,7 @@ def test_extract_time_from_timestamp(con, microsecond): "trino": "interval day to second", "duckdb": "INTERVAL", "postgres": "interval", + "risingwave": "interval", } @@ -2342,6 +2418,11 @@ def test_interval_literal(con, backend): ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=sa.exc.DBAPIError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_column_from_ymd(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.date(c.year(), c.month(), c.day()) @@ -2371,6 +2452,11 @@ def test_date_column_from_ymd(backend, con, alltypes, df): ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=sa.exc.DBAPIError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(smallint, smallint, smallint, smallint, smallint, smallint) does not exist", +) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp( @@ -2543,6 +2629,11 @@ def build_date_col(t): param(lambda _: DATE, build_date_col, id="date_column"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): left = left_fn(alltypes) right = right_fn(alltypes) @@ -2654,6 +2745,11 @@ def test_large_timestamp(con): reason="assert Timestamp('2023-01-07 13:20:05.561000') == Timestamp('2023-01-07 13:20:05.561000231')", raises=AssertionError, ), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Parse error: timestamp without time zone Can't cast string to timestamp (expected format is YYYY-MM-DD HH:MM:SS[.D+{up to 6 digits}] or YYYY-MM-DD HH:MM or YYYY-MM-DD or ISO 8601 format)", + ), ], ), ], @@ -2696,7 +2792,7 @@ def test_timestamp_precision_output(con, ts, scale, unit): raises=com.OperationNotDefinedError, ) @pytest.mark.notyet( - ["postgres"], + ["postgres", "risingwave"], reason="postgres doesn't have any easy way to accurately compute the delta in specific units", raises=com.OperationNotDefinedError, ) @@ -2850,6 +2946,11 @@ def test_delta(con, start, end, unit, expected): ], ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket(backend, kws, pd_freq): ts = backend.functional_alltypes.timestamp_col.name("ts").execute() res = backend.functional_alltypes.timestamp_col.bucket(**kws).name("ts").execute() @@ -2884,6 +2985,11 @@ def test_timestamp_bucket(backend, kws, pd_freq): ) @pytest.mark.parametrize("offset_mins", [2, -2], ids=["pos", "neg"]) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket_offset(backend, offset_mins): ts = backend.functional_alltypes.timestamp_col.name("ts") expr = ts.bucket(minutes=5, offset=ibis.interval(minutes=offset_mins)).name("ts") @@ -2991,6 +3097,11 @@ def test_time_literal_sql(dialect, snapshot, micros): param(datetime.date.fromisoformat, id="fromstring"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar(con, value, func): expr = ibis.date(func(value)).name("tmp") diff --git a/ibis/backends/tests/test_timecontext.py b/ibis/backends/tests/test_timecontext.py index 1d4a7495fe95..72e78065640e 100644 --- a/ibis/backends/tests/test_timecontext.py +++ b/ibis/backends/tests/test_timecontext.py @@ -19,6 +19,7 @@ "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index 2c2fac00246c..15be61cf1723 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -20,6 +20,7 @@ "oracle", "pandas", "trino", + "risingwave", ] ) diff --git a/ibis/backends/tests/test_uuid.py b/ibis/backends/tests/test_uuid.py index eac109b68a89..ea9064dd0d74 100644 --- a/ibis/backends/tests/test_uuid.py +++ b/ibis/backends/tests/test_uuid.py @@ -66,6 +66,11 @@ @pytest.mark.notimpl( ["impala", "datafusion", "polars", "clickhouse"], raises=NotImplementedError ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sqlalchemy.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: UUID", +) def test_uuid_literal(con, backend): backend_name = backend.name() diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index 56920a1959b5..f130b5b60154 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -10,7 +10,7 @@ import ibis.expr.datatypes as dt from ibis.legacy.udf.vectorized import analytic, elementwise, reduction -pytestmark = pytest.mark.notimpl(["druid", "oracle"]) +pytestmark = pytest.mark.notimpl(["druid", "oracle", "risingwave"]) def _format_udf_return_type(func, result_formatter): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 84004e5482a7..1b43317a52c9 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -154,6 +154,11 @@ def calc_zscore(s): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", + ), ], ), param( @@ -164,6 +169,12 @@ def calc_zscore(s): pytest.mark.notimpl(["pyspark"], raises=com.UnsupportedOperationError), pytest.mark.notyet(["clickhouse"], raises=com.OperationNotDefinedError), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: cume_dist", + ), ], ), param( @@ -195,6 +206,11 @@ def calc_zscore(s): raises=com.UnsupportedOperationError, reason="Windows in Flink can only be ordered by a single time column", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -231,6 +247,7 @@ def calc_zscore(s): raises=com.OperationNotDefinedError, reason="No translation rule for ", ), + pytest.mark.notimpl(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -391,7 +408,14 @@ def test_grouped_bounded_expanding_window( lambda t, win: t.double_col.mean().over(win), lambda df: (df.double_col.expanding().mean()), id="mean", - marks=[pytest.mark.notimpl(["dask"], raises=NotImplementedError)], + marks=[ + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( # Disabled on PySpark and Spark backends because in pyspark<3.0.0, @@ -411,6 +435,7 @@ def test_grouped_bounded_expanding_window( "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "datafusion", @@ -570,6 +595,7 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -642,6 +668,11 @@ def test_grouped_unbounded_window( raises=com.UnsupportedOperationError, reason="OVER RANGE FOLLOWING windows are not supported in Flink yet", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_unbound_following_window( backend, alltypes, ibis_method, pandas_fn ): @@ -674,6 +705,11 @@ def test_simple_ungrouped_unbound_following_window( ["mssql"], raises=Exception, reason="order by constant is not supported" ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_window_with_scalar_order_by(alltypes): t = alltypes[alltypes.double_col < 50].order_by("id") w = ibis.window(rows=(0, None), order_by=ibis.NA) @@ -703,6 +739,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="default window semantics are different", raises=AssertionError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -732,6 +773,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=PySparkAnalysisException, reason="pyspark requires CURRENT ROW", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -751,6 +797,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -782,6 +829,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -808,6 +856,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=Exception, reason="Exception: Internal error: Expects default value to have Int64 type.", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -844,6 +897,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=sa.exc.ProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -857,6 +915,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=Exception, reason="Exception: Internal error: Expects default value to have Int64 type.", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -896,6 +959,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=sa.exc.ProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -915,6 +983,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -951,6 +1020,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -1016,6 +1086,11 @@ def test_ungrouped_unbounded_window( reason="RANGE OFFSET frame for 'DB::ColumnNullable' ORDER BY column is not implemented", raises=ClickHouseDatabaseError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", +) @pytest.mark.notyet(["mssql"], raises=sa.exc.ProgrammingError) def test_grouped_bounded_range_window(backend, alltypes, df): # Explanation of the range window spec below: @@ -1073,6 +1148,11 @@ def gb_fn(df): reason="clickhouse doesn't implement percent_rank", raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_percent_rank_whole_table_no_order_by(backend, alltypes, df): expr = alltypes.mutate(val=lambda t: t.id.percent_rank()) @@ -1132,6 +1212,11 @@ def agg(df): raises=Exception, reason="Exception: Internal error: Expects default value to have Int64 type.", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_mutate_window_filter(backend, alltypes): t = alltypes win = ibis.window(order_by=[t.id]) @@ -1203,6 +1288,11 @@ def test_first_last(backend): reason="not support by the backend", ) @pytest.mark.broken(["flink"], raises=Py4JJavaError, reason="bug in Flink") +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="sql parser error: Expected literal int, found: INTERVAL at line:1, column:99", +) def test_range_expression_bounds(backend): t = ibis.memtable( { @@ -1253,6 +1343,11 @@ def test_range_expression_bounds(backend): @pytest.mark.broken( ["pyspark"], reason="pyspark requires CURRENT ROW", raises=PySparkAnalysisException ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): # GH #7631 t = alltypes @@ -1286,6 +1381,11 @@ def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): @pytest.mark.broken( ["pyspark"], reason="pyspark requires CURRENT ROW", raises=PySparkAnalysisException ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_ordering_order(con): table = ibis.memtable({"bool_col": [True, False, False, None, True]}) window = ibis.window( diff --git a/ibis/tests/benchmarks/test_benchmarks.py b/ibis/tests/benchmarks/test_benchmarks.py index 50b5d478c396..4977e1bc18f2 100644 --- a/ibis/tests/benchmarks/test_benchmarks.py +++ b/ibis/tests/benchmarks/test_benchmarks.py @@ -159,7 +159,7 @@ def test_builtins(benchmark, expr_fn, builtin, t, base, large_expr): # compile is a no-op _backends.remove("pandas") -_XFAIL_COMPILE_BACKENDS = {"dask", "pyspark", "polars"} +_XFAIL_COMPILE_BACKENDS = {"dask", "pyspark", "polars", "risingwave"} @pytest.mark.benchmark(group="compilation") diff --git a/poetry.lock b/poetry.lock index 2f9cd82476cb..52e316067558 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6435,6 +6435,20 @@ sqlalchemy = ">=1.4,<2" [package.extras] turbodbc = ["turbodbc (==4.5.4)"] +[[package]] +name = "sqlalchemy-risingwave" +version = "1.0.0" +description = "RisingWave dialect for SQLAlchemy" +optional = true +python-versions = "*" +files = [ + {file = "sqlalchemy-risingwave-1.0.0.tar.gz", hash = "sha256:856a3c44b98cba34d399c3cc9785a74896caca152b3685d87553e4210e3e07a4"}, + {file = "sqlalchemy_risingwave-1.0.0-py3-none-any.whl", hash = "sha256:c733365abc38e88f4d23d83713cfc3f21c0b0d3c81210cbc2f569b49a912ba08"}, +] + +[package.dependencies] +SQLAlchemy = ">=1.4,<2" + [[package]] name = "sqlalchemy-views" version = "0.3.2" @@ -7382,7 +7396,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "duckdb-engine", "geoalchemy2", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "snowflake-sqlalchemy", "sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-views", "trino"] +all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "duckdb-engine", "geoalchemy2", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "snowflake-sqlalchemy", "sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-risingwave", "sqlalchemy-views", "trino"] bigquery = ["db-dtypes", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pydata-google-auth"] clickhouse = ["clickhouse-connect", "sqlalchemy"] dask = ["dask", "regex"] @@ -7403,6 +7417,7 @@ pandas = ["regex"] polars = ["packaging", "polars"] postgres = ["psycopg2", "sqlalchemy", "sqlalchemy-views"] pyspark = ["packaging", "pyspark", "sqlalchemy"] +risingwave = ["psycopg2", "sqlalchemy", "sqlalchemy-risingwave", "sqlalchemy-views"] snowflake = ["packaging", "snowflake-connector-python", "snowflake-sqlalchemy", "sqlalchemy-views"] sqlite = ["regex", "sqlalchemy", "sqlalchemy-views"] trino = ["sqlalchemy", "sqlalchemy-views", "trino"] @@ -7411,4 +7426,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1fad7113a8c6bcf6e661bd27894fd64c370f15b0fd13ac53b21001025e4658e5" +content-hash = "3c1dfc652d2d025e6ea434033966154b44ba3a4452cbe3c7439ea4754c6ec420" diff --git a/pyproject.toml b/pyproject.toml index 1c90337747c5..a6d43f6d04a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ snowflake-sqlalchemy = { version = ">=1.4.1,<2", optional = true } sqlalchemy = { version = ">=1.4,<3", optional = true } sqlalchemy-exasol = { version = ">=4.6.0", optional = true } sqlalchemy-views = { version = ">=0.3.1,<1", optional = true } +sqlalchemy-risingwave = { version = ">=1.0.0,<2", optional = true } trino = { version = ">=0.321,<1", optional = true, extras = ["sqlalchemy"] } [tool.poetry.group.dev.dependencies] @@ -178,6 +179,7 @@ all = [ "sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-views", + "sqlalchemy-risingwave", "trino", ] bigquery = [ @@ -201,6 +203,12 @@ oracle = ["sqlalchemy", "oracledb", "packaging", "sqlalchemy-views"] pandas = ["regex"] polars = ["polars", "packaging"] postgres = ["psycopg2", "sqlalchemy", "sqlalchemy-views"] +risingwave = [ + "psycopg2", + "sqlalchemy", + "sqlalchemy-views", + "sqlalchemy-risingwave", +] pyspark = ["pyspark", "sqlalchemy", "packaging"] snowflake = [ "snowflake-connector-python", @@ -232,6 +240,7 @@ oracle = "ibis.backends.oracle" pandas = "ibis.backends.pandas" polars = "ibis.backends.polars" postgres = "ibis.backends.postgres" +risingwave = "ibis.backends.risingwave" pyspark = "ibis.backends.pyspark" snowflake = "ibis.backends.snowflake" sqlite = "ibis.backends.sqlite" @@ -367,6 +376,7 @@ markers = [ "pandas: Pandas tests", "polars: Polars tests", "postgres: PostgreSQL tests", + "risingwave: Risingwave tests", "pyspark: PySpark tests", "snowflake: Snowflake tests", "sqlite: SQLite tests", diff --git a/requirements-dev.txt b/requirements-dev.txt index f4bf453a108b..db4f960a9436 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -238,6 +238,7 @@ sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "4.0" soupsieve==2.5 ; python_version >= "3.10" and python_version < "3.13" sphobjinv==2.3.1 ; python_version >= "3.10" and python_version < "3.13" sqlalchemy-exasol==4.6.3 ; python_version >= "3.9" and python_version < "4.0" +sqlalchemy-risingwave==1.0.0 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy-views==0.3.2 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy==1.4.51 ; python_version >= "3.9" and python_version < "4.0" sqlglot==20.8.0 ; python_version >= "3.9" and python_version < "4.0"