diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 26b889e0356cf..878208d2ef7f0 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -122,6 +122,12 @@ jobs: - postgres sys-deps: - libgeos-dev + - name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - name: impala title: Impala serial: true @@ -211,6 +217,14 @@ jobs: - postgres sys-deps: - libgeos-dev + - os: windows-latest + backend: + name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - os: windows-latest backend: name: postgres diff --git a/ci/schema/risingwave.sql b/ci/schema/risingwave.sql new file mode 100644 index 0000000000000..cedfa8449d60f --- /dev/null +++ b/ci/schema/risingwave.sql @@ -0,0 +1,177 @@ +SET RW_IMPLICIT_FLUSH=true; + +DROP TABLE IF EXISTS diamonds CASCADE; + +CREATE TABLE diamonds ( + carat FLOAT, + cut TEXT, + color TEXT, + clarity TEXT, + depth FLOAT, + "table" FLOAT, + price BIGINT, + x FLOAT, + y FLOAT, + z FLOAT +) WITH ( + connector = 'posix_fs', + match_pattern = 'diamonds.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS astronauts CASCADE; + +CREATE TABLE astronauts ( + "id" BIGINT, + "number" BIGINT, + "nationwide_number" BIGINT, + "name" VARCHAR, + "original_name" VARCHAR, + "sex" VARCHAR, + "year_of_birth" BIGINT, + "nationality" VARCHAR, + "military_civilian" VARCHAR, + "selection" VARCHAR, + "year_of_selection" BIGINT, + "mission_number" BIGINT, + "total_number_of_missions" BIGINT, + "occupation" VARCHAR, + "year_of_mission" BIGINT, + "mission_title" VARCHAR, + "ascend_shuttle" VARCHAR, + "in_orbit" VARCHAR, + "descend_shuttle" VARCHAR, + "hours_mission" DOUBLE PRECISION, + "total_hrs_sum" DOUBLE PRECISION, + "field21" BIGINT, + "eva_hrs_mission" DOUBLE PRECISION, + "total_eva_hrs" DOUBLE PRECISION +) WITH ( + connector = 'posix_fs', + match_pattern = 'astronauts.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS batting CASCADE; + +CREATE TABLE batting ( + "playerID" TEXT, + "yearID" BIGINT, + stint BIGINT, + "teamID" TEXT, + "lgID" TEXT, + "G" BIGINT, + "AB" BIGINT, + "R" BIGINT, + "H" BIGINT, + "X2B" BIGINT, + "X3B" BIGINT, + "HR" BIGINT, + "RBI" BIGINT, + "SB" BIGINT, + "CS" BIGINT, + "BB" BIGINT, + "SO" BIGINT, + "IBB" BIGINT, + "HBP" BIGINT, + "SH" BIGINT, + "SF" BIGINT, + "GIDP" BIGINT +) WITH ( + connector = 'posix_fs', + match_pattern = 'batting.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS awards_players CASCADE; + +CREATE TABLE awards_players ( + "playerID" TEXT, + "awardID" TEXT, + "yearID" BIGINT, + "lgID" TEXT, + tie TEXT, + notes TEXT +) WITH ( + connector = 'posix_fs', + match_pattern = 'awards_players.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS functional_alltypes CASCADE; + +CREATE TABLE functional_alltypes ( + id INTEGER, + bool_col BOOLEAN, + tinyint_col SMALLINT, + smallint_col SMALLINT, + int_col INTEGER, + bigint_col BIGINT, + float_col REAL, + double_col DOUBLE PRECISION, + date_string_col TEXT, + string_col TEXT, + timestamp_col TIMESTAMP WITHOUT TIME ZONE, + year INTEGER, + month INTEGER +) WITH ( + connector = 'posix_fs', + match_pattern = 'functional_alltypes.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS tzone CASCADE; + +CREATE TABLE tzone ( + ts TIMESTAMP WITH TIME ZONE, + key TEXT, + value DOUBLE PRECISION +); + +INSERT INTO tzone + SELECT + CAST('2017-05-28 11:01:31.000400' AS TIMESTAMP WITH TIME ZONE) + + t * INTERVAL '1 day 1 second' AS ts, + CHR(97 + t) AS key, + t + t / 10.0 AS value + FROM generate_series(0, 9) AS t; + +DROP TABLE IF EXISTS array_types CASCADE; + +CREATE TABLE IF NOT EXISTS array_types ( + x BIGINT[], + y TEXT[], + z DOUBLE PRECISION[], + grouper TEXT, + scalar_column DOUBLE PRECISION, + multi_dim BIGINT[][] +); + +INSERT INTO array_types VALUES + (ARRAY[1, 2, 3], ARRAY['a', 'b', 'c'], ARRAY[1.0, 2.0, 3.0], 'a', 1.0, ARRAY[ARRAY[NULL::BIGINT, NULL, NULL], ARRAY[1, 2, 3]]), + (ARRAY[4, 5], ARRAY['d', 'e'], ARRAY[4.0, 5.0], 'a', 2.0, ARRAY[]::BIGINT[][]), + (ARRAY[6, NULL], ARRAY['f', NULL], ARRAY[6.0, NULL], 'a', 3.0, ARRAY[NULL, ARRAY[]::BIGINT[], NULL]), + (ARRAY[NULL, 1, NULL], ARRAY[NULL, 'a', NULL], ARRAY[]::DOUBLE PRECISION[], 'b', 4.0, ARRAY[ARRAY[1], ARRAY[2], ARRAY[NULL::BIGINT], ARRAY[3]]), + (ARRAY[2, NULL, 3], ARRAY['b', NULL, 'c'], NULL, 'b', 5.0, NULL), + (ARRAY[4, NULL, NULL, 5], ARRAY['d', NULL, NULL, 'e'], ARRAY[4.0, NULL, NULL, 5.0], 'c', 6.0, ARRAY[ARRAY[1, 2, 3]]); + +DROP TABLE IF EXISTS json_t CASCADE; + +CREATE TABLE IF NOT EXISTS json_t (js JSONB); + +INSERT INTO json_t VALUES + ('{"a": [1,2,3,4], "b": 1}'), + ('{"a":null,"b":2}'), + ('{"a":"foo", "c":null}'), + ('null'), + ('[42,47,55]'), + ('[]'); + +DROP TABLE IF EXISTS win CASCADE; +CREATE TABLE win (g TEXT, x BIGINT, y BIGINT); +INSERT INTO win VALUES + ('a', 0, 3), + ('a', 1, 2), + ('a', 2, 0), + ('a', 3, 1), + ('a', 4, 1); diff --git a/compose.yaml b/compose.yaml index 9d110ccbc7b67..388ee3e9f1159 100644 --- a/compose.yaml +++ b/compose.yaml @@ -538,6 +538,88 @@ services: networks: - impala + risingwave-minio: + image: "quay.io/minio/minio:latest" + command: + - server + - "--address" + - "0.0.0.0:9301" + - "--console-address" + - "0.0.0.0:9400" + - /data + expose: + - "9301" + - "9400" + ports: + - "9301:9301" + - "9400:9400" + depends_on: [] + volumes: + - "risingwave-minio:/data" + entrypoint: /bin/sh -c "set -e; mkdir -p \"/data/hummock001\"; /usr/bin/docker-entrypoint.sh \"$$0\" \"$$@\" " + environment: + MINIO_CI_CD: "1" + MINIO_ROOT_PASSWORD: hummockadmin + MINIO_ROOT_USER: hummockadmin + MINIO_DOMAIN: "risingwave-minio" + container_name: risingwave-minio + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/9301; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + + risingwave: + image: ghcr.io/risingwavelabs/risingwave:nightly-20240122 + command: "standalone --meta-opts=\" \ + --advertise-addr 0.0.0.0:5690 \ + --backend mem \ + --state-store hummock+minio://hummockadmin:hummockadmin@risingwave-minio:9301/hummock001 \ + --data-directory hummock_001 \ + --config-path /risingwave.toml\" \ + --compute-opts=\" \ + --config-path /risingwave.toml \ + --advertise-addr 0.0.0.0:5688 \ + --role both \" \ + --frontend-opts=\" \ + --config-path /risingwave.toml \ + --listen-addr 0.0.0.0:4566 \ + --advertise-addr 0.0.0.0:4566 \" \ + --compactor-opts=\" \ + --advertise-addr 0.0.0.0:6660 \"" + expose: + - "4566" + ports: + - "4566:4566" + depends_on: + - risingwave-minio + volumes: + - "./docker/risingwave/risingwave.toml:/risingwave.toml" + - risingwave:/data + environment: + RUST_BACKTRACE: "1" + # If ENABLE_TELEMETRY is not set, telemetry will start by default + ENABLE_TELEMETRY: ${ENABLE_TELEMETRY:-true} + container_name: risingwave + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/6660; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5688; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/4566; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5690; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + networks: impala: # docker defaults to naming networks "$PROJECT_$NETWORK" but the Java Hive @@ -554,6 +636,7 @@ networks: oracle: exasol: flink: + risingwave: volumes: broker_var: @@ -572,3 +655,5 @@ volumes: minio: exasol: impala: + risingwave-minio: + risingwave: diff --git a/docker/risingwave/risingwave.toml b/docker/risingwave/risingwave.toml new file mode 100644 index 0000000000000..43d57926ed16f --- /dev/null +++ b/docker/risingwave/risingwave.toml @@ -0,0 +1,2 @@ +# RisingWave config file to be mounted into the Docker containers. +# See https://github.com/risingwavelabs/risingwave/blob/main/src/config/example.toml for example diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py new file mode 100644 index 0000000000000..04de491f6dfef --- /dev/null +++ b/ibis/backends/risingwave/__init__.py @@ -0,0 +1,282 @@ +"""Risingwave backend.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Callable, Literal + +import sqlalchemy as sa + +import ibis.common.exceptions as exc +import ibis.expr.operations as ops +from ibis import util +from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend +from ibis.backends.risingwave.compiler import RisingwaveCompiler +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.common.exceptions import InvalidDecoratorError + +if TYPE_CHECKING: + from collections.abc import Iterable + + import ibis.expr.datatypes as dt + + +def _verify_source_line(func_name: str, line: str): + if line.startswith("@"): + raise InvalidDecoratorError(func_name, line) + return line + + +class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema): + name = "risingwave" + compiler = RisingwaveCompiler + supports_temporary_tables = False + supports_create_or_replace = False + supports_python_udfs = False + + def do_connect( + self, + host: str | None = None, + user: str | None = None, + password: str | None = None, + port: int = 5432, + database: str | None = None, + schema: str | None = None, + url: str | None = None, + driver: Literal["psycopg2"] = "psycopg2", + ) -> None: + """Create an Ibis client connected to Risingwave database. + + Parameters + ---------- + host + Hostname + user + Username + password + Password + port + Port number + database + Database to connect to + schema + Risingwave schema to use. If `None`, use the default `search_path`. + url + SQLAlchemy connection string. + + If passed, the other connection arguments are ignored. + driver + Database driver + + Examples + -------- + >>> import os + >>> import getpass + >>> import ibis + >>> host = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") + >>> user = os.environ.get("IBIS_TEST_RISINGWAVE_USER", getpass.getuser()) + >>> password = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD") + >>> database = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") + >>> con = connect(database=database, host=host, user=user, password=password) + >>> con.list_tables() # doctest: +ELLIPSIS + [...] + >>> t = con.table("functional_alltypes") + >>> t + RisingwaveTable[table] + name: functional_alltypes + schema: + id : int32 + bool_col : boolean + tinyint_col : int16 + smallint_col : int16 + int_col : int32 + bigint_col : int64 + float_col : float32 + double_col : float64 + date_string_col : string + string_col : string + timestamp_col : timestamp + year : int32 + month : int32 + """ + if driver != "psycopg2": + raise NotImplementedError("psycopg2 is currently the only supported driver") + + alchemy_url = self._build_alchemy_url( + url=url, + host=host, + port=port, + user=user, + password=password, + database=database, + driver=f"risingwave+{driver}", + ) + + connect_args = {} + if schema is not None: + connect_args["options"] = f"-csearch_path={schema}" + + engine = sa.create_engine( + alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool + ) + + @sa.event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + with dbapi_connection.cursor() as cur: + cur.execute("SET TIMEZONE = UTC") + + super().do_connect(engine) + + def list_tables(self, like=None, schema=None): + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + schema + The schema to perform the list against. + + ::: {.callout-warning} + ## `schema` refers to database hierarchy + + The `schema` parameter does **not** refer to the column names and + types of `table`. + ::: + """ + tables = self.inspector.get_table_names(schema=schema) + views = self.inspector.get_view_names(schema=schema) + return self._filter_with_like(tables + views, like) + + def list_databases(self, like=None) -> list[str]: + # http://dba.stackexchange.com/a/1304/58517 + dbs = sa.table( + "pg_database", + sa.column("datname", sa.TEXT()), + sa.column("datistemplate", sa.BOOLEAN()), + schema="pg_catalog", + ) + query = sa.select(dbs.c.datname).where(sa.not_(dbs.c.datistemplate)) + with self.begin() as con: + databases = list(con.execute(query).scalars()) + + return self._filter_with_like(databases, like) + + @property + def current_database(self) -> str: + return self._scalar_query(sa.select(sa.func.current_database())) + + @property + def current_schema(self) -> str: + return self._scalar_query(sa.select(sa.func.current_schema())) + + def function(self, name: str, *, schema: str | None = None) -> Callable: + query = sa.text( + """ +SELECT + n.nspname as schema, + pg_catalog.pg_get_function_result(p.oid) as return_type, + string_to_array(pg_catalog.pg_get_function_arguments(p.oid), ', ') as signature, + CASE p.prokind + WHEN 'a' THEN 'agg' + WHEN 'w' THEN 'window' + WHEN 'p' THEN 'proc' + ELSE 'func' + END as "Type" +FROM pg_catalog.pg_proc p +LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace +WHERE p.proname = :name +""" + + "AND n.nspname OPERATOR(pg_catalog.~) :schema COLLATE pg_catalog.default" + * (schema is not None) + ).bindparams(name=name, schema=f"^({schema})$") + + def split_name_type(arg: str) -> tuple[str, dt.DataType]: + name, typ = arg.split(" ", 1) + return name, RisingwaveType.from_string(typ) + + with self.begin() as con: + rows = con.execute(query).mappings().fetchall() + + if not rows: + name = f"{schema}.{name}" if schema else name + raise exc.MissingUDFError(name) + elif len(rows) > 1: + raise exc.AmbiguousUDFError(name) + + [row] = rows + return_type = RisingwaveType.from_string(row["return_type"]) + signature = list(map(split_name_type, row["signature"])) + + # dummy callable + def fake_func(*args, **kwargs): + ... + + fake_func.__name__ = name + fake_func.__signature__ = inspect.Signature( + [ + inspect.Parameter( + name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typ + ) + for name, typ in signature + ], + return_annotation=return_type, + ) + fake_func.__annotations__ = {"return": return_type, **dict(signature)} + op = ops.udf.scalar.builtin(fake_func, schema=schema) + return op + + def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: + name = util.gen_name("risingwave_metadata") + type_info_sql = """\ + SELECT + attname, + format_type(atttypid, atttypmod) AS type + FROM pg_attribute + WHERE attrelid = CAST(:name AS regclass) + AND attnum > 0 + AND NOT attisdropped + ORDER BY attnum""" + if self.inspector.has_table(query): + query = f"TABLE {query}" + + text = sa.text(type_info_sql).bindparams(name=name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE VIEW IF NOT EXISTS {name} AS {query}") + try: + yield from ( + (col, RisingwaveType.from_string(typestr)) + for col, typestr in con.execute(text) + ) + finally: + con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}") + + def _get_temp_view_definition( + self, name: str, definition: sa.sql.compiler.Compiled + ) -> str: + yield f"DROP VIEW IF EXISTS {name}" + yield f"CREATE TEMPORARY VIEW {name} AS {definition}" + + def create_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support creating a schema in a different database" + ) + if_not_exists = "IF NOT EXISTS " * force + name = self._quote(name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + + def drop_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support dropping a schema in a different database" + ) + name = self._quote(name) + if_exists = "IF EXISTS " * force + with self.begin() as con: + con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py new file mode 100644 index 0000000000000..b4bcd9c0b9d5c --- /dev/null +++ b/ibis/backends/risingwave/compiler.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import ibis.expr.operations as ops +from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.backends.risingwave.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample + + +class RisingwaveExprTranslator(AlchemyExprTranslator): + _registry = operation_registry.copy() + _rewrites = AlchemyExprTranslator._rewrites.copy() + _has_reduction_filter_syntax = True + _supports_tuple_syntax = True + _dialect_name = "risingwave" + + # it does support it, but we can't use it because of support for pivot + supports_unnest_in_select = False + + type_mapper = RisingwaveType + + +rewrites = RisingwaveExprTranslator.rewrites + + +@rewrites(ops.Any) +@rewrites(ops.All) +def _any_all_no_op(expr): + return expr + + +class RisingwaveCompiler(AlchemyCompiler): + translator_class = RisingwaveExprTranslator + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/risingwave/datatypes.py b/ibis/backends/risingwave/datatypes.py new file mode 100644 index 0000000000000..389210486a6f8 --- /dev/null +++ b/ibis/backends/risingwave/datatypes.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as psql +import sqlalchemy.types as sat + +import ibis.expr.datatypes as dt +from ibis.backends.base.sql.alchemy.datatypes import AlchemyType +from ibis.backends.base.sqlglot.datatypes import PostgresType as SqlglotPostgresType + +_from_postgres_types = { + psql.DOUBLE_PRECISION: dt.Float64, + psql.JSONB: dt.JSON, + psql.JSON: dt.JSON, + psql.BYTEA: dt.Binary, +} + + +_postgres_interval_fields = { + "YEAR": "Y", + "MONTH": "M", + "DAY": "D", + "HOUR": "h", + "MINUTE": "m", + "SECOND": "s", + "YEAR TO MONTH": "M", + "DAY TO HOUR": "h", + "DAY TO MINUTE": "m", + "DAY TO SECOND": "s", + "HOUR TO MINUTE": "m", + "HOUR TO SECOND": "s", + "MINUTE TO SECOND": "s", +} + + +class RisingwaveType(AlchemyType): + dialect = "risingwave" + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine: + if dtype.is_floating(): + if isinstance(dtype, dt.Float64): + return psql.DOUBLE_PRECISION + else: + return psql.REAL + elif dtype.is_array(): + # Unwrap the array element type because sqlalchemy doesn't allow arrays of + # arrays. This doesn't affect the underlying data. + while dtype.is_array(): + dtype = dtype.value_type + return sa.ARRAY(cls.from_ibis(dtype)) + elif dtype.is_map(): + if not (dtype.key_type.is_string() and dtype.value_type.is_string()): + raise TypeError( + f"Risingwave only supports map, got: {dtype}" + ) + return psql.HSTORE() + elif dtype.is_uuid(): + return psql.UUID() + else: + return super().from_ibis(dtype) + + @classmethod + def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: + if dtype := _from_postgres_types.get(type(typ)): + return dtype(nullable=nullable) + elif isinstance(typ, psql.HSTORE): + return dt.Map(dt.string, dt.string, nullable=nullable) + elif isinstance(typ, psql.INTERVAL): + field = typ.fields.upper() + if (unit := _postgres_interval_fields.get(field, None)) is None: + raise ValueError(f"Unknown Risingwave interval field {field!r}") + elif unit in {"Y", "M"}: + raise ValueError( + "Variable length intervals are not yet supported with Risingwave" + ) + return dt.Interval(unit=unit, nullable=nullable) + else: + return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_string(cls, type_string: str) -> RisingwaveType: + return SqlglotPostgresType.from_string(type_string) diff --git a/ibis/backends/risingwave/registry.py b/ibis/backends/risingwave/registry.py new file mode 100644 index 0000000000000..a6cb67ca83a8f --- /dev/null +++ b/ibis/backends/risingwave/registry.py @@ -0,0 +1,848 @@ +from __future__ import annotations + +import functools +import itertools +import locale +import operator +import platform +import re +import string + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import GenericFunction + +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops + +# used for literal translate +from ibis.backends.base.sql.alchemy import ( + fixed_arity, + get_sqla_table, + reduction, + sqlalchemy_operation_registry, + sqlalchemy_window_functions_registry, + unary, + varargs, +) +from ibis.backends.base.sql.alchemy.registry import ( + _bitwise_op, + _extract, + get_col, +) + +operation_registry = sqlalchemy_operation_registry.copy() +operation_registry.update(sqlalchemy_window_functions_registry) + +_truncate_precisions = { + "us": "microseconds", + "ms": "milliseconds", + "s": "second", + "m": "minute", + "h": "hour", + "D": "day", + "W": "week", + "M": "month", + "Q": "quarter", + "Y": "year", +} + + +def _timestamp_truncate(t, op): + sa_arg = t.translate(op.arg) + try: + precision = _truncate_precisions[op.unit.short] + except KeyError: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit!r}") + return sa.func.date_trunc(precision, sa_arg) + + +def _timestamp_bucket(t, op): + arg = t.translate(op.arg) + interval = t.translate(op.interval) + + origin = sa.literal_column("timestamp '1970-01-01 00:00:00'") + + if op.offset is not None: + origin = origin + t.translate(op.offset) + return sa.func.date_bin(interval, arg, origin) + + +def _typeof(t, op): + sa_arg = t.translate(op.arg) + typ = sa.cast(sa.func.pg_typeof(sa_arg), sa.TEXT) + + # select pg_typeof('thing') returns unknown so we have to check the child's + # type for nullness + return sa.case( + ((typ == "unknown") & (op.arg.dtype != dt.null), "text"), + ((typ == "unknown") & (op.arg.dtype == dt.null), "null"), + else_=typ, + ) + + +_strftime_to_postgresql_rules = { + "%a": "TMDy", # TM does it in a locale dependent way + "%A": "TMDay", + "%w": "D", # 1-based day of week, see below for how we make this 0-based + "%d": "DD", # day of month + "%-d": "FMDD", # - is no leading zero for Python same for FM in postgres + "%b": "TMMon", # Sep + "%B": "TMMonth", # September + "%m": "MM", # 01 + "%-m": "FMMM", # 1 + "%y": "YY", # 15 + "%Y": "YYYY", # 2015 + "%H": "HH24", # 09 + "%-H": "FMHH24", # 9 + "%I": "HH12", # 09 + "%-I": "FMHH12", # 9 + "%p": "AM", # AM or PM + "%M": "MI", # zero padded minute + "%-M": "FMMI", # Minute + "%S": "SS", # zero padded second + "%-S": "FMSS", # Second + "%f": "US", # zero padded microsecond + "%z": "OF", # utf offset + "%Z": "TZ", # uppercase timezone name + "%j": "DDD", # zero padded day of year + "%-j": "FMDDD", # day of year + "%U": "WW", # 1-based week of year + # 'W': ?, # meh +} + +try: + _strftime_to_postgresql_rules.update( + { + "%c": locale.nl_langinfo(locale.D_T_FMT), # locale date and time + "%x": locale.nl_langinfo(locale.D_FMT), # locale date + "%X": locale.nl_langinfo(locale.T_FMT), # locale time + } + ) +except AttributeError: + HAS_LANGINFO = False +else: + HAS_LANGINFO = True + + +# translate strftime spec into mostly equivalent Risingwave spec +_scanner = re.Scanner( # type: ignore # re does have a Scanner attribute + # double quotes need to be escaped + [('"', lambda *_: r"\"")] + + [ + ( + "|".join( + map( + "(?:{})".format, + itertools.chain( + _strftime_to_postgresql_rules.keys(), + [ + # "%e" is in the C standard and Python actually + # generates this if your spec contains "%c" but we + # don't officially support it as a specifier so we + # need to special case it in the scanner + "%e", + r"\s+", + rf"[{re.escape(string.punctuation)}]", + rf"[^{re.escape(string.punctuation)}\s]+", + ], + ), + ) + ), + lambda _, token: token, + ) + ] +) + + +_lexicon_values = frozenset(_strftime_to_postgresql_rules.values()) + +_locale_specific_formats = frozenset(["%c", "%x", "%X"]) +_strftime_blacklist = frozenset(["%w", "%U", "%e"]) | _locale_specific_formats + + +def _reduce_tokens(tokens, arg): + # current list of tokens + curtokens = [] + + # reduced list of tokens that accounts for blacklisted values + reduced = [] + + non_special_tokens = frozenset(_strftime_to_postgresql_rules) - _strftime_blacklist + + # TODO: how much of a hack is this? + for token in tokens: + if token in _locale_specific_formats and not HAS_LANGINFO: + raise com.UnsupportedOperationError( + f"Format string component {token!r} is not supported on {platform.system()}" + ) + # we are a non-special token %A, %d, etc. + if token in non_special_tokens: + curtokens.append(_strftime_to_postgresql_rules[token]) + + # we have a string like DD, to escape this we + # surround it with double quotes + elif token in _lexicon_values: + curtokens.append(f'"{token}"') + + # we have a token that needs special treatment + elif token in _strftime_blacklist: + if token == "%w": + value = sa.extract("dow", arg) # 0 based day of week + elif token == "%U": + value = sa.cast(sa.func.to_char(arg, "WW"), sa.SMALLINT) - 1 + elif token in ("%c", "%x", "%X"): + # re scan and tokenize this pattern + try: + new_pattern = _strftime_to_postgresql_rules[token] + except KeyError: + raise ValueError( + "locale specific date formats (%%c, %%x, %%X) are " + "not yet implemented for %s" % platform.system() + ) + + new_tokens, _ = _scanner.scan(new_pattern) + value = functools.reduce( + sa.sql.ColumnElement.concat, + _reduce_tokens(new_tokens, arg), + ) + elif token == "%e": + # pad with spaces instead of zeros + value = sa.func.replace(sa.func.to_char(arg, "DD"), "0", " ") + + reduced += [ + sa.func.to_char(arg, "".join(curtokens)), + sa.cast(value, sa.TEXT), + ] + + # empty current token list in case there are more tokens + del curtokens[:] + + # uninteresting text + else: + curtokens.append(token) + # append result to r if we had more tokens or if we have no + # blacklisted tokens + if curtokens: + reduced.append(sa.func.to_char(arg, "".join(curtokens))) + return reduced + + +def _strftime(arg, pattern): + tokens, _ = _scanner.scan(pattern.value) + reduced = _reduce_tokens(tokens, arg) + return functools.reduce(sa.sql.ColumnElement.concat, reduced) + + +def _find_in_set(t, op): + # TODO + # this operation works with any type, not just strings. should the + # operation itself also have this property? + return ( + sa.func.coalesce( + sa.func.array_position( + pg.array(list(map(t.translate, op.values))), + t.translate(op.needle), + ), + 0, + ) + - 1 + ) + + +def _log(t, op): + arg, base = op.args + sa_arg = t.translate(arg) + if base is not None: + sa_base = t.translate(base) + return sa.cast( + sa.func.log(sa.cast(sa_base, sa.NUMERIC), sa.cast(sa_arg, sa.NUMERIC)), + t.get_sqla_type(op.dtype), + ) + return sa.func.ln(sa_arg) + + +def _regex_extract(arg, pattern, index): + # wrap in parens to support 0th group being the whole string + pattern = "(" + pattern + ")" + # arrays are 1-based in postgres + index = index + 1 + does_match = sa.func.textregexeq(arg, pattern) + matches = sa.func.regexp_match(arg, pattern, type_=pg.ARRAY(sa.TEXT)) + return sa.case((does_match, matches[index]), else_=None) + + +def _array_repeat(t, op): + """Repeat an array.""" + arg = t.translate(op.arg) + times = t.translate(op.times) + + array_length = sa.func.cardinality(arg) + array = sa.sql.elements.Grouping(arg) if isinstance(op.arg, ops.Literal) else arg + + # sequence from 1 to the total number of elements desired in steps of 1. + series = sa.func.generate_series(1, times * array_length).table_valued() + + # if our current index modulo the array's length is a multiple of the + # array's length, then the index is the array's length + index = sa.func.coalesce( + sa.func.nullif(series.column % array_length, 0), array_length + ) + + # tie it all together in a scalar subquery and collapse that into an ARRAY + return sa.func.array(sa.select(array[index]).scalar_subquery()) + + +def _table_column(t, op): + ctx = t.context + table = op.table + + sa_table = get_sqla_table(ctx, table) + out_expr = get_col(sa_table, op) + + if op.dtype.is_timestamp(): + timezone = op.dtype.timezone + if timezone is not None: + out_expr = out_expr.op("AT TIME ZONE")(timezone).label(op.name) + + # If the column does not originate from the table set in the current SELECT + # context, we should format as a subquery + if t.permit_subquery and ctx.is_foreign_expr(table): + return sa.select(out_expr) + + return out_expr + + +def _round(t, op): + arg, digits = op.args + sa_arg = t.translate(arg) + + if digits is None: + return sa.func.round(sa_arg) + + # postgres doesn't allow rounding of double precision values to a specific + # number of digits (though simple truncation on doubles is allowed) so + # we cast to numeric and then cast back if necessary + result = sa.func.round(sa.cast(sa_arg, sa.NUMERIC), t.translate(digits)) + if digits is not None and arg.dtype.is_decimal(): + return result + result = sa.cast(result, pg.DOUBLE_PRECISION()) + return result + + +def _mod(t, op): + left, right = map(t.translate, op.args) + + # postgres doesn't allow modulus of double precision values, so upcast and + # then downcast later if necessary + if not op.dtype.is_integer(): + left = sa.cast(left, sa.NUMERIC) + right = sa.cast(right, sa.NUMERIC) + + result = left % right + if op.dtype.is_float64(): + return sa.cast(result, pg.DOUBLE_PRECISION()) + else: + return result + + +def _neg_idx_to_pos(array, idx): + return sa.case((idx < 0, sa.func.cardinality(array) + idx), else_=idx) + + +def _array_slice(*, index_converter, array_length, func): + def translate(t, op): + arg = t.translate(op.arg) + + arg_length = array_length(arg) + + if (start := op.start) is None: + start = 0 + else: + start = t.translate(start) + start = sa.func.least(arg_length, index_converter(arg, start)) + + if (stop := op.stop) is None: + stop = arg_length + else: + stop = index_converter(arg, t.translate(stop)) + + return func(arg, start + 1, stop) + + return translate + + +def _array_index(*, index_converter, func): + def translate(t, op): + sa_array = t.translate(op.arg) + sa_index = t.translate(op.index) + if isinstance(op.arg, ops.Literal): + sa_array = sa.sql.elements.Grouping(sa_array) + return func(sa_array, index_converter(sa_array, sa_index) + 1) + + return translate + + +def _literal(t, op): + dtype = op.dtype + value = op.value + + if value is None: + return ( + sa.null() if dtype.is_null() else sa.cast(sa.null(), t.get_sqla_type(dtype)) + ) + if dtype.is_interval(): + return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'") + elif dtype.is_array(): + return pg.array(value) + elif dtype.is_map(): + return pg.hstore(list(value.keys()), list(value.values())) + elif dtype.is_time(): + return sa.func.make_time( + value.hour, value.minute, value.second + value.microsecond / 1e6 + ) + elif dtype.is_date(): + return sa.func.make_date(value.year, value.month, value.day) + elif dtype.is_timestamp(): + if (tz := dtype.timezone) is not None: + return sa.func.to_timestamp(value.timestamp()).op("AT TIME ZONE")(tz) + return sa.cast(sa.literal(value.isoformat()), sa.TIMESTAMP()) + else: + return sa.literal(value) + + +def _string_agg(t, op): + agg = sa.func.string_agg(t.translate(op.arg), t.translate(op.sep)) + if (where := op.where) is not None: + return agg.filter(t.translate(where)) + return agg + + +def _corr(t, op): + if op.how == "sample": + raise ValueError( + f"{t.__class__.__name__} only implements population correlation " + "coefficient" + ) + return _binary_variance_reduction(sa.func.corr)(t, op) + + +def _covar(t, op): + suffix = {"sample": "samp", "pop": "pop"} + how = suffix.get(op.how, "samp") + func = getattr(sa.func, f"covar_{how}") + return _binary_variance_reduction(func)(t, op) + + +def _mode(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + return sa.func.mode().within_group(t.translate(arg)) + + +def _quantile(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) + + +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) + + +def _binary_variance_reduction(func): + def variance_compiler(t, op): + x = op.left + if (x_type := x.dtype).is_boolean(): + x = ops.Cast(x, dt.Int32(nullable=x_type.nullable)) + + y = op.right + if (y_type := y.dtype).is_boolean(): + y = ops.Cast(y, dt.Int32(nullable=y_type.nullable)) + + if t._has_reduction_filter_syntax: + result = func(t.translate(x), t.translate(y)) + + if (where := op.where) is not None: + return result.filter(t.translate(where)) + return result + else: + if (where := op.where) is not None: + x = ops.IfElse(where, x, None) + y = ops.IfElse(where, y, None) + return func(t.translate(x), t.translate(y)) + + return variance_compiler + + +def _arg_min_max(sort_func): + def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: + arg = t.translate(op.arg) + key = t.translate(op.key) + + conditions = [arg != sa.null(), key != sa.null()] + + agg = sa.func.array_agg(pg.aggregate_order_by(arg, sort_func(key))) + + if (where := op.where) is not None: + conditions.append(t.translate(where)) + return agg.filter(sa.and_(*conditions))[1] + + return translate + + +def _arbitrary(t, op): + if (how := op.how) == "heavy": + raise com.UnsupportedOperationError( + f"risingwave backend doesn't support how={how!r} for the arbitrary() aggregate" + ) + func = getattr(sa.func, op.how) + return t._reduction(func, op) + + +class rw_struct_field(GenericFunction): + inherit_cache = True + + +@compiles(rw_struct_field) +def compile_struct_field_postgresql(element, compiler, **kw): + arg, field = element.clauses + return f"({compiler.process(arg, **kw)}).{field.name}" + + +def _struct_field(t, op): + arg = op.arg + idx = arg.dtype.names.index(op.field) + 1 + field_name = sa.literal_column(f"f{idx:d}") + return rw_struct_field( + t.translate(arg), field_name, type_=t.get_sqla_type(op.dtype) + ) + + +def _struct_column(t, op): + types = op.dtype.types + return sa.func.row( + # we have to cast here, otherwise risingwave refuses to allow the statement + *map(t.translate, map(ops.Cast, op.values, types)), + type_=t.get_sqla_type( + dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)}) + ), + ) + + +def _unnest(t, op): + arg = op.arg + row_type = arg.dtype.value_type + + types = getattr(row_type, "types", (row_type,)) + + is_struct = row_type.is_struct() + derived = ( + sa.func.unnest(t.translate(arg)) + .table_valued( + *( + sa.column(f"f{i:d}", stype) + for i, stype in enumerate(map(t.get_sqla_type, types), start=1) + ) + ) + .render_derived(with_types=is_struct) + ) + + # wrap in a row column so that we can return a single column from this rule + if not is_struct: + return derived.c[0] + return sa.func.row(*derived.c) + + +def _array_sort(arg): + flat = sa.func.unnest(arg).column_valued() + return sa.func.array(sa.select(flat).order_by(flat).scalar_subquery()) + + +def _array_position(haystack, needle): + t = ( + sa.func.unnest(haystack) + .table_valued("value", with_ordinality="idx", name="haystack") + .render_derived() + ) + idx = t.c.idx - 1 + return sa.func.coalesce( + sa.select(idx).where(t.c.value == needle).limit(1).scalar_subquery(), -1 + ) + + +def _array_map(t, op): + return sa.func.array( + # this translates to the function call, with column names the same as + # the parameter names in the lambda + sa.select(t.translate(op.body)) + .select_from( + # unnest the input array + sa.func.unnest(t.translate(op.arg)) + # name the columns of the result the same as the lambda parameter + # so that we can reference them as such in the outer query + .table_valued(op.param) + .render_derived() + ) + .scalar_subquery() + ) + + +def _array_filter(t, op): + param = op.param + return sa.func.array( + sa.select(sa.column(param, type_=t.get_sqla_type(op.arg.dtype.value_type))) + .select_from( + sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived() + ) + .where(t.translate(op.body)) + .scalar_subquery() + ) + + +def zero_value(dtype): + if dtype.is_interval(): + return sa.func.make_interval() + return 0 + + +def interval_sign(v): + zero = sa.func.make_interval() + return sa.case((v == zero, 0), (v < zero, -1), (v > zero, 1)) + + +def _sign(value, dtype): + if dtype.is_interval(): + return interval_sign(value) + return sa.func.sign(value) + + +def _range(t, op): + start = t.translate(op.start) + stop = t.translate(op.stop) + step = t.translate(op.step) + satype = t.get_sqla_type(op.dtype) + seq = sa.func.generate_series(start, stop, step, type_=satype) + zero = zero_value(op.step.dtype) + return sa.case( + ( + sa.and_( + sa.func.nullif(step, zero).is_not(None), + _sign(step, op.step.dtype) == _sign(stop - start, op.step.dtype), + ), + sa.func.array_remove( + sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype + ), + ), + else_=sa.cast(pg.array([]), satype), + ) + + +operation_registry.update( + { + ops.Literal: _literal, + # We override this here to support time zones + ops.TableColumn: _table_column, + ops.Argument: lambda t, op: sa.column( + op.param, type_=t.get_sqla_type(op.dtype) + ), + # types + ops.TypeOf: _typeof, + # Floating + ops.IsNan: fixed_arity(lambda arg: arg == float("nan"), 1), + ops.IsInf: fixed_arity( + lambda arg: sa.or_(arg == float("inf"), arg == float("-inf")), 1 + ), + # boolean reductions + ops.Any: reduction(sa.func.bool_or), + ops.All: reduction(sa.func.bool_and), + # strings + ops.GroupConcat: _string_agg, + ops.Capitalize: unary(sa.func.initcap), + ops.RegexSearch: fixed_arity(lambda x, y: x.op("~")(y), 2), + # postgres defaults to replacing only the first occurrence + ops.RegexReplace: fixed_arity( + lambda string, pattern, replacement: sa.func.regexp_replace( + string, pattern, replacement, "g" + ), + 3, + ), + ops.Translate: fixed_arity(sa.func.translate, 3), + ops.RegexExtract: fixed_arity(_regex_extract, 3), + ops.StringSplit: fixed_arity( + lambda col, sep: sa.func.string_to_array( + col, sep, type_=sa.ARRAY(col.type) + ), + 2, + ), + ops.FindInSet: _find_in_set, + # math + ops.Log: _log, + ops.Log2: unary(lambda x: sa.func.log(2, x)), + ops.Log10: unary(sa.func.log), + ops.Round: _round, + ops.Modulus: _mod, + # dates and times + ops.DateFromYMD: fixed_arity(sa.func.make_date, 3), + ops.DateTruncate: _timestamp_truncate, + ops.TimestampTruncate: _timestamp_truncate, + ops.TimestampBucket: _timestamp_bucket, + ops.IntervalFromInteger: ( + lambda t, op: t.translate(op.arg) + * sa.text(f"INTERVAL '1 {op.dtype.resolution}'") + ), + ops.DateAdd: fixed_arity(operator.add, 2), + ops.DateSub: fixed_arity(operator.sub, 2), + ops.DateDiff: fixed_arity(operator.sub, 2), + ops.TimestampAdd: fixed_arity(operator.add, 2), + ops.TimestampSub: fixed_arity(operator.sub, 2), + ops.TimestampDiff: fixed_arity(operator.sub, 2), + ops.Strftime: fixed_arity(_strftime, 2), + ops.ExtractEpochSeconds: fixed_arity( + lambda arg: sa.cast(sa.extract("epoch", arg), sa.INTEGER), 1 + ), + ops.ExtractDayOfYear: _extract("doy"), + ops.ExtractWeekOfYear: _extract("week"), + # extracting the second gives us the fractional part as well, so smash that + # with a cast to SMALLINT + ops.ExtractSecond: fixed_arity( + lambda arg: sa.cast(sa.func.floor(sa.extract("second", arg)), sa.SMALLINT), + 1, + ), + # we get total number of milliseconds including seconds with extract so we + # mod 1000 + ops.ExtractMillisecond: fixed_arity( + lambda arg: sa.cast( + sa.func.floor(sa.extract("millisecond", arg)) % 1000, + sa.SMALLINT, + ), + 1, + ), + ops.DayOfWeekIndex: fixed_arity( + lambda arg: sa.cast( + sa.cast(sa.extract("dow", arg) + 6, sa.SMALLINT) % 7, sa.SMALLINT + ), + 1, + ), + ops.DayOfWeekName: fixed_arity( + lambda arg: sa.func.trim(sa.func.to_char(arg, "Day")), 1 + ), + ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3), + # array operations + ops.ArrayLength: unary(sa.func.cardinality), + ops.ArrayCollect: reduction(sa.func.array_agg), + ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))), + ops.ArraySlice: _array_slice( + index_converter=_neg_idx_to_pos, + array_length=sa.func.cardinality, + func=lambda arg, start, stop: arg[start:stop], + ), + ops.ArrayIndex: _array_index( + index_converter=_neg_idx_to_pos, func=lambda arg, index: arg[index] + ), + ops.ArrayConcat: varargs(lambda *args: functools.reduce(operator.add, args)), + ops.ArrayRepeat: _array_repeat, + ops.Unnest: _unnest, + ops.Covariance: _covar, + ops.Correlation: _corr, + ops.BitwiseXor: _bitwise_op("#"), + ops.Mode: _mode, + ops.ApproxMedian: _median, + ops.Median: _median, + ops.Quantile: _quantile, + ops.MultiQuantile: _quantile, + ops.TimestampNow: lambda t, op: sa.literal_column( + "CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.dtype) + ), + ops.MapGet: fixed_arity( + lambda arg, key, default: sa.case( + (arg.has_key(key), arg[key]), else_=default + ), + 3, + ), + ops.MapContains: fixed_arity(pg.HSTORE.Comparator.has_key, 2), + ops.MapKeys: unary(pg.HSTORE.Comparator.keys), + ops.MapValues: unary(pg.HSTORE.Comparator.vals), + ops.MapMerge: fixed_arity(operator.add, 2), + ops.MapLength: unary(lambda arg: sa.func.cardinality(arg.keys())), + ops.Map: fixed_arity(pg.hstore, 2), + ops.ArgMin: _arg_min_max(sa.asc), + ops.ArgMax: _arg_min_max(sa.desc), + ops.ArrayStringJoin: fixed_arity( + lambda sep, arr: sa.func.array_to_string(arr, sep), 2 + ), + ops.Strip: unary(lambda arg: sa.func.trim(arg, string.whitespace)), + ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)), + ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), + ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2), + ops.Arbitrary: _arbitrary, + ops.StructColumn: _struct_column, + ops.StructField: _struct_field, + ops.First: reduction(sa.func.first), + ops.Last: reduction(sa.func.last), + ops.ExtractMicrosecond: fixed_arity( + lambda arg: sa.extract("microsecond", arg) % 1_000_000, 1 + ), + ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2), + ops.ArraySort: fixed_arity(_array_sort, 1), + ops.ArrayIntersect: fixed_arity( + lambda left, right: sa.func.array( + sa.intersect( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayRemove: fixed_arity( + lambda left, right: sa.func.array( + sa.except_( + sa.select(sa.func.unnest(left).column_valued()), sa.select(right) + ).scalar_subquery() + ), + 2, + ), + ops.ArrayUnion: fixed_arity( + lambda left, right: sa.func.array( + sa.union( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayDistinct: fixed_arity( + lambda arg: sa.case( + (arg.is_(sa.null()), sa.null()), + else_=sa.func.array( + sa.select( + sa.distinct(sa.func.unnest(arg).column_valued()) + ).scalar_subquery() + ), + ), + 1, + ), + ops.ArrayPosition: fixed_arity(_array_position, 2), + ops.ArrayMap: _array_map, + ops.ArrayFilter: _array_filter, + ops.IntegerRange: _range, + ops.TimestampRange: _range, + ops.RegexSplit: fixed_arity(sa.func.regexp_split_to_array, 2), + } +) diff --git a/ibis/backends/risingwave/tests/__init__.py b/ibis/backends/risingwave/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ibis/backends/risingwave/tests/conftest.py b/ibis/backends/risingwave/tests/conftest.py new file mode 100644 index 0000000000000..35cfe6b8e1dbd --- /dev/null +++ b/ibis/backends/risingwave/tests/conftest.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +import pytest +import sqlalchemy as sa + +import ibis +from ibis.backends.conftest import init_database +from ibis.backends.tests.base import ServiceBackendTest + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + +PG_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", os.environ.get("PGUSER", "root")) +PG_PASS = os.environ.get( + "IBIS_TEST_RISINGWAVE_PASSWORD", os.environ.get("PGPASSWORD", "") +) +PG_HOST = os.environ.get( + "IBIS_TEST_RISINGWAVE_HOST", os.environ.get("PGHOST", "localhost") +) +PG_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", os.environ.get("PGPORT", 4566)) +IBIS_TEST_RISINGWAVE_DB = os.environ.get( + "IBIS_TEST_RISINGWAVE_DATABASE", os.environ.get("PGDATABASE", "dev") +) + + +class TestConf(ServiceBackendTest): + # postgres rounds half to even for double precision and half away from zero + # for numeric and decimal + + returned_timestamp_unit = "s" + supports_structs = False + rounding_method = "half_to_even" + service_name = "risingwave" + deps = "psycopg2", "sqlalchemy" + + @property + def test_files(self) -> Iterable[Path]: + return self.data_dir.joinpath("csv").glob("*.csv") + + def _load_data( + self, + *, + user: str = PG_USER, + password: str = PG_PASS, + host: str = PG_HOST, + port: int = PG_PORT, + database: str = IBIS_TEST_RISINGWAVE_DB, + **_: Any, + ) -> None: + """Load test data into a Risingwave backend instance. + + Parameters + ---------- + data_dir + Location of test data + script_dir + Location of scripts defining schemas + """ + init_database( + url=sa.engine.make_url( + f"risingwave://{user}:{password}@{host}:{port:d}/{database}" + ), + database=database, + schema=self.ddl_script, + isolation_level="AUTOCOMMIT", + recreate=False, + ) + + @staticmethod + def connect(*, tmpdir, worker_id, port: int | None = None, **kw): + con = ibis.risingwave.connect( + host=PG_HOST, + port=port or PG_PORT, + user=PG_USER, + password=PG_PASS, + database=IBIS_TEST_RISINGWAVE_DB, + **kw, + ) + cursor = con.raw_sql("SET RW_IMPLICIT_FLUSH TO true;") + cursor.close() + return con + + +@pytest.fixture(scope="session") +def con(tmp_path_factory, data_dir, worker_id): + return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection + + +@pytest.fixture(scope="module") +def db(con): + return con.database() + + +@pytest.fixture(scope="module") +def alltypes(db): + return db.functional_alltypes + + +@pytest.fixture(scope="module") +def df(alltypes): + return alltypes.execute() + + +@pytest.fixture(scope="module") +def alltypes_sqla(con, alltypes): + name = alltypes.op().name + return con._get_sqla_table(name) + + +@pytest.fixture(scope="module") +def intervals(con): + return con.table("intervals") + + +@pytest.fixture +def translate(): + from ibis.backends.risingwave import Backend + + context = Backend.compiler.make_context() + return lambda expr: Backend.compiler.translator_class(expr, context).get_result() diff --git a/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql new file mode 100644 index 0000000000000..cfbcf133a863b --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql @@ -0,0 +1,2 @@ +SELECT sum(t0.foo) AS "Sum(foo)" +FROM t0 AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql new file mode 100644 index 0000000000000..c00dec1bed252 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql @@ -0,0 +1,7 @@ +SELECT + RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS rank, + DENSE_RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS dense_rank, + CUME_DIST() OVER (ORDER BY t0.double_col ASC) AS cume_dist, + NTILE(7) OVER (ORDER BY t0.double_col ASC) - 1 AS ntile, + PERCENT_RANK() OVER (ORDER BY t0.double_col ASC) AS percent_rank +FROM functional_alltypes AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql new file mode 100644 index 0000000000000..34761d9a76e0d --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION ALL SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION ALL SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql new file mode 100644 index 0000000000000..6ce31e7468bb8 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/test_client.py b/ibis/backends/risingwave/tests/test_client.py new file mode 100644 index 0000000000000..b5c7cfa985609 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_client.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os + +import pandas as pd +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis.tests.util import assert_equal + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + +RISINGWAVE_TEST_DB = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") +IBIS_RISINGWAVE_HOST = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") +IBIS_RISINGWAVE_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", "4566") +IBIS_RISINGWAVE_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", "root") +IBIS_RISINGWAVE_PASS = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD", "") + + +def test_table(alltypes): + assert isinstance(alltypes, ir.Table) + + +def test_array_execute(alltypes): + d = alltypes.limit(10).double_col + s = d.execute() + assert isinstance(s, pd.Series) + assert len(s) == 10 + + +def test_literal_execute(con): + expr = ibis.literal("1234") + result = con.execute(expr) + assert result == "1234" + + +def test_simple_aggregate_execute(alltypes): + d = alltypes.double_col.sum() + v = d.execute() + assert isinstance(v, float) + + +def test_list_tables(con): + assert con.list_tables() + assert len(con.list_tables(like="functional")) == 1 + + +def test_compile_toplevel(snapshot): + t = ibis.table([("foo", "double")], name="t0") + + expr = t.foo.sum() + result = ibis.postgres.compile(expr) + snapshot.assert_match(str(result), "out.sql") + + +def test_list_databases(con): + assert RISINGWAVE_TEST_DB is not None + assert RISINGWAVE_TEST_DB in con.list_databases() + + +def test_schema_type_conversion(con): + typespec = [ + # name, type, nullable + ("jsonb", postgresql.JSONB, True, dt.JSON), + ] + + sqla_types = [] + ibis_types = [] + for name, t, nullable, ibis_type in typespec: + sqla_types.append(sa.Column(name, t, nullable=nullable)) + ibis_types.append((name, ibis_type(nullable=nullable))) + + # Create a table with placeholder stubs for JSON, JSONB, and UUID. + table = sa.Table("tname", sa.MetaData(), *sqla_types) + + # Check that we can correctly create a schema with dt.any for the + # missing types. + schema = con._schema_from_sqla_table(table) + expected = ibis.schema(ibis_types) + + assert_equal(schema, expected) + + +@pytest.mark.parametrize("params", [{}, {"database": RISINGWAVE_TEST_DB}]) +def test_create_and_drop_table(con, temp_table, params): + sch = ibis.schema( + [ + ("first_name", "string"), + ("last_name", "string"), + ("department_name", "string"), + ("salary", "float64"), + ] + ) + + con.create_table(temp_table, schema=sch, **params) + assert con.table(temp_table, **params) is not None + + con.drop_table(temp_table, **params) + + with pytest.raises(sa.exc.NoSuchTableError): + con.table(temp_table, **params) + + +@pytest.mark.parametrize( + ("pg_type", "expected_type"), + [ + param(pg_type, ibis_type, id=pg_type.lower()) + for (pg_type, ibis_type) in [ + ("boolean", dt.boolean), + ("bytea", dt.binary), + ("bigint", dt.int64), + ("smallint", dt.int16), + ("integer", dt.int32), + ("text", dt.string), + ("real", dt.float32), + ("double precision", dt.float64), + ("character varying", dt.string), + ("date", dt.date), + ("time", dt.time), + ("time without time zone", dt.time), + ("timestamp without time zone", dt.timestamp), + ("timestamp with time zone", dt.Timestamp("UTC")), + ("interval", dt.Interval("s")), + ("numeric", dt.decimal), + ("jsonb", dt.json), + ] + ], +) +def test_get_schema_from_query(con, pg_type, expected_type): + name = con._quote(ibis.util.guid()) + with con.begin() as c: + c.exec_driver_sql(f"CREATE TABLE {name} (x {pg_type}, y {pg_type}[])") + expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type))) + result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}") + assert result_schema == expected_schema + with con.begin() as c: + c.exec_driver_sql(f"DROP TABLE {name}") + + +@pytest.mark.xfail(reason="unsupported insert with CTEs") +def test_insert_with_cte(con): + X = con.create_table("X", schema=ibis.schema(dict(id="int")), temp=False) + expr = X.join(X.mutate(a=X["id"] + 1), ["id"]) + Y = con.create_table("Y", expr, temp=False) + assert Y.execute().empty + con.drop_table("Y") + con.drop_table("X") + + +def test_connect_url_with_empty_host(): + con = ibis.connect("risingwave:///dev") + assert con.con.url.host is None diff --git a/ibis/backends/risingwave/tests/test_functions.py b/ibis/backends/risingwave/tests/test_functions.py new file mode 100644 index 0000000000000..c8874e390c608 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_functions.py @@ -0,0 +1,1032 @@ +from __future__ import annotations + +import operator +import string +import warnings +from datetime import datetime + +import numpy as np +import pandas as pd +import pandas.testing as tm +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis import config +from ibis import literal as L + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + + +@pytest.mark.parametrize( + ("left_func", "right_func"), + [ + param( + lambda t: t.double_col.cast("int8"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int8", + ), + param( + lambda t: t.double_col.cast("int16"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int16", + ), + param( + lambda t: t.string_col.cast("double"), + lambda at: sa.cast(at.c.string_col, postgresql.DOUBLE_PRECISION), + id="string_to_double", + ), + param( + lambda t: t.string_col.cast("float32"), + lambda at: sa.cast(at.c.string_col, postgresql.REAL), + id="string_to_float", + ), + param( + lambda t: t.string_col.cast("decimal"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC()), + id="string_to_decimal_no_params", + ), + param( + lambda t: t.string_col.cast("decimal(9, 3)"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC(9, 3)), + id="string_to_decimal_params", + ), + ], +) +def test_cast(alltypes, alltypes_sqla, translate, left_func, right_func): + left = left_func(alltypes) + right = right_func(alltypes_sqla.alias("t0")) + assert str(translate(left.op()).compile()) == str(right.compile()) + + +def test_date_cast(alltypes, alltypes_sqla, translate): + result = alltypes.date_string_col.cast("date") + expected = sa.cast(alltypes_sqla.alias("t0").c.date_string_col, sa.DATE) + assert str(translate(result.op())) == str(expected) + + +@pytest.mark.parametrize( + "column", + [ + "id", + "bool_col", + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "date_string_col", + "string_col", + "timestamp_col", + "year", + "month", + ], +) +def test_noop_cast(alltypes, alltypes_sqla, translate, column): + col = alltypes[column] + result = col.cast(col.type()) + expected = alltypes_sqla.alias("t0").c[column] + assert result.equals(col) + assert str(translate(result.op())) == str(expected) + + +def test_timestamp_cast_noop(alltypes, alltypes_sqla, translate): + # See GH #592 + result1 = alltypes.timestamp_col.cast("timestamp") + result2 = alltypes.int_col.cast("timestamp") + + assert isinstance(result1, ir.TimestampColumn) + assert isinstance(result2, ir.TimestampColumn) + + expected1 = alltypes_sqla.alias("t0").c.timestamp_col + expected2 = sa.cast( + sa.func.to_timestamp(alltypes_sqla.alias("t0").c.int_col), sa.TIMESTAMP() + ) + + assert str(translate(result1.op())) == str(expected1) + assert str(translate(result2.op())) == str(expected2) + + +@pytest.mark.parametrize(("value", "expected"), [(0, None), (5.5, 5.5)]) +def test_nullif_zero(con, value, expected): + assert con.execute(L(value).nullif(0)) == expected + + +@pytest.mark.parametrize(("value", "expected"), [("foo_bar", 7), ("", 0)]) +def test_string_length(con, value, expected): + assert con.execute(L(value).length()) == expected + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + param(operator.methodcaller("left", 3), "foo", id="left"), + param(operator.methodcaller("right", 3), "bar", id="right"), + param(operator.methodcaller("substr", 0, 3), "foo", id="substr_0_3"), + param(operator.methodcaller("substr", 4, 3), "bar", id="substr_4, 3"), + param(operator.methodcaller("substr", 1), "oo_bar", id="substr_1"), + ], +) +def test_string_substring(con, op, expected): + value = L("foo_bar") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "expected"), + [("lstrip", "foo "), ("rstrip", " foo"), ("strip", "foo")], +) +def test_string_strip(con, opname, expected): + op = operator.methodcaller(opname) + value = L(" foo ") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "count", "char", "expected"), + [("lpad", 6, " ", " foo"), ("rpad", 6, " ", "foo ")], +) +def test_string_pad(con, opname, count, char, expected): + op = operator.methodcaller(opname, count, char) + value = L("foo") + assert con.execute(op(value)) == expected + + +def test_string_reverse(con): + assert con.execute(L("foo").reverse()) == "oof" + + +def test_string_upper(con): + assert con.execute(L("foo").upper()) == "FOO" + + +def test_string_lower(con): + assert con.execute(L("FOO").lower()) == "foo" + + +@pytest.mark.parametrize( + ("haystack", "needle", "expected"), + [ + ("foobar", "bar", True), + ("foobar", "foo", True), + ("foobar", "baz", False), + ("100%", "%", True), + ("a_b_c", "_", True), + ], +) +def test_string_contains(con, haystack, needle, expected): + value = L(haystack) + expr = value.contains(needle) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")], +) +def test_capitalize(con, value, expected): + assert con.execute(L(value).capitalize()) == expected + + +def test_repeat(con): + expr = L("bar ").repeat(3) + assert con.execute(expr) == "bar bar bar " + + +def test_re_replace(con): + expr = L("fudge|||chocolate||candy").re_replace("\\|{2,3}", ", ") + assert con.execute(expr) == "fudge, chocolate, candy" + + +def test_translate(con): + expr = L("faab").translate("a", "b") + assert con.execute(expr) == "fbbb" + + +@pytest.mark.parametrize( + ("raw_value", "expected"), [("a", 0), ("b", 1), ("d", -1), (None, 3)] +) +def test_find_in_set(con, raw_value, expected): + value = L(raw_value, dt.string) + haystack = ["a", "b", "c", None] + expr = value.find_in_set(haystack) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("raw_value", "opname", "expected"), + [ + (None, "isnull", True), + (1, "isnull", False), + (None, "notnull", False), + (1, "notnull", True), + ], +) +def test_isnull_notnull(con, raw_value, opname, expected): + lit = L(raw_value) + op = operator.methodcaller(opname) + expr = op(lit) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("foobar").find("bar"), 3, id="find_pos"), + param(L("foobar").find("baz"), -1, id="find_neg"), + param(L("foobar").like("%bar"), True, id="like_left_pattern"), + param(L("foobar").like("foo%"), True, id="like_right_pattern"), + param(L("foobar").like("%baz%"), False, id="like_both_sides_pattern"), + param(L("foobar").like(["%bar"]), True, id="like_list_left_side"), + param(L("foobar").like(["foo%"]), True, id="like_list_right_side"), + param(L("foobar").like(["%baz%"]), False, id="like_list_both_sides"), + param(L("foobar").like(["%bar", "foo%"]), True, id="like_list_multiple"), + param(L("foobarfoo").replace("foo", "H"), "HbarH", id="replace"), + param(L("a").ascii_str(), ord("a"), id="ascii_str"), + ], +) +def test_string_functions(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("abcd").re_search("[a-z]"), True, id="re_search_match"), + param(L("abcd").re_search(r"[\d]+"), False, id="re_search_no_match"), + param(L("1222").re_search(r"[\d]+"), True, id="re_search_match_number"), + ], +) +def test_regexp(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.NA.fillna(5), 5, id="filled"), + param(L(5).fillna(10), 5, id="not_filled"), + param(L(5).nullif(5), None, id="nullif_null"), + param(L(10).nullif(5), 10, id="nullif_not_null"), + ], +) +def test_fillna_nullif(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(5, None, 4), 5, id="first"), + param(ibis.coalesce(ibis.NA, 4, ibis.NA), 4, id="second"), + param(ibis.coalesce(ibis.NA, ibis.NA, 3.14), 3.14, id="third"), + ], +) +def test_coalesce(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(ibis.NA, ibis.NA), None, id="all_null"), + param( + ibis.coalesce( + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ), + None, + id="all_nulls_with_all_cast", + ), + ], +) +def test_coalesce_all_na(con, expr, expected): + assert con.execute(expr) is None + + +def test_coalesce_all_na_double(con): + expr = ibis.coalesce(ibis.NA, ibis.NA, ibis.NA.cast("double")) + assert np.isnan(con.execute(expr)) + + +def test_numeric_builtins_work(alltypes, df): + expr = alltypes.double_col.fillna(0) + result = expr.execute() + expected = df.double_col.fillna(0) + expected.name = "Coalesce()" + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("op", "pandas_op"), + [ + param( + lambda t: (t.double_col > 20).ifelse(10, -20), + lambda df: pd.Series(np.where(df.double_col > 20, 10, -20), dtype="int8"), + id="simple", + ), + param( + lambda t: (t.double_col > 20).ifelse(10, -20).abs(), + lambda df: pd.Series( + np.where(df.double_col > 20, 10, -20), dtype="int8" + ).abs(), + id="abs", + ), + ], +) +def test_ifelse(alltypes, df, op, pandas_op): + expr = op(alltypes) + result = expr.execute() + result.name = None + expected = pandas_op(df) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + # tier and histogram + param( + lambda d: d.bucket([0, 10, 25, 50, 100]), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_over_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], include_over=True), + lambda s: pd.cut( + s, [0, 10, 25, 50, np.inf], right=False, labels=False + ).astype("int8"), + id="include_over_true", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], close_extreme=False), + lambda s: pd.cut(s, [0, 10, 25, 50], right=False, labels=False), + id="close_extreme_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], closed="right", close_extreme=False), + lambda s: pd.cut( + s, + [0, 10, 25, 50], + include_lowest=False, + right=True, + labels=False, + ), + id="closed_right", + ), + param( + lambda d: d.bucket([10, 25, 50, 100], include_under=True), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_under_true", + ), + ], +) +def test_bucket(alltypes, df, func, pandas_func): + expr = func(alltypes.double_col) + result = expr.execute() + expected = pandas_func(df.double_col) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_category_label(alltypes, df): + t = alltypes + d = t.double_col + + bins = [0, 10, 25, 50, 100] + labels = ["a", "b", "c", "d"] + bucket = d.bucket(bins) + expr = bucket.label(labels) + result = expr.execute() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = pd.Series(pd.Categorical(result, ordered=True)) + + result.name = "double_col" + + expected = pd.cut(df.double_col, bins, labels=labels, right=False) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("distinct", [True, False]) +def test_union_cte(alltypes, distinct, snapshot): + t = alltypes + expr1 = t.group_by(t.string_col).aggregate(metric=t.double_col.sum()) + expr2 = expr1.view() + expr3 = expr1.view() + expr = expr1.union(expr2, distinct=distinct).union(expr3, distinct=distinct) + result = " ".join( + line.strip() + for line in str( + expr.compile().compile(compile_kwargs={"literal_binds": True}) + ).splitlines() + ) + snapshot.assert_match(result, "out.sql") + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + param( + lambda t, cond: t.bool_col.count(), + lambda df, cond: df.bool_col.count(), + id="count", + ), + param( + lambda t, cond: t.double_col.mean(), + lambda df, cond: df.double_col.mean(), + id="mean", + ), + param( + lambda t, cond: t.double_col.min(), + lambda df, cond: df.double_col.min(), + id="min", + ), + param( + lambda t, cond: t.double_col.max(), + lambda df, cond: df.double_col.max(), + id="max", + ), + param( + lambda t, cond: t.double_col.var(), + lambda df, cond: df.double_col.var(), + id="var", + ), + param( + lambda t, cond: t.double_col.std(), + lambda df, cond: df.double_col.std(), + id="std", + ), + param( + lambda t, cond: t.double_col.var(how="sample"), + lambda df, cond: df.double_col.var(ddof=1), + id="samp_var", + ), + param( + lambda t, cond: t.double_col.std(how="pop"), + lambda df, cond: df.double_col.std(ddof=0), + id="pop_std", + ), + param( + lambda t, cond: t.bool_col.count(where=cond), + lambda df, cond: df.bool_col[cond].count(), + id="count_where", + ), + param( + lambda t, cond: t.double_col.mean(where=cond), + lambda df, cond: df.double_col[cond].mean(), + id="mean_where", + ), + param( + lambda t, cond: t.double_col.min(where=cond), + lambda df, cond: df.double_col[cond].min(), + id="min_where", + ), + param( + lambda t, cond: t.double_col.max(where=cond), + lambda df, cond: df.double_col[cond].max(), + id="max_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond), + lambda df, cond: df.double_col[cond].var(), + id="var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond), + lambda df, cond: df.double_col[cond].std(), + id="std_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond, how="sample"), + lambda df, cond: df.double_col[cond].var(), + id="samp_var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond, how="pop"), + lambda df, cond: df.double_col[cond].std(ddof=0), + id="pop_std_where", + ), + ], +) +def test_aggregations(alltypes, df, func, pandas_func): + table = alltypes.limit(100) + df = df.head(table.count().execute()) + + cond = table.string_col.isin(["1", "7"]) + expr = func(table, cond) + result = expr.execute() + expected = pandas_func(df, cond.execute()) + + np.testing.assert_allclose(result, expected) + + +def test_not_contains(alltypes, df): + n = 100 + table = alltypes.limit(n) + expr = table.string_col.notin(["1", "7"]) + result = expr.execute() + expected = ~df.head(n).string_col.isin(["1", "7"]) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_group_concat(alltypes, df): + expr = alltypes.string_col.group_concat() + result = expr.execute() + expected = ",".join(df.string_col.dropna()) + assert result == expected + + +def test_distinct_aggregates(alltypes, df): + expr = alltypes.limit(100).double_col.nunique() + result = expr.execute() + assert result == df.head(100).double_col.nunique() + + +def test_not_exists(alltypes, df): + t = alltypes + t2 = t.view() + + expr = t[~((t.string_col == t2.string_col).any())] + result = expr.execute() + + left, right = df, t2.execute() + expected = left[left.string_col != right.string_col] + + tm.assert_frame_equal(result, expected, check_index_type=False, check_dtype=False) + + +def test_interactive_repr_shows_error(alltypes): + # #591. Doing this in Postgres because so many built-in functions are + # not available + + expr = alltypes.int_col.convert_base(10, 2) + + with config.option_context("interactive", True): + result = repr(expr) + + assert "no translation rule" in result.lower() + + +def test_subquery(alltypes, df): + t = alltypes + + expr = t.mutate(d=t.double_col.fillna(0)).limit(1000).group_by("string_col").size() + result = expr.execute().sort_values("string_col").reset_index(drop=True) + expected = ( + df.assign(d=df.double_col.fillna(0)) + .head(1000) + .groupby("string_col") + .string_col.count() + .rename("CountStar()") + .reset_index() + .sort_values("string_col") + .reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +def test_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + df_f = getattr(df.double_col, func) + result = t.select((t.double_col - f()).name("double_col")).execute().double_col + expected = df.double_col - df_f() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_rolling_window(alltypes, func, df): + t = alltypes + df = ( + df[["double_col", "timestamp_col"]] + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + window = ibis.window(order_by=t.timestamp_col, preceding=6, following=0) + f = getattr(t.double_col, func) + df_f = getattr(df.double_col.rolling(7, min_periods=0), func) + result = t.select(f().over(window).name("double_col")).execute().double_col + expected = df_f() + tm.assert_series_equal(result, expected) + + +def test_rolling_window_with_mlb(alltypes): + t = alltypes + window = ibis.trailing_window( + preceding=ibis.rows_with_max_lookback(3, ibis.interval(days=5)), + order_by=t.timestamp_col, + ) + expr = t["double_col"].sum().over(window) + with pytest.raises(NotImplementedError): + expr.execute() + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_partitioned_window(alltypes, func, df): + t = alltypes + window = ibis.window( + group_by=t.string_col, + order_by=t.timestamp_col, + preceding=6, + following=0, + ) + + def roller(func): + def rolled(df): + torder = df.sort_values("timestamp_col") + rolling = torder.double_col.rolling(7, min_periods=0) + return getattr(rolling, func)() + + return rolled + + f = getattr(t.double_col, func) + expr = f().over(window).name("double_col") + result = t.select(expr).execute().double_col + expected = df.groupby("string_col").apply(roller(func)).reset_index(drop=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + col = t.double_col - f().over(ibis.cumulative_window()) + expr = t.select(col.name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_ordered_window(alltypes, func, df): + t = alltypes + df = df.sort_values("timestamp_col").reset_index(drop=True) + window = ibis.cumulative_window(order_by=t.timestamp_col) + f = getattr(t.double_col, func) + expr = t.select((t.double_col - f().over(window)).name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "shift_amount"), [("lead", -1), ("lag", 1)], ids=["lead", "lag"] +) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_analytic_shift_functions(alltypes, df, func, shift_amount): + method = getattr(alltypes.double_col, func) + expr = method(1) + result = expr.execute().rename("double_col") + expected = df.double_col.shift(shift_amount) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "expected_index"), [("first", -1), ("last", 0)], ids=["first", "last"] +) +@pytest.mark.xfail(reason="Unsupported expr: (first(t0.double_col) + 1) - 1") +def test_first_last_value(alltypes, df, func, expected_index): + col = alltypes.order_by(ibis.desc(alltypes.string_col)).double_col + method = getattr(col, func) + # test that we traverse into expression trees + expr = (1 + method()) - 1 + result = expr.execute() + expected = df.double_col.iloc[expected_index] + assert result == expected + + +def test_null_column(alltypes): + t = alltypes + nrows = t.count().execute() + expr = t.mutate(na_column=ibis.NA).na_column + result = expr.execute() + tm.assert_series_equal(result, pd.Series([None] * nrows, name="na_column")) + + +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_window_with_arithmetic(alltypes, df): + t = alltypes + w = ibis.window(order_by=t.timestamp_col) + expr = t.mutate(new_col=ibis.row_number().over(w) / 2) + + df = df[["timestamp_col"]].sort_values("timestamp_col").reset_index(drop=True) + expected = df.assign(new_col=[x / 2.0 for x in range(len(df))]) + result = expr["timestamp_col", "new_col"].execute() + tm.assert_frame_equal(result, expected) + + +def test_anonymous_aggregate(alltypes, df): + t = alltypes + expr = t[t.double_col > t.double_col.mean()] + result = expr.execute() + expected = df[df.double_col > df.double_col.mean()].reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def array_types(con): + return con.table("array_types") + + +@pytest.mark.xfail( + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype" +) +def test_array_length(array_types): + expr = array_types.select( + array_types.x.length().name("x_length"), + array_types.y.length().name("y_length"), + array_types.z.length().name("z_length"), + ) + result = expr.execute() + expected = pd.DataFrame( + { + "x_length": [3, 2, 2, 3, 3, 4], + "y_length": [3, 2, 2, 3, 3, 4], + "z_length": [3, 2, 2, 0, None, 4], + } + ) + result_sorted = result.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + expected_sorted = expected.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + tm.assert_frame_equal(result_sorted, expected_sorted) + + +def custom_sort_none_first(arr): + return sorted(arr, key=lambda x: (x is not None, x)) + + +def test_head(con): + t = con.table("functional_alltypes") + result = t.head().execute() + expected = t.limit(5).execute() + tm.assert_frame_equal(result, expected) + + +def test_identical_to(con, df): + # TODO: abstract this testing logic out into parameterized fixtures + t = con.table("functional_alltypes") + dt = df[["tinyint_col", "double_col"]] + expr = t.tinyint_col.identical_to(t.double_col) + result = expr.execute() + expected = (dt.tinyint_col.isnull() & dt.double_col.isnull()) | ( + dt.tinyint_col == dt.double_col + ) + expected.name = result.name + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["invert", "neg"]) +def test_not_and_negate_bool(con, opname, df): + op = getattr(operator, opname) + t = con.table("functional_alltypes").limit(10) + expr = t.select(op(t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = op(df.head(10).bool_col) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "field", + [ + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "year", + "month", + ], +) +def test_negate_non_boolean(con, field, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t[field]).name(field)) + result = expr.execute()[field] + expected = -df.head(10)[field] + tm.assert_series_equal(result, expected) + + +def test_negate_boolean(con, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = -df.head(10).bool_col + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["sum", "mean", "min", "max", "std", "var"]) +def test_boolean_reduction(alltypes, opname, df): + op = operator.methodcaller(opname) + expr = op(alltypes.bool_col) + result = expr.execute() + assert result == op(df.bool_col) + + +def test_timestamp_with_timezone(con): + t = con.table("tzone") + result = t.ts.execute() + assert str(result.dtype.tz) + + +@pytest.fixture( + params=[ + None, + "UTC", + "America/New_York", + "America/Los_Angeles", + "Europe/Paris", + "Chile/Continental", + "Asia/Tel_Aviv", + "Asia/Tokyo", + "Africa/Nairobi", + "Australia/Sydney", + ] +) +def tz(request): + return request.param + + +@pytest.fixture +def tzone_compute(con, temp_table, tz): + schema = ibis.schema([("ts", dt.Timestamp(tz)), ("b", "double"), ("c", "string")]) + con.create_table(temp_table, schema=schema, temp=False) + t = con.table(temp_table) + + n = 10 + df = pd.DataFrame( + { + "ts": pd.date_range("2017-04-01", periods=n, tz=tz).values, + "b": np.arange(n).astype("float64"), + "c": list(string.ascii_lowercase[:n]), + } + ) + + df.to_sql( + temp_table, + con.con, + index=False, + if_exists="append", + dtype={"ts": sa.TIMESTAMP(timezone=True), "b": sa.FLOAT, "c": sa.TEXT}, + ) + + yield t + con.drop_table(temp_table) + + +def test_ts_timezone_is_preserved(tzone_compute, tz): + assert dt.Timestamp(tz).equals(tzone_compute.ts.type()) + + +def test_timestamp_with_timezone_select(tzone_compute, tz): + ts = tzone_compute.ts.execute() + assert str(getattr(ts.dtype, "tz", None)) == str(tz) + + +@pytest.mark.parametrize( + ("left", "right", "type"), + [ + param( + L("2017-04-01 01:02:33"), + datetime(2017, 4, 1, 1, 3, 34), + dt.timestamp, + id="ibis_timestamp", + ), + param( + datetime(2017, 4, 1, 1, 3, 34), + L("2017-04-01 01:02:33"), + dt.timestamp, + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize("opname", ["eq", "ne", "lt", "le", "gt", "ge"]) +def test_string_temporal_compare(con, opname, left, right, type): + op = getattr(operator, opname) + expr = op(left, right) + result = con.execute(expr) + left_raw = con.execute(L(left).cast(type)) + right_raw = con.execute(L(right).cast(type)) + expected = op(left_raw, right_raw) + assert result == expected + + +@pytest.mark.parametrize( + ("left", "right"), + [ + param( + L("2017-03-31 00:02:33").cast(dt.timestamp), + datetime(2017, 4, 1, 1, 3, 34), + id="ibis_timestamp", + ), + param( + datetime(2017, 3, 31, 0, 2, 33), + L("2017-04-01 01:03:34").cast(dt.timestamp), + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize( + "op", + [ + param( + lambda left, right: ibis.timestamp("2017-04-01 00:02:34").between( + left, right + ), + id="timestamp", + ), + param( + lambda left, right: ( + ibis.timestamp("2017-04-01").cast(dt.date).between(left, right) + ), + id="date", + ), + ], +) +def test_string_temporal_compare_between(con, op, left, right): + expr = op(left, right) + result = con.execute(expr) + assert isinstance(result, (bool, np.bool_)) + assert result + + +@pytest.mark.xfail( + reason="function make_date(integer, integer, integer) does not exist" +) +def test_scalar_parameter(con): + start_string, end_string = "2009-03-01", "2010-07-03" + + start = ibis.param(dt.date) + end = ibis.param(dt.date) + t = con.table("functional_alltypes") + col = t.date_string_col.cast("date") + expr = col.between(start, end).name("res") + expected_expr = col.between(start_string, end_string).name("res") + + result = expr.execute(params={start: start_string, end: end_string}) + expected = expected_expr.execute() + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_cast(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary") + result = expr.execute() + name = expr.get_name() + sql_string = ( + f"SELECT decode(string_col, 'escape') AS \"{name}\" " + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + raw_data = [row[0][0] for row in cur] + expected = pd.Series(raw_data, name=name) + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_round_trip(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary").cast("string") + result = expr.execute() + name = expr.get_name() + sql_string = ( + "SELECT encode(decode(string_col, 'escape'), 'escape') AS " + f'"{name}"' + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + expected = pd.Series([row[0][0] for row in cur], name=name) + tm.assert_series_equal(result, expected) diff --git a/ibis/backends/risingwave/tests/test_json.py b/ibis/backends/risingwave/tests/test_json.py new file mode 100644 index 0000000000000..18edda8e37412 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_json.py @@ -0,0 +1,17 @@ +"""Tests for json data types.""" +from __future__ import annotations + +import json + +import pytest +from pytest import param + +import ibis + + +@pytest.mark.parametrize("data", [param({"status": True}, id="status")]) +def test_json(data, alltypes): + lit = ibis.literal(json.dumps(data), type="json").name("tmp") + expr = alltypes[[alltypes.id, lit]].head(1) + df = expr.execute() + assert df["tmp"].iloc[0] == data diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 81c6790b07b9e..6f3c4f869c4cf 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -49,6 +49,7 @@ def mean_udf(s): "bigquery", "datafusion", "postgres", + "risingwave", "clickhouse", "impala", "duckdb", @@ -205,6 +206,7 @@ def test_aggregate_grouped(backend, alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -518,39 +520,51 @@ def mean_and_std(v): lambda t, where: t.double_col.arbitrary(where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_default", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="first", where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_first", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="last", where=where), @@ -576,6 +590,10 @@ def mean_and_std(v): raises=com.UnsupportedOperationError, reason="backend only supports the `first` option for `.arbitrary()", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), ], ), param( @@ -602,7 +620,14 @@ def mean_and_std(v): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl( - ["bigquery", "duckdb", "postgres", "pyspark", "trino"], + [ + "bigquery", + "duckdb", + "postgres", + "risingwave", + "pyspark", + "trino", + ], raises=com.UnsupportedOperationError, reason="how='heavy' not supported in the backend", ), @@ -617,19 +642,31 @@ def mean_and_std(v): lambda t, where: t.double_col.first(where=where), lambda t, where: t.double_col[where].iloc[0], id="first", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.last(where=where), lambda t, where: t.double_col[where].iloc[-1], id="last", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.bigint_col.bit_and(where=where), @@ -900,6 +937,11 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): reason="backend doesn't implement approximate quantiles yet", raises=com.OperationNotDefinedError, ), + pytest.mark.broken( + ["risingwave"], + reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64", + raises=sa.exc.InternalError, + ), ], ), ], @@ -948,6 +990,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -963,6 +1010,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -983,6 +1035,16 @@ def test_quantile( raises=(ValueError, AttributeError), reason="ClickHouse only implements `sample` correlation coefficient", ), + pytest.mark.notyet( + ["pyspark"], + raises=ValueError, + reason="PySpark only implements sample correlation", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1009,7 +1071,7 @@ def test_quantile( reason="Correlation with how='sample' is not supported.", ), pytest.mark.notyet( - ["oracle"], + ["oracle", "risingwave"], raises=ValueError, reason="XXXXSQLExprTranslator only implements population correlation coefficient", ), @@ -1033,6 +1095,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1057,6 +1124,16 @@ def test_quantile( raises=ValueError, reason="ClickHouse only implements `sample` correlation coefficient", ), + pytest.mark.notyet( + ["pyspark"], + raises=ValueError, + reason="PySpark only implements sample correlation", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), ], @@ -1372,6 +1449,7 @@ def test_topk_filter_op(alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -1412,6 +1490,7 @@ def test_aggregate_list_like(backend, alltypes, df, agg_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 140454942d6e7..7989cd7b0d74b 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -9,6 +9,7 @@ import pandas.testing as tm import pytest import pytz +import sqlalchemy as sa import toolz from pytest import param @@ -131,6 +132,11 @@ def test_array_concat_variadic(con): # Issues #2370 @pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) @pytest.mark.notyet(["trino"], raises=TrinoUserError) +@pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: cannot determine type of empty array", +) def test_array_concat_some_empty(con): left = ibis.literal([]) right = ibis.literal([2, 1]) @@ -207,6 +213,11 @@ def test_array_index(con, idx): reason="backend does not support nullable nested types", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) @pytest.mark.never( ["bigquery"], reason="doesn't support arrays of arrays", raises=AssertionError ) @@ -238,6 +249,11 @@ def test_array_discovery(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_simple(backend): array_types = backend.array_types expected = ( @@ -255,6 +271,11 @@ def test_unnest_simple(backend): @builtin_array @pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_complex(backend): array_types = backend.array_types df = array_types.execute() @@ -293,6 +314,11 @@ def test_unnest_complex(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_idempotent(backend): array_types = backend.array_types df = array_types.execute() @@ -314,6 +340,11 @@ def test_unnest_idempotent(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_no_nulls(backend): array_types = backend.array_types df = array_types.execute() @@ -341,6 +372,11 @@ def test_unnest_no_nulls(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -376,6 +412,11 @@ def test_unnest_default_name(backend): ["datafusion", "flink"], raises=Exception, reason="array_types table isn't defined" ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_slice(backend, start, stop): array_types = backend.array_types expr = array_types.select(sliced=array_types.y[start:stop]) @@ -390,6 +431,11 @@ def test_array_slice(backend, start, stop): @pytest.mark.notimpl( ["datafusion", "polars", "snowflake", "sqlite"], raises=com.OperationNotDefinedError ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.notimpl( ["dask", "pandas"], raises=com.OperationNotDefinedError, @@ -417,6 +463,7 @@ def test_array_slice(backend, start, stop): ], ) def test_array_map(con, input, output): + t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) expected = pd.DataFrame(output) @@ -476,6 +523,11 @@ def test_array_filter(con, input, output): @builtin_array @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_contains(backend, con): t = backend.array_types expr = t.x.contains(1) @@ -499,6 +551,11 @@ def test_array_position(backend, con): @builtin_array @pytest.mark.notimpl(["dask", "polars"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) def test_array_remove(con): t = ibis.memtable({"a": [[3, 2], [], [42, 2], [2, 2], []]}) expr = t.a.remove(2) @@ -529,6 +586,11 @@ def test_array_remove(con): raises=(AssertionError, GoogleBadRequest), reason="bigquery doesn't support null elements in arrays", ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( ("input", "expected"), [ @@ -556,6 +618,11 @@ def test_array_unique(con, input, expected): @pytest.mark.notimpl( ["dask", "datafusion", "polars"], raises=com.OperationNotDefinedError ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735", +) def test_array_sort(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)}) expr = t.mutate(a=t.a.sort()).order_by("id") @@ -591,6 +658,11 @@ def test_array_union(con): @pytest.mark.notimpl( ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array..." ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( "data", [ @@ -629,6 +701,7 @@ def test_array_intersect(con, data): reason="ClickHouse won't accept dicts for struct type values", ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) +@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError @@ -646,9 +719,23 @@ def test_unnest_struct(con): @builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "druid", "oracle", "pandas", "polars", "postgres"], + [ + "dask", + "datafusion", + "druid", + "oracle", + "pandas", + "polars", + "postgres", + "risingwave", + ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_zip(backend): t = backend.array_types @@ -674,6 +761,7 @@ def test_zip(backend): reason="https://github.com/ClickHouse/ClickHouse/issues/41112", ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) +@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], @@ -731,7 +819,7 @@ def flatten_data(): ["bigquery"], reason="BigQuery doesn't support arrays of arrays", raises=TypeError ) @pytest.mark.notyet( - ["postgres"], + ["postgres", "risingwave"], reason="Postgres doesn't truly support arrays of arrays", raises=(com.OperationNotDefinedError, PsycoPg2IndeterminateDatatype), ) @@ -802,6 +890,7 @@ def test_range_single_argument(con, n): ) @pytest.mark.parametrize("n", [-2, 0, 2]) @pytest.mark.notimpl(["polars", "flink", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.skip("risingwave") def test_range_single_argument_unnest(backend, con, n): expr = ibis.range(n).unnest() result = con.execute(expr) @@ -848,6 +937,11 @@ def test_range_start_stop_step(con, start, stop, step): ["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream" ) @pytest.mark.notimpl(["flink", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Invalid parameter step: step size cannot equal zero", +) def test_range_start_stop_step_zero(con, start, stop): expr = ibis.range(start, stop, 0) result = con.execute(expr) @@ -954,6 +1048,11 @@ def swap(token): ibis.interval(hours=1), "1H", id="pos", + marks=pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), ), param( datetime(2017, 1, 2), @@ -964,7 +1063,12 @@ def swap(token): marks=[ pytest.mark.broken( ["polars"], raises=AssertionError, reason="returns an empty array" - ) + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), param( @@ -980,6 +1084,11 @@ def swap(token): ["clickhouse", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1005,7 +1114,14 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): datetime(2017, 1, 2, tzinfo=pytz.UTC), ibis.interval(hours=0), id="pos", - marks=[pytest.mark.notyet(["polars"], raises=PolarsComputeError)], + marks=[ + pytest.mark.notyet(["polars"], raises=PolarsComputeError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), + ], ), param( datetime(2017, 1, 1, tzinfo=pytz.UTC), @@ -1019,6 +1135,11 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): ["clickhouse", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1048,3 +1169,19 @@ def test_repr_timestamp_array(con, monkeypatch): expr = ibis.array(pd.date_range("2010-01-01", "2010-01-03", freq="D").tolist()) assert "Translation to backend failed" not in repr(expr) + + +@pytest.mark.notyet( + ["dask", "datafusion", "flink", "pandas", "polars"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.OperationalError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14734", +) +def test_unnest_range(con): + expr = ibis.range(2).unnest().name("x").as_table().mutate({"y": 1.0}) + result = con.execute(expr) + expected = pd.DataFrame({"x": np.array([0, 1], dtype="int8"), "y": [1.0, 1.0]}) + tm.assert_frame_equal(result, expected) diff --git a/ibis/backends/tests/test_benchmarks.py b/ibis/backends/tests/test_benchmarks.py deleted file mode 100644 index 4805a5ccf5bc1..0000000000000 --- a/ibis/backends/tests/test_benchmarks.py +++ /dev/null @@ -1,900 +0,0 @@ -from __future__ import annotations - -import copy -import functools -import inspect -import itertools -import os -import string - -import numpy as np -import pandas as pd -import pytest -import sqlalchemy as sa -from packaging.version import parse as vparse - -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis.backends.base import _get_backend_names - -# from ibis.backends.pandas.udf import udf - -# FIXME(kszucs): pytestmark = pytest.mark.benchmark -pytestmark = pytest.mark.skip(reason="the backends must be rewritten first") - - -def make_t(): - return ibis.table( - [ - ("_timestamp", "int32"), - ("dim1", "int32"), - ("dim2", "int32"), - ("valid_seconds", "int32"), - ("meas1", "int32"), - ("meas2", "int32"), - ("year", "int32"), - ("month", "int32"), - ("day", "int32"), - ("hour", "int32"), - ("minute", "int32"), - ], - name="t", - ) - - -@pytest.fixture(scope="module") -def t(): - return make_t() - - -def make_base(t): - return t[ - ( - (t.year > 2016) - | ((t.year == 2016) & (t.month > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour > 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute >= 5) - ) - ) - & ( - (t.year < 2016) - | ((t.year == 2016) & (t.month < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour < 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute <= 5) - ) - ) - ] - - -@pytest.fixture(scope="module") -def base(t): - return make_base(t) - - -def make_large_expr(base): - src_table = base - src_table = src_table.mutate( - _timestamp=(src_table["_timestamp"] - src_table["_timestamp"] % 3600) - .cast("int32") - .name("_timestamp"), - valid_seconds=300, - ) - - aggs = [] - for meas in ["meas1", "meas2"]: - aggs.append(src_table[meas].sum().cast("float").name(meas)) - src_table = src_table.aggregate( - aggs, by=["_timestamp", "dim1", "dim2", "valid_seconds"] - ) - - part_keys = ["year", "month", "day", "hour", "minute"] - ts_col = src_table["_timestamp"].cast("timestamp") - new_cols = {} - for part_key in part_keys: - part_col = getattr(ts_col, part_key)() - new_cols[part_key] = part_col - src_table = src_table.mutate(**new_cols) - return src_table[ - [ - "_timestamp", - "dim1", - "dim2", - "meas1", - "meas2", - "year", - "month", - "day", - "hour", - "minute", - ] - ] - - -@pytest.fixture(scope="module") -def large_expr(base): - return make_large_expr(base) - - -@pytest.mark.benchmark(group="construction") -@pytest.mark.parametrize( - "construction_fn", - [ - pytest.param(lambda *_: make_t(), id="small"), - pytest.param(lambda t, *_: make_base(t), id="medium"), - pytest.param(lambda _, base: make_large_expr(base), id="large"), - ], -) -def test_construction(benchmark, construction_fn, t, base): - benchmark(construction_fn, t, base) - - -@pytest.mark.benchmark(group="builtins") -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -@pytest.mark.parametrize("builtin", [hash, str]) -def test_builtins(benchmark, expr_fn, builtin, t, base, large_expr): - expr = expr_fn(t, base, large_expr) - benchmark(builtin, expr) - - -_backends = set(_get_backend_names()) -# compile is a no-op -_backends.remove("pandas") - -_XFAIL_COMPILE_BACKENDS = {"dask", "pyspark", "polars"} - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -def test_compile(benchmark, module, expr_fn, t, base, large_expr): - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - expr = expr_fn(t, base, large_expr) - try: - benchmark(mod.compile, expr) - except (sa.exc.NoSuchModuleError, ImportError) as e: # delayed imports - pytest.skip(str(e)) - - -@pytest.fixture(scope="module") -def pt(): - n = 60_000 - data = pd.DataFrame( - { - "key": np.random.choice(16000, size=n), - "low_card_key": np.random.choice(30, size=n), - "value": np.random.rand(n), - "timestamps": pd.date_range( - start="2023-05-05 16:37:57", periods=n, freq="s" - ).values, - "timestamp_strings": pd.date_range( - start="2023-05-05 16:37:39", periods=n, freq="s" - ).values.astype(str), - "repeated_timestamps": pd.date_range(start="2018-09-01", periods=30).repeat( - int(n / 30) - ), - } - ) - - return ibis.pandas.connect(dict(df=data)).table("df") - - -def high_card_group_by(t): - return t.group_by(t.key).aggregate(avg_value=t.value.mean()) - - -def cast_to_dates(t): - return t.timestamps.cast(dt.date) - - -def cast_to_dates_from_strings(t): - return t.timestamp_strings.cast(dt.date) - - -def multikey_group_by_with_mutate(t): - return ( - t.mutate(dates=t.timestamps.cast("date")) - .group_by(["low_card_key", "dates"]) - .aggregate(avg_value=lambda t: t.value.mean()) - ) - - -def simple_sort(t): - return t.order_by([t.key]) - - -def simple_sort_projection(t): - return t[["key", "value"]].order_by(["key"]) - - -def multikey_sort(t): - return t.order_by(["low_card_key", "key"]) - - -def multikey_sort_projection(t): - return t[["low_card_key", "key", "value"]].order_by(["low_card_key", "key"]) - - -def low_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.low_card_key, - ) - - -def low_card_grouped_rolling(t): - return t.value.mean().over(low_card_rolling_window(t)) - - -def high_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.key, - ) - - -def high_card_grouped_rolling(t): - return t.value.mean().over(high_card_rolling_window(t)) - - -# @udf.reduction(["double"], "double") -# def my_mean(series): -# return series.mean() - - -def low_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(high_card_rolling_window(t)) - - -# @udf.analytic(["double"], "double") -# def my_zscore(series): -# return (series - series.mean()) / series.std() - - -def low_card_window(t): - return ibis.window(group_by=t.low_card_key) - - -def high_card_window(t): - return ibis.window(group_by=t.key) - - -def low_card_window_analytics_udf(t): - return my_zscore(t.value).over(low_card_window(t)) - - -def high_card_window_analytics_udf(t): - return my_zscore(t.value).over(high_card_window(t)) - - -# @udf.reduction(["double", "double"], "double") -# def my_wm(v, w): -# return np.average(v, weights=w) - - -def low_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -broken_pandas_grouped_rolling = pytest.mark.xfail( - condition=vparse("1.4") <= vparse(pd.__version__) < vparse("1.4.2"), - raises=ValueError, - reason="https://github.com/pandas-dev/pandas/pull/44068", -) - - -@pytest.mark.benchmark(group="execution") -@pytest.mark.parametrize( - "expression_fn", - [ - pytest.param(high_card_group_by, id="high_card_group_by"), - pytest.param(cast_to_dates, id="cast_to_dates"), - pytest.param(cast_to_dates_from_strings, id="cast_to_dates_from_strings"), - pytest.param(multikey_group_by_with_mutate, id="multikey_group_by_with_mutate"), - pytest.param(simple_sort, id="simple_sort"), - pytest.param(simple_sort_projection, id="simple_sort_projection"), - pytest.param(multikey_sort, id="multikey_sort"), - pytest.param(multikey_sort_projection, id="multikey_sort_projection"), - pytest.param( - low_card_grouped_rolling, - id="low_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling, - id="high_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - low_card_grouped_rolling_udf_mean, - id="low_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_mean, - id="high_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param(low_card_window_analytics_udf, id="low_card_window_analytics_udf"), - pytest.param( - high_card_window_analytics_udf, id="high_card_window_analytics_udf" - ), - pytest.param( - low_card_grouped_rolling_udf_wm, - id="low_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_wm, - id="high_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - ], -) -def test_execute(benchmark, expression_fn, pt): - expr = expression_fn(pt) - benchmark(expr.execute) - - -@pytest.fixture(scope="module") -def part(): - return ibis.table( - dict( - p_partkey="int64", - p_size="int64", - p_type="string", - p_mfgr="string", - ), - name="part", - ) - - -@pytest.fixture(scope="module") -def supplier(): - return ibis.table( - dict( - s_suppkey="int64", - s_nationkey="int64", - s_name="string", - s_acctbal="decimal(15, 3)", - s_address="string", - s_phone="string", - s_comment="string", - ), - name="supplier", - ) - - -@pytest.fixture(scope="module") -def partsupp(): - return ibis.table( - dict( - ps_partkey="int64", - ps_suppkey="int64", - ps_supplycost="decimal(15, 3)", - ), - name="partsupp", - ) - - -@pytest.fixture(scope="module") -def nation(): - return ibis.table( - dict(n_nationkey="int64", n_regionkey="int64", n_name="string"), - name="nation", - ) - - -@pytest.fixture(scope="module") -def region(): - return ibis.table(dict(r_regionkey="int64", r_name="string"), name="region") - - -@pytest.fixture(scope="module") -def tpc_h02(part, supplier, partsupp, nation, region): - REGION = "EUROPE" - SIZE = 25 - TYPE = "BRASS" - - expr = ( - part.join(partsupp, part.p_partkey == partsupp.ps_partkey) - .join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = ( - partsupp.join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = subexpr[ - (subexpr.r_name == REGION) & (expr.p_partkey == subexpr.ps_partkey) - ] - - filters = [ - expr.p_size == SIZE, - expr.p_type.like(f"%{TYPE}"), - expr.r_name == REGION, - expr.ps_supplycost == subexpr.ps_supplycost.min(), - ] - q = expr.filter(filters) - - q = q.select( - [ - q.s_acctbal, - q.s_name, - q.n_name, - q.p_partkey, - q.p_mfgr, - q.s_address, - q.s_phone, - q.s_comment, - ] - ) - - return q.order_by( - [ - ibis.desc(q.s_acctbal), - q.n_name, - q.s_name, - q.p_partkey, - ] - ).limit(100) - - -@pytest.mark.benchmark(group="repr") -def test_repr_tpc_h02(benchmark, tpc_h02): - benchmark(repr, tpc_h02) - - -@pytest.mark.benchmark(group="repr") -def test_repr_huge_union(benchmark): - n = 10 - raw_types = [ - "int64", - "float64", - "string", - "array, b: map>>>", - ] - tables = [ - ibis.table( - list(zip(string.ascii_letters, itertools.cycle(raw_types))), - name=f"t{i:d}", - ) - for i in range(n) - ] - expr = functools.reduce(ir.Table.union, tables) - benchmark(repr, expr) - - -@pytest.mark.benchmark(group="node_args") -def test_op_argnames(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.argnames, expr.op()) - - -@pytest.mark.benchmark(group="node_args") -def test_op_args(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.args, expr.op()) - - -@pytest.mark.benchmark(group="datatype") -def test_complex_datatype_parse(benchmark): - type_str = "array, b: map>>>" - expected = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - assert dt.parse(type_str) == expected - benchmark(dt.parse, type_str) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize("func", [str, hash]) -def test_complex_datatype_builtins(benchmark, func): - datatype = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - benchmark(func, datatype) - - -@pytest.mark.benchmark(group="equality") -def test_large_expr_equals(benchmark, tpc_h02): - benchmark(ir.Expr.equals, tpc_h02, copy.deepcopy(tpc_h02)) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize( - "dtypes", - [ - pytest.param( - [ - obj - for _, obj in inspect.getmembers( - dt, - lambda obj: isinstance(obj, dt.DataType), - ) - ], - id="singletons", - ), - pytest.param( - dt.Array( - dt.Struct( - dict( - a=dt.Array(dt.string), - b=dt.Map(dt.string, dt.Array(dt.int64)), - ) - ) - ), - id="complex", - ), - ], -) -def test_eq_datatypes(benchmark, dtypes): - def eq(a, b): - assert a == b - - benchmark(eq, dtypes, copy.deepcopy(dtypes)) - - -def multiple_joins(table, num_joins): - for _ in range(num_joins): - table = table.mutate(dummy=ibis.literal("")) - table = table.left_join(table, ["dummy"])[[table]] - - -@pytest.mark.parametrize("num_joins", [1, 10]) -@pytest.mark.parametrize("num_columns", [1, 10, 100]) -def test_multiple_joins(benchmark, num_joins, num_columns): - table = ibis.table( - {f"col_{i:d}": "string" for i in range(num_columns)}, - name="t", - ) - benchmark(multiple_joins, table, num_joins) - - -@pytest.fixture -def customers(): - return ibis.table( - dict( - customerid="int32", - name="string", - address="string", - citystatezip="string", - birthdate="date", - phone="string", - timezone="string", - lat="float64", - long="float64", - ), - name="customers", - ) - - -@pytest.fixture -def orders(): - return ibis.table( - dict( - orderid="int32", - customerid="int32", - ordered="timestamp", - shipped="timestamp", - items="string", - total="float64", - ), - name="orders", - ) - - -@pytest.fixture -def orders_items(): - return ibis.table( - dict(orderid="int32", sku="string", qty="int32", unit_price="float64"), - name="orders_items", - ) - - -@pytest.fixture -def products(): - return ibis.table( - dict( - sku="string", - desc="string", - weight_kg="float64", - cost="float64", - dims_cm="string", - ), - name="products", - ) - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -def test_compile_with_drops( - benchmark, module, customers, orders, orders_items, products -): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - .drop("dims_cm", "cost") - .mutate(o_date=lambda t: t.shipped.date()) - .filter(lambda t: t.ordered == t.shipped) - ) - - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - try: - benchmark(mod.compile, expr) - except sa.exc.NoSuchModuleError as e: - pytest.skip(str(e)) - - -def test_repr_join(benchmark, customers, orders, orders_items, products): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - ) - op = expr.op() - benchmark(repr, op) - - -@pytest.mark.parametrize("overwrite", [True, False], ids=["overwrite", "no_overwrite"]) -def test_insert_duckdb(benchmark, overwrite, tmp_path): - pytest.importorskip("duckdb") - - n_rows = int(1e4) - table_name = "t" - schema = ibis.schema(dict(a="int64", b="int64", c="int64")) - t = ibis.memtable(dict.fromkeys(list("abc"), range(n_rows)), schema=schema) - - con = ibis.duckdb.connect(tmp_path / "test_insert.ddb") - con.create_table(table_name, schema=schema) - benchmark(con.insert, table_name, t, overwrite=overwrite) - - -def test_snowflake_medium_sized_to_pandas(benchmark): - pytest.importorskip("snowflake.connector") - - if (url := os.environ.get("SNOWFLAKE_URL")) is None: - pytest.skip("SNOWFLAKE_URL environment variable not set") - - con = ibis.connect(url) - - # LINEITEM at scale factor 1 is around 6MM rows, but we limit to 1,000,000 - # to make the benchmark fast enough for development, yet large enough to show a - # difference if there's a performance hit - lineitem = con.table("LINEITEM", schema="SNOWFLAKE_SAMPLE_DATA.TPCH_SF1").limit( - 1_000_000 - ) - - benchmark.pedantic(lineitem.to_pandas, rounds=5, iterations=1, warmup_rounds=1) - - -def test_parse_many_duckdb_types(benchmark): - parse = pytest.importorskip("ibis.backends.duckdb.datatypes").DuckDBType.from_string - - def parse_many(types): - list(map(parse, types)) - - types = ["VARCHAR", "INTEGER", "DOUBLE", "BIGINT"] * 1000 - benchmark(parse_many, types) - - -@pytest.fixture(scope="session") -def sql() -> str: - return """ - SELECT t1.id as t1_id, x, t2.id as t2_id, y - FROM t1 INNER JOIN t2 - ON t1.id = t2.id - """ - - -@pytest.fixture(scope="session") -def ddb(tmp_path_factory): - duckdb = pytest.importorskip("duckdb") - - N = 20_000_000 - - con = duckdb.connect() - - path = str(tmp_path_factory.mktemp("duckdb") / "data.ddb") - sql = ( - lambda var, table, n=N: f""" - CREATE TABLE {table} AS - SELECT ROW_NUMBER() OVER () AS id, {var} - FROM ( - SELECT {var} - FROM RANGE({n}) _ ({var}) - ORDER BY RANDOM() - ) - """ - ) - - with duckdb.connect(path) as con: - con.execute(sql("x", table="t1")) - con.execute(sql("y", table="t2")) - return path - - -def test_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - # yes, we're benchmarking duckdb here, not ibis - # - # we do this to get a baseline for comparison - duckdb = pytest.importorskip("duckdb") - con = duckdb.connect(ddb, read_only=True) - - benchmark(lambda sql: con.sql(sql).to_arrow_table(), sql) - - -def test_ibis_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect(ddb, read_only=True) - - expr = con.sql(sql) - benchmark(expr.to_pyarrow) - - -@pytest.fixture -def diffs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "difference": "float64", - "pct_difference": "float64", - "pct_threshold": "float64", - "validation_status": "string", - }, - name="diffs", - ) - - -@pytest.fixture -def srcs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "validation_type": "string", - "aggregation_type": "string", - "table_name": "string", - "column_name": "string", - "primary_keys": "string", - "num_random_rows": "string", - "agg_value": "float64", - }, - name="srcs", - ) - - -@pytest.fixture -def nrels(): - return 300 - - -def make_big_union(t, nrels): - return ibis.union(*[t] * nrels) - - -@pytest.fixture -def src(srcs, nrels): - return make_big_union(srcs, nrels) - - -@pytest.fixture -def diff(diffs, nrels): - return make_big_union(diffs, nrels) - - -def test_big_eq_expr(benchmark, src, diff): - benchmark(ops.core.Node.equals, src.op(), diff.op()) - - -def test_big_join_expr(benchmark, src, diff): - benchmark(ir.Table.join, src, diff, ["validation_name"], how="outer") - - -def test_big_join_execute(benchmark, nrels): - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect() - - # cache to avoid a request-per-union operand - src = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580336/source_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - diff = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580340/differences_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - expr = src.join(diff, ["validation_name"], how="outer") - t = benchmark.pedantic(expr.to_pyarrow, rounds=1, iterations=1, warmup_rounds=1) - assert len(t) diff --git a/ibis/backends/tests/test_binary.py b/ibis/backends/tests/test_binary.py index 0a5790c646316..1d9f7cfa0516f 100644 --- a/ibis/backends/tests/test_binary.py +++ b/ibis/backends/tests/test_binary.py @@ -15,6 +15,7 @@ "sqlite": "blob", "trino": "varbinary", "postgres": "bytea", + "risingwave": "bytea", "flink": "BINARY(1) NOT NULL", } diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3f2336cafe9b6..2d8674eb83843 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -306,6 +306,11 @@ def tmpcon(alchemy_con): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_create_temporary_table_from_schema(tmpcon, new_schema): temp_table = f"_{guid()}" table = tmpcon.create_table(temp_table, schema=new_schema, temp=True) @@ -338,6 +343,7 @@ def test_create_temporary_table_from_schema(tmpcon, new_schema): "pandas", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -367,6 +373,11 @@ def test_rename_table(con, temp_table, temp_table_orig): raises=com.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason='Feature is not yet implemented: column constraints "NOT NULL"', +) def test_nullable_input_output(con, temp_table): sch = ibis.schema( [("foo", "int64"), ("bar", dt.int64(nullable=False)), ("baz", "boolean")] @@ -412,7 +423,7 @@ def test_create_drop_view(ddl_con, temp_view): assert set(t_expr.schema().names) == set(v_expr.schema().names) -@mark.notimpl(["postgres", "polars"]) +@mark.notimpl(["postgres", "risingwave", "polars"]) @mark.notimpl( ["datafusion"], raises=NotImplementedError, @@ -622,6 +633,7 @@ def test_list_databases(alchemy_con): test_databases = { "sqlite": {"main"}, "postgres": {"postgres", "ibis_testing"}, + "risingwave": {"dev"}, "mssql": {"ibis_testing"}, "mysql": {"ibis_testing", "information_schema"}, "duckdb": {"memory"}, @@ -634,7 +646,7 @@ def test_list_databases(alchemy_con): @pytest.mark.never( - ["bigquery", "postgres", "mssql", "mysql", "oracle"], + ["bigquery", "postgres", "risingwave", "mssql", "mysql", "oracle"], reason="backend does not support client-side in-memory tables", raises=(sa.exc.OperationalError, TypeError, sa.exc.InterfaceError), ) @@ -707,6 +719,11 @@ def test_unsigned_integer_type(alchemy_con, alchemy_temp_table): marks=mark.postgres, id="postgresql", ), + param( + "postgresql://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), param( "pyspark://?spark.app.name=test-pyspark", marks=[ @@ -1123,6 +1140,11 @@ def test_set_backend_name(name, monkeypatch): marks=mark.postgres, id="postgres", ), + param( + "postgres://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), ], ) def test_set_backend_url(url, monkeypatch): @@ -1147,6 +1169,7 @@ def test_set_backend_url(url, monkeypatch): "pandas", "polars", "postgres", + "risingwave", "pyspark", "sqlite", ], @@ -1183,6 +1206,11 @@ def test_create_table_timestamp(con, temp_table): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_ref_count(backend, con, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation") persisted_table = non_persisted_table.cache() @@ -1203,6 +1231,11 @@ def test_persist_expression_ref_count(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression(backend, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation", other_calc="xyz") persisted_table = non_persisted_table.cache() @@ -1217,6 +1250,11 @@ def test_persist_expression(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager(backend, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc" @@ -1233,6 +1271,11 @@ def test_persist_expression_contextmanager(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1251,6 +1294,11 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): ["mssql"], reason="mssql supports support temporary tables through naming conventions", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") def test_persist_expression_multiple_refs(backend, con, alltypes): non_cached_table = alltypes.mutate( @@ -1288,6 +1336,11 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_repeated_cache(alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1307,6 +1360,11 @@ def test_persist_expression_repeated_cache(alltypes): ["oracle"], reason="Oracle error message for a missing table/view doesn't include the name of the table", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" @@ -1391,6 +1449,11 @@ def test_create_schema(con_create_schema): assert schema not in con_create_schema.list_schemas() +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: information_schema.schemata is not supported,", +) def test_list_schemas(con_create_schema): schemas = con_create_schema.list_schemas() assert len(schemas) == len(set(schemas)) diff --git a/ibis/backends/tests/test_column.py b/ibis/backends/tests/test_column.py index f26b2a876ded0..f6b4bd8ee0f41 100644 --- a/ibis/backends/tests/test_column.py +++ b/ibis/backends/tests/test_column.py @@ -19,6 +19,7 @@ "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "trino", diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index df67a506769cf..4a020ea989496 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -231,7 +231,7 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): backend.assert_series_equal(foo2.x.execute(), expected2) -_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink"} +_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink", "risingwave"} no_sqlglot_dialect = sorted( # TODO(cpcloud): remove the strict=False hack once backends are ported to # sqlglot @@ -248,6 +248,11 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): ], ) @pytest.mark.notyet(["polars"], raises=PolarsComputeError) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @table_dot_sql_notimpl @dot_sql_notimpl @dot_sql_notyet @@ -279,6 +284,11 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): @pytest.mark.notyet( ["oracle"], strict=False, reason="only works with backends that quote everything" ) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @dot_sql_notimpl @dot_sql_never def test_con_dot_sql_transpile(backend, con, dialect, df): @@ -296,6 +306,11 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): @dot_sql_never @pytest.mark.notimpl(["druid", "flink", "polars"]) @pytest.mark.notyet(["snowflake"], reason="snowflake column names are case insensitive") +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_order_by_no_projection(backend): con = backend.connection astronauts = con.table("astronauts") diff --git a/ibis/backends/tests/test_examples.py b/ibis/backends/tests/test_examples.py index b73ea8e2f3da8..113fe70102a45 100644 --- a/ibis/backends/tests/test_examples.py +++ b/ibis/backends/tests/test_examples.py @@ -16,7 +16,7 @@ reason="nix on linux cannot download duckdb extensions or data due to sandboxing", ) @pytest.mark.notimpl(["dask", "exasol", "pyspark"]) -@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino"]) +@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino", "risingwave"]) @pytest.mark.parametrize( ("example", "columns"), [ diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 658850c6096e6..c6346b645a97d 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -4,6 +4,7 @@ import pyarrow as pa import pyarrow.csv as pcsv import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -254,6 +255,7 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -348,6 +350,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): marks=[ pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(38,9)", + ), ], ), param( @@ -369,6 +376,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): ), pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(76,38)", + ), ], ), ], @@ -390,6 +402,7 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "bigquery", @@ -488,7 +501,22 @@ def test_to_pandas_batches_empty_table(backend, con): @pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize("n", [None, 1]) +@pytest.mark.parametrize( + "n", + [ + param( + None, + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), + 1, + ], +) def test_to_pandas_batches_nonempty_table(backend, con, n): t = backend.functional_alltypes.limit(n) n = t.count().execute() @@ -498,7 +526,24 @@ def test_to_pandas_batches_nonempty_table(backend, con, n): @pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize("n", [None, 0, 1, 2]) +@pytest.mark.parametrize( + "n", + [ + param( + None, + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), + 0, + 1, + 2, + ], +) def test_to_pandas_batches_column(backend, con, n): t = backend.functional_alltypes.limit(n).timestamp_col n = t.count().execute() diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 241c57ce88705..1ff8122967f12 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa import toolz from pytest import param @@ -43,6 +44,7 @@ "sqlite": "null", "trino": "unknown", "postgres": "null", + "risingwave": "null", } @@ -66,6 +68,7 @@ def test_null_literal(con, backend): "trino": "boolean", "duckdb": "BOOLEAN", "postgres": "boolean", + "risingwave": "boolean", "flink": "BOOLEAN NOT NULL", } @@ -150,6 +153,7 @@ def test_isna(backend, alltypes, col, value, filt): "duckdb", "impala", "postgres", + "risingwave", "mysql", "snowflake", "polars", @@ -307,6 +311,7 @@ def test_filter(backend, alltypes, sorted_df, predicate_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -540,6 +545,11 @@ def test_order_by(backend, alltypes, df, key, df_kwargs): @pytest.mark.notimpl(["dask", "pandas", "polars", "mssql", "druid"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_order_by_random(alltypes): expr = alltypes.filter(_.id < 100).order_by(ibis.random()).limit(5) r1 = expr.execute() @@ -761,6 +771,11 @@ def test_correlated_subquery(alltypes): @pytest.mark.notimpl(["polars", "pyspark"]) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason='DataFrame.iloc[:, 0] (column name="playerID") are different', +) def test_uncorrelated_subquery(backend, batting, batting_df): subset_batting = batting[batting.yearID <= 2000] expr = batting[_.yearID == subset_batting.yearID.max()]["playerID", "yearID"] @@ -833,6 +848,11 @@ def test_typeof(con): @pytest.mark.notimpl(["pyspark"], condition=is_older_than("pyspark", "3.5.0")) @pytest.mark.notyet(["dask"], reason="not supported by the backend") @pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="not supported by exasol") +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="https://github.com/risingwavelabs/risingwave/issues/1343", +) def test_isin_uncorrelated( backend, batting, awards_players, batting_df, awards_players_df ): @@ -985,6 +1005,11 @@ def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns) ) @pytest.mark.notimpl(["druid", "flink"], reason="no sqlglot dialect", raises=ValueError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_many_subqueries(con, snapshot): def query(t, group_cols): t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) @@ -1012,6 +1037,11 @@ def query(t, group_cols): reason="invalid code generated for unnesting a struct", raises=TrinoUserError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected ), found: TEXT at line:3, column:219 Near "))]) AS anon_1(f1"', +) def test_pivot_longer(backend): diamonds = backend.diamonds df = diamonds.execute() @@ -1126,6 +1156,11 @@ def test_pivot_wider(backend): ["exasol"], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function last(double precision) does not exist, do you mean left or least", +) def test_distinct_on_keep(backend, on, keep): from ibis import _ @@ -1191,6 +1226,11 @@ def test_distinct_on_keep(backend, on, keep): raises=com.OperationNotDefinedError, reason="backend doesn't implement deduplication", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function first(double precision) does not exist", +) def test_distinct_on_keep_is_none(backend, on): from ibis import _ @@ -1209,7 +1249,7 @@ def test_distinct_on_keep_is_none(backend, on): assert len(result) == len(expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "flink", "exasol"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "flink", "exasol"]) @pytest.mark.notyet( [ "sqlite", @@ -1226,7 +1266,7 @@ def test_hash_consistent(backend, alltypes): assert h1.dtype in ("i8", "uint64") # polars likes returning uint64 for this -@pytest.mark.notimpl(["pandas", "dask", "oracle", "snowflake", "sqlite"]) +@pytest.mark.notimpl(["pandas", "dask", "oracle", "risingwave", "snowflake", "sqlite"]) @pytest.mark.parametrize( ("from_val", "to_type", "expected"), [ @@ -1275,6 +1315,7 @@ def test_try_cast(con, from_val, to_type, expected): "oracle", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", ] @@ -1312,6 +1353,7 @@ def test_try_cast_null(con, from_val, to_type): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "exasol", @@ -1337,6 +1379,7 @@ def test_try_cast_table(backend, con): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "exasol", @@ -1377,9 +1420,31 @@ def test_try_cast_func(con, from_val, to_type, func): ### NONE/ZERO start # no stop param(slice(None, 0), lambda _: 0, id="[:0]"), - param(slice(None, None), lambda t: t.count().to_pandas(), id="[:]"), + param( + slice(None, None), + lambda t: t.count().to_pandas(), + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], + id="[:]", + ), param(slice(0, 0), lambda _: 0, id="[0:0]"), - param(slice(0, None), lambda t: t.count().to_pandas(), id="[0:]"), + param( + slice(0, None), + lambda t: t.count().to_pandas(), + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], + id="[0:]", + ), # positive stop param(slice(None, 2), lambda _: 2, id="[:2]"), param(slice(0, 2), lambda _: 2, id="[0:2]"), @@ -1434,6 +1499,11 @@ def test_try_cast_func(con, from_val, to_type, func): reason="impala doesn't support OFFSET without ORDER BY", ), pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), ], ), # positive stop @@ -1520,6 +1590,11 @@ def test_static_table_slice(backend, slc, expected_count_fn): raises=com.UnsupportedArgumentError, reason="Removed half-baked dynamic offset functionality for now", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) @pytest.mark.notyet( ["trino"], raises=TrinoUserError, @@ -1610,6 +1685,11 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): reason="doesn't support dynamic limit/offset; compiles incorrectly in sqlglot", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) def test_dynamic_table_slice_with_computed_offset(backend): t = backend.functional_alltypes @@ -1629,6 +1709,11 @@ def test_dynamic_table_slice_with_computed_offset(backend): @pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample(backend): t = backend.functional_alltypes.filter(_.int_col >= 2) @@ -1645,6 +1730,11 @@ def test_sample(backend): @pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample_memtable(con, backend): df = pd.DataFrame({"x": [1, 2, 3, 4]}) res = con.execute(ibis.memtable(df).sample(0.5)) @@ -1665,6 +1755,7 @@ def test_sample_memtable(con, backend): "oracle", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -1700,6 +1791,11 @@ def test_substitute(backend): ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) @pytest.mark.notimpl(["flink"], reason="no sqlglot dialect", raises=ValueError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a diff --git a/ibis/backends/tests/test_json.py b/ibis/backends/tests/test_json.py index 78d379ae0bde4..98c72e6934d54 100644 --- a/ibis/backends/tests/test_json.py +++ b/ibis/backends/tests/test_json.py @@ -40,13 +40,17 @@ ["flink"], reason="https://github.com/ibis-project/ibis/pull/6920#discussion_r1373212503", ) +@pytest.mark.broken( + ["risingwave"], + reason="TODO(Kexiang): order mismatch in array", +) def test_json_getitem(json_t, expr_fn, expected): expr = expr_fn(json_t) result = expr.execute() tm.assert_series_equal(result.fillna(pd.NA), expected.fillna(pd.NA)) -@pytest.mark.notimpl(["dask", "mysql", "pandas"]) +@pytest.mark.notimpl(["dask", "mysql", "pandas", "risingwave"]) @pytest.mark.notyet(["bigquery", "sqlite"], reason="doesn't support maps") @pytest.mark.notyet(["postgres"], reason="only supports map") @pytest.mark.notyet( @@ -70,7 +74,7 @@ def test_json_map(backend, json_t): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "mysql", "pandas"]) +@pytest.mark.notimpl(["dask", "mysql", "pandas", "risingwave"]) @pytest.mark.notyet(["sqlite"], reason="doesn't support arrays") @pytest.mark.notyet( ["pyspark", "trino", "flink"], reason="should work but doesn't deserialize JSON" diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index cb0386dc0ff83..6c7a5e717f12a 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -34,6 +35,11 @@ def test_map_table(backend): @pytest.mark.xfail_version( duckdb=["duckdb<0.8.0"], raises=exc.UnsupportedOperationError ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_column_map_values(backend): table = backend.map expr = table.select("idx", vals=table.kv.values()).order_by("idx") @@ -64,6 +70,11 @@ def test_column_map_merge(backend): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_keys(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.keys().name("tmp") @@ -79,6 +90,11 @@ def test_literal_map_keys(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_values(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.values().name("tmp") @@ -87,7 +103,7 @@ def test_literal_map_values(con): assert np.array_equal(result, ["a", "b"]) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -103,7 +119,9 @@ def test_scalar_isin_literal_map_keys(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -124,6 +142,11 @@ def test_map_scalar_contains_key_scalar(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_scalar_contains_key_column(backend, alltypes, df): value = {"1": "a", "3": "c"} mapping = ibis.literal(value) @@ -133,7 +156,9 @@ def test_map_scalar_contains_key_column(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -149,7 +174,9 @@ def test_map_column_contains_key_scalar(backend, alltypes, df): backend.assert_series_equal(result, series) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -164,7 +191,9 @@ def test_map_column_contains_key_column(alltypes): assert result.all() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -183,6 +212,11 @@ def test_literal_map_merge(con): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_getitem_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -200,6 +234,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_get_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -220,19 +259,27 @@ def test_literal_map_get_broadcast(backend, alltypes, df): [1, 2], id="string", marks=pytest.mark.notyet( - ["postgres"], reason="only support maps of string -> string" + ["postgres", "risingwave"], + reason="only support maps of string -> string", ), ), param(["a", "b"], ["1", "2"], id="int"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_construct_dict(con, keys, values): expr = ibis.map(keys, values) result = con.execute(expr.name("tmp")) assert result == dict(zip(keys, values)) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -246,7 +293,9 @@ def test_map_construct_array_column(con, alltypes, df): assert result.to_list() == expected.to_list() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -258,7 +307,9 @@ def test_map_get_with_compatible_value_smaller(con): assert con.execute(expr) == 3 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -270,7 +321,9 @@ def test_map_get_with_compatible_value_bigger(con): assert con.execute(expr) == 3000 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -283,7 +336,9 @@ def test_map_get_with_incompatible_value_different_kind(con): @pytest.mark.parametrize("null_value", [None, ibis.NA]) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -303,6 +358,11 @@ def test_map_get_with_null_on_not_nullable(con, null_value): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_get_with_null_on_null_type_with_null(con, null_value): value = ibis.literal({"A": None, "B": None}) expr = value.get("C", null_value) @@ -310,7 +370,9 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value): assert pd.isna(result) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -327,6 +389,11 @@ def test_map_get_with_null_on_null_type_with_non_null(con): raises=exc.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_create_table(con, temp_table): t = con.create_table( temp_table, @@ -340,6 +407,11 @@ def test_map_create_table(con, temp_table): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_length(con): expr = ibis.literal(dict(a="A", b="B")).length() assert con.execute(expr) == 2 diff --git a/ibis/backends/tests/test_network.py b/ibis/backends/tests/test_network.py index e6048ee907b2d..dca5815c68554 100644 --- a/ibis/backends/tests/test_network.py +++ b/ibis/backends/tests/test_network.py @@ -20,6 +20,7 @@ "trino": "varchar(17)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(17) NOT NULL", } @@ -50,6 +51,7 @@ def test_macaddr_literal(con, backend): "trino": "127.0.0.1", "impala": "127.0.0.1", "postgres": "127.0.0.1", + "risingwave": "127.0.0.1", "pandas": "127.0.0.1", "pyspark": "127.0.0.1", "mysql": "127.0.0.1", @@ -67,6 +69,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(9)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(9) NOT NULL", }, id="ipv4", @@ -82,6 +85,7 @@ def test_macaddr_literal(con, backend): "trino": "2001:db8::1", "impala": "2001:db8::1", "postgres": "2001:db8::1", + "risingwave": "2001:db8::1", "pandas": "2001:db8::1", "pyspark": "2001:db8::1", "mysql": "2001:db8::1", @@ -99,6 +103,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(11)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(11) NOT NULL", }, id="ipv6", diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 23544c4b1445a..4816d6cf62423 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -52,6 +52,7 @@ "trino": "integer", "duckdb": "TINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="int8", @@ -67,6 +68,7 @@ "trino": "integer", "duckdb": "SMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="int16", @@ -82,6 +84,7 @@ "trino": "integer", "duckdb": "INTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="int32", @@ -97,6 +100,7 @@ "trino": "integer", "duckdb": "BIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="int64", @@ -112,6 +116,7 @@ "trino": "integer", "duckdb": "UTINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="uint8", @@ -127,6 +132,7 @@ "trino": "integer", "duckdb": "USMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="uint16", @@ -142,6 +148,7 @@ "trino": "integer", "duckdb": "UINTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="uint32", @@ -157,6 +164,7 @@ "trino": "integer", "duckdb": "UBIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="uint64", @@ -172,6 +180,7 @@ "trino": "real", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, marks=[ @@ -199,6 +208,7 @@ "trino": "real", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, id="float32", @@ -214,6 +224,7 @@ "trino": "double", "duckdb": "DOUBLE", "postgres": "numeric", + "risingwave": "numeric", "flink": "DOUBLE NOT NULL", }, id="float64", @@ -245,6 +256,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": decimal.Decimal("1.1"), "impala": decimal.Decimal("1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1"), @@ -263,6 +275,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(18,3)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 18) NOT NULL", }, marks=[ @@ -285,6 +298,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": decimal.Decimal("1.100000000"), "impala": decimal.Decimal("1.1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1.1"), @@ -305,6 +319,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(38,9)", "duckdb": "DECIMAL(38,9)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 9) NOT NULL", }, marks=[pytest.mark.notimpl(["exasol"], raises=ExaQueryError)], @@ -318,6 +333,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": 1.1, "dask": decimal.Decimal("1.1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "clickhouse": decimal.Decimal( @@ -333,6 +349,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", }, marks=[ pytest.mark.notimpl(["exasol"], raises=ExaQueryError), @@ -369,6 +386,7 @@ def test_numeric_literal(con, backend, expr, expected_types): { "sqlite": float("inf"), "postgres": decimal.Decimal("Infinity"), + "risingwave": float("nan"), "pandas": decimal.Decimal("Infinity"), "dask": decimal.Decimal("Infinity"), "pyspark": decimal.Decimal("Infinity"), @@ -379,6 +397,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "real", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -437,6 +456,7 @@ def test_numeric_literal(con, backend, expr, expected_types): { "sqlite": float("-inf"), "postgres": decimal.Decimal("-Infinity"), + "risingwave": float("nan"), "pandas": decimal.Decimal("-Infinity"), "dask": decimal.Decimal("-Infinity"), "pyspark": decimal.Decimal("-Infinity"), @@ -447,6 +467,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "real", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -507,6 +528,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": float("nan"), "sqlite": None, "postgres": float("nan"), + "risingwave": float("nan"), "pandas": decimal.Decimal("NaN"), "dask": decimal.Decimal("NaN"), "pyspark": decimal.Decimal("NaN"), @@ -519,6 +541,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "null", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -729,14 +752,55 @@ def test_isnan_isinf( L(5.556).log(2), math.log(5.556, 2), id="log-base", - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], + ), + param( + L(5.556).ln(), + math.log(5.556), + id="ln", ), param(L(5.556).ln(), math.log(5.556), id="ln"), param( L(5.556).log2(), math.log(5.556, 2), id="log2", - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], + ), + param( + L(5.556).log10(), + math.log10(5.556), + marks=pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError), + id="log10", + ), + param( + L(5.556).radians(), + math.radians(5.556), + id="radians", + ), + param( + L(5.556).degrees(), + math.degrees(5.556), + id="degrees", + ), + param( + L(11) % 3, + 11 % 3, + marks=pytest.mark.notimpl(["exasol"], raises=ExaQueryError), + id="mod", ), param(L(5.556).log10(), math.log10(5.556), id="log10"), param(L(5.556).radians(), math.radians(5.556), id="radians"), @@ -873,7 +937,14 @@ def test_simple_math_functions_columns( param( lambda t: t.double_col.add(1).log(2), lambda t: np.log2(t.double_col + 1), - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], id="log2", ), param( @@ -907,6 +978,11 @@ def test_simple_math_functions_columns( reason="Base greatest(9000, t0.bigint_col) for logarithm not supported!", ), pytest.mark.notimpl(["polars"], raises=com.UnsupportedArgumentError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), ], ), ], @@ -1130,6 +1206,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1142,6 +1219,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1154,6 +1232,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1166,6 +1245,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1210,6 +1290,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ( { "postgres": None, + "risingwave": None, "mysql": 10, "snowflake": 38, "trino": 18, @@ -1219,6 +1300,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): }, { "postgres": None, + "risingwave": None, "mysql": 0, "snowflake": 0, "trino": 3, @@ -1252,6 +1334,11 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ], reason="Not SQLAlchemy backends", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(5)", +) def test_sa_default_numeric_precision_and_scale( con, backend, default_precisions, default_scales, temp_table ): @@ -1287,6 +1374,11 @@ def test_sa_default_numeric_precision_and_scale( @pytest.mark.notimpl(["dask", "pandas", "polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_random(con): expr = ibis.random() result = con.execute(expr) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 3fea928999e21..c0f3a98b3a047 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -38,6 +39,11 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): ) @pytest.mark.notimpl(["trino", "druid"]) @pytest.mark.broken(["oracle"], raises=OracleDatabaseError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) @@ -76,6 +82,7 @@ def test_scalar_param_array(con): "impala", "flink", "postgres", + "risingwave", "druid", "oracle", "exasol", @@ -108,6 +115,11 @@ def test_scalar_param_struct(con): "sql= SELECT MAP_FROM_ARRAYS(ARRAY['a', 'b', 'c'], ARRAY['ghi', 'def', 'abc']) '[' 'b' ']' AS `MapGet(param_0, 'b', None)`" ), ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_map(con): value = {"a": "ghi", "b": "def", "c": "abc"} param = ibis.param(dt.Map(dt.string, dt.string)) @@ -168,6 +180,11 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col): ids=["string", "date", "datetime"], ) @pytest.mark.notimpl(["druid", "oracle"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_date(backend, alltypes, value): param = ibis.param("date") ds_col = alltypes.date_string_col @@ -203,6 +220,7 @@ def test_scalar_param_date(backend, alltypes, value): @pytest.mark.notimpl( [ "postgres", + "risingwave", "datafusion", "clickhouse", "polars", diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 4e1739e30cf05..d967e45d80c7f 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -93,6 +93,7 @@ def gzip_csv(data_dir, tmp_path): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -119,6 +120,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -142,6 +144,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -198,6 +201,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -234,6 +238,7 @@ def test_register_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -273,6 +278,7 @@ def test_register_iterator_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -303,6 +309,7 @@ def test_register_pandas(con): "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -328,6 +335,7 @@ def test_register_pyarrow_tables(con): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -370,6 +378,7 @@ def test_csv_reregister_schema(con, tmp_path): "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -400,7 +409,7 @@ def test_register_garbage(con, monkeypatch): ], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -431,7 +440,17 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -450,7 +469,17 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -479,6 +508,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): "mysql", "pandas", "postgres", + "risingwave", "sqlite", "trino", ] @@ -527,7 +557,7 @@ def num_diamonds(data_dir): [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 0fb52fa10f0fd..41102559ad9c6 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -67,19 +68,26 @@ def test_union_mixed_distinct(backend, union_subsets): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support INTERSECT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support INTERSECT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -114,19 +122,26 @@ def test_intersect(backend, alltypes, df, distinct): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support EXCEPT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support EXCEPT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: EXCEPT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -193,18 +208,25 @@ def test_top_level_union(backend, con, alltypes, distinct): True, param( False, - marks=pytest.mark.notimpl( - [ - "impala", - "bigquery", - "dask", - "mssql", - "pandas", - "snowflake", - "sqlite", - "exasol", - ] - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "bigquery", + "dask", + "mssql", + "pandas", + "snowflake", + "sqlite", + "exasol", + ] + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], ), ], ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 18ab061bdef52..082ef3870927b 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -11,7 +11,7 @@ sa = pytest.importorskip("sqlalchemy") sg = pytest.importorskip("sqlglot") -pytestmark = pytest.mark.notimpl(["flink"]) +pytestmark = pytest.mark.notimpl(["flink", "risingwave"]) simple_literal = param(ibis.literal(1), id="simple_literal") array_literal = param( diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 0505cb1c39bba..98ea425c1b89b 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -33,6 +34,7 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(6) NOT NULL", }, id="string", @@ -48,14 +50,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote1", - marks=pytest.mark.broken( - ["oracle"], - raises=OracleDatabaseError, - reason="ORA-01741: illegal zero length identifier", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=OracleDatabaseError, + reason="ORA-01741: illegal zero length identifier", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), param( 'STRI"NG', @@ -68,14 +78,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote2", - marks=pytest.mark.broken( - ["oracle"], - raises=OracleDatabaseError, - reason="ORA-25716", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=OracleDatabaseError, + reason="ORA-25716", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), ], ) @@ -215,6 +233,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -225,6 +248,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -240,6 +268,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["druid"], reason="No posix support", raises=AssertionError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -250,6 +283,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -262,6 +300,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -274,6 +317,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -286,6 +334,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -296,6 +349,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -306,6 +364,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -846,6 +909,7 @@ def test_substr_with_null_values(backend, alltypes, df): "mysql", "polars", "postgres", + "risingwave", "pyspark", "druid", "oracle", @@ -918,6 +982,11 @@ def test_multiple_subs(con): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function levenshtein(character varying, character varying) does not exist", +) @pytest.mark.parametrize( "right", ["sitting", ibis.literal("sitting")], ids=["python", "ibis"] ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index d368c6b163058..95f7df9f4ea50 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -65,7 +65,7 @@ def test_all_fields(struct, struct_df): _NULL_STRUCT_LITERAL = ibis.NA.cast("struct") -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" @@ -79,7 +79,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -95,7 +95,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" ) @@ -111,7 +111,7 @@ def test_struct_column(backend, alltypes, df): tm.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "polars"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 99a5b464ae89d..2ccc7184c73ce 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa import sqlglot as sg from pytest import param @@ -155,6 +156,11 @@ def test_timestamp_extract(backend, alltypes, df, attr): "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" ), ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -434,6 +440,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "impala", "mysql", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -634,6 +641,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'WEEK'. Was expecting one of: DAY, DAYS, HOUR", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), ], ), param("D", pd.offsets.DateOffset), @@ -652,6 +664,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MILLISECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: millisecond", + ), ], ), param( @@ -671,6 +688,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MICROSECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: microsecond", + ), ], ), ], @@ -721,7 +743,14 @@ def convert_to_offset(offset, displacement_type=displacement_type): ), param( "W", - marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + marks=[ + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), + ], ), "D", ], @@ -819,7 +848,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop", marks=[ pytest.mark.notimpl( - ["dask", "snowflake", "sqlite", "bigquery", "exasol"], + ["dask", "snowflake", "sqlite", "bigquery", "exasol", "risingwave"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -839,7 +868,14 @@ def convert_to_offset(x): id="timestamp-add-interval-binop-different-units", marks=[ pytest.mark.notimpl( - ["sqlite", "polars", "snowflake", "bigquery", "exasol"], + [ + "sqlite", + "polars", + "snowflake", + "bigquery", + "exasol", + "risingwave", + ], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -963,6 +999,11 @@ def convert_to_offset(x): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", + ), pytest.mark.broken( ["flink"], raises=com.UnsupportedOperationError, @@ -1409,6 +1450,13 @@ def test_interval_add_cast_column(backend, alltypes, df): raises=com.UnsupportedArgumentError, reason="Polars does not support columnar argument StringConcat()", ), + pytest.mark.notimpl( + [ + "risingwave", + ], + raises=AttributeError, + reason="Neither 'concat' object nor 'Comparator' object has an attribute 'value'", + ), pytest.mark.notyet(["dask"], raises=com.OperationNotDefinedError), pytest.mark.notyet(["impala"], raises=com.UnsupportedOperationError), pytest.mark.notimpl(["druid", "flink"], raises=AttributeError), @@ -1511,7 +1559,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): ], ) @pytest.mark.notimpl( - ["mysql", "postgres", "sqlite", "druid", "oracle"], + ["mysql", "postgres", "risingwave", "sqlite", "druid", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @@ -1593,6 +1641,7 @@ def test_integer_to_timestamp(backend, con, unit): [ "dask", "pandas", + "risingwave", "clickhouse", "sqlite", "datafusion", @@ -1631,6 +1680,11 @@ def test_string_to_timestamp(alltypes, fmt): reason="DayOfWeekName is not supported in Flink", ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_scalar(con, date, expected_index, expected_day): expr = ibis.literal(date).cast(dt.date) result_index = con.execute(expr.day_of_week.index().name("tmp")) @@ -1656,6 +1710,11 @@ def test_day_of_week_scalar(con, date, expected_index, expected_day): ), ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_column(backend, alltypes, df): expr = alltypes.timestamp_col.day_of_week @@ -1692,6 +1751,11 @@ def test_day_of_week_column(backend, alltypes, df): "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" ), ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -1759,6 +1823,7 @@ def test_now_from_projection(alltypes): "snowflake": "DATE", "sqlite": "text", "trino": "date", + "risingwave": "date", } @@ -1769,6 +1834,11 @@ def test_now_from_projection(alltypes): @pytest.mark.notimpl( ["oracle"], raises=OracleDatabaseError, reason="ORA-00936 missing expression" ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_literal(con, backend): expr = ibis.date(2022, 2, 4) result = con.execute(expr) @@ -1788,6 +1858,7 @@ def test_date_literal(con, backend): "trino": "timestamp(3)", "duckdb": "TIMESTAMP", "postgres": "timestamp without time zone", + "risingwave": "timestamp without time zone", "flink": "TIMESTAMP(6) NOT NULL", } @@ -1797,6 +1868,11 @@ def test_date_literal(con, backend): raises=com.OperationNotDefinedError, ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_literal(con, backend): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0) result = con.execute(expr) @@ -1853,6 +1929,11 @@ def test_timestamp_literal(con, backend): ", , , )" ), ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_with_timezone_literal(con, timezone, expected): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0).cast(dt.Timestamp(timezone=timezone)) result = con.execute(expr) @@ -1869,6 +1950,7 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): "trino": "time(3)", "duckdb": "TIME", "postgres": "time without time zone", + "risingwave": "time without time zone", } @@ -1880,6 +1962,11 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): ["clickhouse", "impala", "exasol"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_time(integer, integer, integer) does not exist", +) def test_time_literal(con, backend): expr = ibis.time(16, 20, 0) result = con.execute(expr) @@ -1956,6 +2043,7 @@ def test_extract_time_from_timestamp(con, microsecond): "trino": "interval day to second", "duckdb": "INTERVAL", "postgres": "interval", + "risingwave": "interval", } @@ -2024,6 +2112,11 @@ def test_interval_literal(con, backend): @pytest.mark.broken( ["oracle"], raises=OracleDatabaseError, reason="ORA-00936: missing expression" ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_column_from_ymd(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.date(c.year(), c.month(), c.day()) @@ -2044,6 +2137,11 @@ def test_date_column_from_ymd(backend, con, alltypes, df): reason="StringColumn' object has no attribute 'year'", ) @pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(smallint, smallint, smallint, smallint, smallint, smallint) does not exist", +) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp( @@ -2172,6 +2270,11 @@ def build_date_col(t): param(lambda _: DATE, build_date_col, id="date_column"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): left = left_fn(alltypes) right = right_fn(alltypes) @@ -2294,6 +2397,11 @@ def test_large_timestamp(con): raises=AssertionError, ), pytest.mark.notimpl(["exasol"], raises=AssertionError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Parse error: timestamp without time zone Can't cast string to timestamp (expected format is YYYY-MM-DD HH:MM:SS[.D+{up to 6 digits}] or YYYY-MM-DD HH:MM or YYYY-MM-DD or ISO 8601 format)", + ), ], ), ], @@ -2325,6 +2433,11 @@ def test_timestamp_precision_output(con, ts, scale, unit): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notyet( + ["risingwave"], + reason="postgres doesn't have any easy way to accurately compute the delta in specific units", + raises=com.OperationNotDefinedError, +) @pytest.mark.parametrize( ("start", "end", "unit", "expected"), [ @@ -2471,6 +2584,11 @@ def test_delta(con, start, end, unit, expected): ], ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket(backend, kws, pd_freq): ts = backend.functional_alltypes.timestamp_col.execute().rename("ts") res = backend.functional_alltypes.timestamp_col.bucket(**kws).execute().rename("ts") @@ -2506,6 +2624,11 @@ def test_timestamp_bucket(backend, kws, pd_freq): ) @pytest.mark.parametrize("offset_mins", [2, -2], ids=["pos", "neg"]) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket_offset(backend, offset_mins): ts = backend.functional_alltypes.timestamp_col expr = ts.bucket(minutes=5, offset=ibis.interval(minutes=offset_mins)) @@ -2616,6 +2739,11 @@ def test_time_literal_sql(dialect, snapshot, micros): param(datetime.date.fromisoformat, id="fromstring"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar(con, value, func): expr = ibis.date(func(value)).name("tmp") diff --git a/ibis/backends/tests/test_timecontext.py b/ibis/backends/tests/test_timecontext.py index 2488ee5b36594..88376a4f961b9 100644 --- a/ibis/backends/tests/test_timecontext.py +++ b/ibis/backends/tests/test_timecontext.py @@ -19,6 +19,7 @@ "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index f0f85392053d3..2a55a30355b28 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -20,6 +20,7 @@ "oracle", "pandas", "trino", + "risingwave", ] ) diff --git a/ibis/backends/tests/test_uuid.py b/ibis/backends/tests/test_uuid.py index a01a1c124ad7c..3f2d49cd6d59f 100644 --- a/ibis/backends/tests/test_uuid.py +++ b/ibis/backends/tests/test_uuid.py @@ -4,6 +4,7 @@ import uuid import pytest +import sqlalchemy import ibis import ibis.common.exceptions as com @@ -28,6 +29,11 @@ @pytest.mark.notimpl(["polars"], raises=NotImplementedError) @pytest.mark.notimpl(["datafusion"], raises=Exception) +@pytest.mark.notimpl( + ["risingwave"], + raises=sqlalchemy.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: UUID", +) def test_uuid_literal(con, backend): backend_name = backend.name() diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index c2414db0b58da..fa6728acb7f21 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -10,7 +10,7 @@ import ibis.expr.datatypes as dt from ibis.legacy.udf.vectorized import analytic, elementwise, reduction -pytestmark = pytest.mark.notimpl(["druid", "oracle"]) +pytestmark = pytest.mark.notimpl(["druid", "oracle", "risingwave"]) def _format_udf_return_type(func, result_formatter): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 1e3627afab8f1..f5532f7b7e8b5 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -137,6 +138,11 @@ def calc_zscore(s): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", + ), ], ), param( @@ -148,6 +154,12 @@ def calc_zscore(s): ["clickhouse", "exasol"], raises=com.OperationNotDefinedError ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: cume_dist", + ), ], ), param( @@ -174,6 +186,11 @@ def calc_zscore(s): raises=com.UnsupportedOperationError, reason="Windows in Flink can only be ordered by a single time column", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -212,6 +229,7 @@ def calc_zscore(s): ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -373,7 +391,14 @@ def test_grouped_bounded_expanding_window( lambda t, win: t.double_col.mean().over(win), lambda df: (df.double_col.expanding().mean()), id="mean", - marks=[pytest.mark.notimpl(["dask"], raises=NotImplementedError)], + marks=[ + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( # Disabled on PySpark and Spark backends because in pyspark<3.0.0, @@ -393,6 +418,7 @@ def test_grouped_bounded_expanding_window( "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "datafusion", @@ -548,6 +574,7 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -625,6 +652,11 @@ def test_grouped_unbounded_window( @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["flink"], raises=com.UnsupportedOperationError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_unbound_following_window( backend, alltypes, ibis_method, pandas_fn ): @@ -652,6 +684,11 @@ def test_simple_ungrouped_unbound_following_window( ["mssql"], raises=Exception, reason="order by constant is not supported" ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_window_with_scalar_order_by(alltypes): t = alltypes[alltypes.double_col < 50].order_by("id") w = ibis.window(rows=(0, None), order_by=ibis.NA) @@ -680,6 +717,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="default window semantics are different", raises=AssertionError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -713,6 +755,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=Py4JJavaError, reason="CalciteContextException: Argument to function 'NTILE' must be a literal", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -732,6 +779,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -763,6 +811,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -783,6 +832,13 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): lambda df: df.float_col.shift(1), True, id="ordered-lag", + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( lambda t, win: t.float_col.lag().over(win), @@ -812,6 +868,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=SnowflakeProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -819,6 +880,13 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): lambda df: df.float_col.shift(-1), True, id="ordered-lead", + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( lambda t, win: t.float_col.lead().over(win), @@ -851,6 +919,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=SnowflakeProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -870,6 +943,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -902,6 +976,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -968,6 +1043,11 @@ def test_ungrouped_unbounded_window( raises=MySQLOperationalError, reason="https://github.com/tobymao/sqlglot/issues/2779", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", +) def test_grouped_bounded_range_window(backend, alltypes, df): # Explanation of the range window spec below: # @@ -1023,6 +1103,11 @@ def gb_fn(df): reason="clickhouse doesn't implement percent_rank", raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_percent_rank_whole_table_no_order_by(backend, alltypes, df): expr = alltypes.mutate(val=lambda t: t.id.percent_rank()) @@ -1069,6 +1154,11 @@ def agg(df): @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_mutate_window_filter(backend, alltypes): t = alltypes win = ibis.window(order_by=[t.id]) @@ -1143,6 +1233,11 @@ def test_first_last(backend): raises=ExaQueryError, reason="database can't handle UTC timestamps in DataFrames", ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="sql parser error: Expected literal int, found: INTERVAL at line:1, column:99", +) def test_range_expression_bounds(backend): t = ibis.memtable( { @@ -1187,6 +1282,11 @@ def test_range_expression_bounds(backend): @pytest.mark.broken( ["mssql"], reason="lack of support for booleans", raises=PyODBCProgrammingError ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): # GH #7631 t = alltypes @@ -1217,6 +1317,11 @@ def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notyet(["flink"], raises=com.UnsupportedOperationError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_windowed_order_by_sequence_is_preserved(con): table = ibis.memtable({"bool_col": [True, False, False, None, True]}) window = ibis.window( diff --git a/poetry.lock b/poetry.lock index 85d8d9e77ddd7..52e3160675580 100644 --- a/poetry.lock +++ b/poetry.lock @@ -442,33 +442,33 @@ files = [ [[package]] name = "black" -version = "23.12.1" +version = "24.1.0" description = "The uncompromising code formatter." optional = true python-versions = ">=3.8" files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, + {file = "black-24.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94d5280d020dadfafc75d7cae899609ed38653d3f5e82e7ce58f75e76387ed3d"}, + {file = "black-24.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aaf9aa85aaaa466bf969e7dd259547f4481b712fe7ee14befeecc152c403ee05"}, + {file = "black-24.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec489cae76eac3f7573629955573c3a0e913641cafb9e3bfc87d8ce155ebdb29"}, + {file = "black-24.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5a0100b4bdb3744dd68412c3789f472d822dc058bb3857743342f8d7f93a5a7"}, + {file = "black-24.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6cc5a6ba3e671cfea95a40030b16a98ee7dc2e22b6427a6f3389567ecf1b5262"}, + {file = "black-24.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0e367759062dcabcd9a426d12450c6d61faf1704a352a49055a04c9f9ce8f5a"}, + {file = "black-24.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be305563ff4a2dea813f699daaffac60b977935f3264f66922b1936a5e492ee4"}, + {file = "black-24.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:6a8977774929b5db90442729f131221e58cc5d8208023c6af9110f26f75b6b20"}, + {file = "black-24.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d74d4d0da276fbe3b95aa1f404182562c28a04402e4ece60cf373d0b902f33a0"}, + {file = "black-24.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39addf23f7070dbc0b5518cdb2018468ac249d7412a669b50ccca18427dba1f3"}, + {file = "black-24.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:827a7c0da520dd2f8e6d7d3595f4591aa62ccccce95b16c0e94bb4066374c4c2"}, + {file = "black-24.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:0cd59d01bf3306ff7e3076dd7f4435fcd2fafe5506a6111cae1138fc7de52382"}, + {file = "black-24.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf8dd261ee82df1abfb591f97e174345ab7375a55019cc93ad38993b9ff5c6ad"}, + {file = "black-24.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:82d9452aeabd51d1c8f0d52d4d18e82b9f010ecb30fd55867b5ff95904f427ff"}, + {file = "black-24.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9aede09f72b2a466e673ee9fca96e4bccc36f463cac28a35ce741f0fd13aea8b"}, + {file = "black-24.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:780f13d03066a7daf1707ec723fdb36bd698ffa29d95a2e7ef33a8dd8fe43b5c"}, + {file = "black-24.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a15670c650668399c4b5eae32e222728185961d6ef6b568f62c1681d57b381ba"}, + {file = "black-24.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1e0fa70b8464055069864a4733901b31cbdbe1273f63a24d2fa9d726723d45ac"}, + {file = "black-24.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fa8d9aaa22d846f8c0f7f07391148e5e346562e9b215794f9101a8339d8b6d8"}, + {file = "black-24.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:f0dfbfbacfbf9cd1fac7a5ddd3e72510ffa93e841a69fcf4a6358feab1685382"}, + {file = "black-24.1.0-py3-none-any.whl", hash = "sha256:5134a6f6b683aa0a5592e3fd61dd3519d8acd953d93e2b8b76f9981245b65594"}, + {file = "black-24.1.0.tar.gz", hash = "sha256:30fbf768cd4f4576598b1db0202413fafea9a227ef808d1a12230c643cefe9fc"}, ] [package.dependencies] @@ -1429,6 +1429,21 @@ files = [ {file = "duckdb-0.9.2.tar.gz", hash = "sha256:3843afeab7c3fc4a4c0b53686a4cc1d9cdbdadcbb468d60fef910355ecafd447"}, ] +[[package]] +name = "duckdb-engine" +version = "0.10.0" +description = "SQLAlchemy driver for duckdb" +optional = true +python-versions = ">=3.7" +files = [ + {file = "duckdb_engine-0.10.0-py3-none-any.whl", hash = "sha256:c408d002e83630b6bbb05fc3b26a43406085b1c22dd43e8cab00bf0b9c011ea8"}, + {file = "duckdb_engine-0.10.0.tar.gz", hash = "sha256:5e3dad3b3513f055a4f5ec5430842249cfe03015743a7597ed1dcc0447dca565"}, +] + +[package.dependencies] +duckdb = ">=0.4.0" +sqlalchemy = ">=1.3.22" + [[package]] name = "dulwich" version = "0.21.7" @@ -1900,6 +1915,24 @@ requests = {version = "*", extras = ["socks"]} six = "*" tqdm = "*" +[[package]] +name = "geoalchemy2" +version = "0.14.3" +description = "Using SQLAlchemy with Spatial Databases" +optional = true +python-versions = ">=3.7" +files = [ + {file = "GeoAlchemy2-0.14.3-py3-none-any.whl", hash = "sha256:a727198394fcc4760a27c4c5bff8b9f4f79324ec2dd98c4c1b8a7026b8918d81"}, + {file = "GeoAlchemy2-0.14.3.tar.gz", hash = "sha256:79c432b10dd8c48422f794eaf9a1200929de14f41d2396923bfe92f4c6abaf89"}, +] + +[package.dependencies] +packaging = "*" +SQLAlchemy = ">=1.4" + +[package.extras] +shapely = ["Shapely (>=1.7)"] + [[package]] name = "geojson" version = "3.1.0" @@ -4419,62 +4452,62 @@ files = [ [[package]] name = "py4j" -version = "0.10.9.7" +version = "0.10.9.5" description = "Enables Python programs to dynamically access arbitrary Java objects" optional = true python-versions = "*" files = [ - {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, - {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, + {file = "py4j-0.10.9.5-py2.py3-none-any.whl", hash = "sha256:52d171a6a2b031d8a5d1de6efe451cf4f5baff1a2819aabc3741c8406539ba04"}, + {file = "py4j-0.10.9.5.tar.gz", hash = "sha256:276a4a3c5a2154df1860ef3303a927460e02e97b047dc0a47c1c3fb8cce34db6"}, ] [[package]] name = "pyarrow" -version = "14.0.2" +version = "15.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-14.0.2-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:ba9fe808596c5dbd08b3aeffe901e5f81095baaa28e7d5118e01354c64f22807"}, - {file = "pyarrow-14.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:22a768987a16bb46220cef490c56c671993fbee8fd0475febac0b3e16b00a10e"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dbba05e98f247f17e64303eb876f4a80fcd32f73c7e9ad975a83834d81f3fda"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a898d134d00b1eca04998e9d286e19653f9d0fcb99587310cd10270907452a6b"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:87e879323f256cb04267bb365add7208f302df942eb943c93a9dfeb8f44840b1"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:76fc257559404ea5f1306ea9a3ff0541bf996ff3f7b9209fc517b5e83811fa8e"}, - {file = "pyarrow-14.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0c4a18e00f3a32398a7f31da47fefcd7a927545b396e1f15d0c85c2f2c778cd"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02"}, - {file = "pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379"}, - {file = "pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:e354fba8490de258be7687f341bc04aba181fc8aa1f71e4584f9890d9cb2dec2"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:20e003a23a13da963f43e2b432483fdd8c38dc8882cd145f09f21792e1cf22a1"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc0de7575e841f1595ac07e5bc631084fd06ca8b03c0f2ecece733d23cd5102a"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e986dc859712acb0bd45601229021f3ffcdfc49044b64c6d071aaf4fa49e98"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f7d029f20ef56673a9730766023459ece397a05001f4e4d13805111d7c2108c0"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:209bac546942b0d8edc8debda248364f7f668e4aad4741bae58e67d40e5fcf75"}, - {file = "pyarrow-14.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:1e6987c5274fb87d66bb36816afb6f65707546b3c45c44c28e3c4133c010a881"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a01d0052d2a294a5f56cc1862933014e696aa08cc7b620e8c0cce5a5d362e976"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a51fee3a7db4d37f8cda3ea96f32530620d43b0489d169b285d774da48ca9785"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64df2bf1ef2ef14cee531e2dfe03dd924017650ffaa6f9513d7a1bb291e59c15"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c0fa3bfdb0305ffe09810f9d3e2e50a2787e3a07063001dcd7adae0cee3601a"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c65bf4fd06584f058420238bc47a316e80dda01ec0dfb3044594128a6c2db794"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:63ac901baec9369d6aae1cbe6cca11178fb018a8d45068aaf5bb54f94804a866"}, - {file = "pyarrow-14.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:75ee0efe7a87a687ae303d63037d08a48ef9ea0127064df18267252cfe2e9541"}, - {file = "pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025"}, -] - -[package.dependencies] -numpy = ">=1.16.6" + {file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"}, + {file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"}, + {file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"}, + {file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"}, + {file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"}, + {file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"}, + {file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"}, + {file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" [[package]] name = "pyarrow-hotfix" @@ -4701,6 +4734,7 @@ files = [ [package.dependencies] requests = "*" +sqlalchemy = {version = "*", optional = true, markers = "extra == \"sqlalchemy\""} [package.extras] async = ["tornado"] @@ -4721,7 +4755,6 @@ files = [ [package.dependencies] packaging = "*" -pandas = {version = "*", optional = true, markers = "extra == \"pandas\""} pyopenssl = "*" rsa = "*" websocket-client = ">=1.0.1" @@ -5006,23 +5039,22 @@ files = [ [[package]] name = "pyspark" -version = "3.5.0" +version = "3.3.4" description = "Apache Spark Python API" optional = true -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "pyspark-3.5.0.tar.gz", hash = "sha256:d41a9b76bd2aca370a6100d075c029e22ba44c5940927877e9435a3a9c566558"}, + {file = "pyspark-3.3.4.tar.gz", hash = "sha256:1f866be47130a522355240949ed50d9812a8f327bd7619f043ffe07fbcf7f7b6"}, ] [package.dependencies] -py4j = "0.10.9.7" +py4j = "0.10.9.5" [package.extras] -connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.56.0)", "grpcio-status (>=1.56.0)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] ml = ["numpy (>=1.15)"] mllib = ["numpy (>=1.15)"] -pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] -sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] +pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] +sql = ["pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] [[package]] name = "pystac" @@ -5070,13 +5102,13 @@ docs = ["Sphinx (>=6.2,<7.0)", "boto3 (>=1.26,<2.0)", "cartopy (>=0.21,<1.0)", " [[package]] name = "pytest" -version = "7.4.4" +version = "8.0.0" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, - {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, + {file = "pytest-8.0.0-py3-none-any.whl", hash = "sha256:50fb9cbe836c3f20f0dfa99c565201fb75dc54c8d76373cd1bde06b06657bdb6"}, + {file = "pytest-8.0.0.tar.gz", hash = "sha256:249b1b0864530ba251b7438274c4d251c58d868edaaec8762893ad4a0d71c36c"}, ] [package.dependencies] @@ -5084,7 +5116,7 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" +pluggy = ">=1.3.0,<2.0" tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] @@ -5353,6 +5385,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -6246,6 +6279,25 @@ development = ["Cython", "coverage", "more-itertools", "numpy (<1.27.0)", "pendu pandas = ["pandas (>=1.0.0,<2.2.0)", "pyarrow"] secure-local-storage = ["keyring (!=16.1.0,<25.0.0)"] +[[package]] +name = "snowflake-sqlalchemy" +version = "1.5.1" +description = "Snowflake SQLAlchemy Dialect" +optional = true +python-versions = ">=3.7" +files = [ + {file = "snowflake-sqlalchemy-1.5.1.tar.gz", hash = "sha256:4f1383402ffc89311974bd810dee22003aef4af0f312a0fdb55778333ad1abf7"}, + {file = "snowflake_sqlalchemy-1.5.1-py2.py3-none-any.whl", hash = "sha256:df022fb73bc04d68dfb3216ebf7a1bfbd14d22def9c38bbe05275beb258adcd0"}, +] + +[package.dependencies] +snowflake-connector-python = "<4.0.0" +sqlalchemy = ">=1.4.0,<2.0.0" + +[package.extras] +development = ["mock", "numpy", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-timeout", "pytz"] +pandas = ["snowflake-connector-python[pandas] (<4.0.0)"] + [[package]] name = "sortedcontainers" version = "2.4.0" @@ -6363,6 +6415,40 @@ postgresql-psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql", "pymysql (<1)"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlalchemy-exasol" +version = "4.6.3" +description = "EXASOL dialect for SQLAlchemy" +optional = true +python-versions = ">=3.8,<4.0" +files = [ + {file = "sqlalchemy_exasol-4.6.3-py3-none-any.whl", hash = "sha256:d524d14bd84935087fb4e9fed273c1b5f6d23f0008ef3460a0278aa332e646ea"}, + {file = "sqlalchemy_exasol-4.6.3.tar.gz", hash = "sha256:03a424886cc90480a2127ca0531779e8b0a415d4b113d85dd23025d6c0b52cd3"}, +] + +[package.dependencies] +packaging = ">=21.3" +pyexasol = ">=0.25.1,<0.26.0" +pyodbc = ">=4.0.34,<6" +sqlalchemy = ">=1.4,<2" + +[package.extras] +turbodbc = ["turbodbc (==4.5.4)"] + +[[package]] +name = "sqlalchemy-risingwave" +version = "1.0.0" +description = "RisingWave dialect for SQLAlchemy" +optional = true +python-versions = "*" +files = [ + {file = "sqlalchemy-risingwave-1.0.0.tar.gz", hash = "sha256:856a3c44b98cba34d399c3cc9785a74896caca152b3685d87553e4210e3e07a4"}, + {file = "sqlalchemy_risingwave-1.0.0-py3-none-any.whl", hash = "sha256:c733365abc38e88f4d23d83713cfc3f21c0b0d3c81210cbc2f569b49a912ba08"}, +] + +[package.dependencies] +SQLAlchemy = ">=1.4,<2" + [[package]] name = "sqlalchemy-views" version = "0.3.2" @@ -6686,6 +6772,7 @@ files = [ python-dateutil = "*" pytz = "*" requests = ">=2.31.0" +sqlalchemy = {version = ">=1.3", optional = true, markers = "extra == \"sqlalchemy\""} tzlocal = "*" [package.extras] @@ -7309,33 +7396,34 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pyexasol", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "sqlalchemy", "sqlalchemy-views", "trino"] +all = ["black", "clickhouse-connect", "dask", "datafusion", "db-dtypes", "deltalake", "duckdb", "duckdb-engine", "geoalchemy2", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", "graphviz", "impyla", "oracledb", "packaging", "pins", "polars", "psycopg2", "pydata-google-auth", "pydruid", "pymysql", "pyodbc", "pyspark", "regex", "shapely", "snowflake-connector-python", "snowflake-sqlalchemy", "sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-risingwave", "sqlalchemy-views", "trino"] bigquery = ["db-dtypes", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pydata-google-auth"] -clickhouse = ["clickhouse-connect"] +clickhouse = ["clickhouse-connect", "sqlalchemy"] dask = ["dask", "regex"] datafusion = ["datafusion"] decompiler = ["black"] deltalake = ["deltalake"] -druid = ["pydruid"] -duckdb = ["duckdb"] +druid = ["pydruid", "sqlalchemy"] +duckdb = ["duckdb", "duckdb-engine", "sqlalchemy", "sqlalchemy-views"] examples = ["pins"] -exasol = ["pyexasol"] +exasol = ["sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-views"] flink = [] -geospatial = ["geopandas", "shapely"] -impala = ["impyla"] -mssql = ["pyodbc"] -mysql = ["pymysql"] -oracle = ["oracledb", "packaging"] +geospatial = ["geoalchemy2", "geopandas", "shapely"] +impala = ["impyla", "sqlalchemy"] +mssql = ["pyodbc", "sqlalchemy", "sqlalchemy-views"] +mysql = ["pymysql", "sqlalchemy", "sqlalchemy-views"] +oracle = ["oracledb", "packaging", "sqlalchemy", "sqlalchemy-views"] pandas = ["regex"] polars = ["packaging", "polars"] -postgres = ["psycopg2"] -pyspark = ["packaging", "pyspark"] -snowflake = ["packaging", "snowflake-connector-python"] +postgres = ["psycopg2", "sqlalchemy", "sqlalchemy-views"] +pyspark = ["packaging", "pyspark", "sqlalchemy"] +risingwave = ["psycopg2", "sqlalchemy", "sqlalchemy-risingwave", "sqlalchemy-views"] +snowflake = ["packaging", "snowflake-connector-python", "snowflake-sqlalchemy", "sqlalchemy-views"] sqlite = ["regex", "sqlalchemy", "sqlalchemy-views"] -trino = ["trino"] +trino = ["sqlalchemy", "sqlalchemy-views", "trino"] visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "11da6bdc8c65ae8790ee2cbc799ca82af0c1f783f8c5ec6d0ab1477fd21b03b7" +content-hash = "3c1dfc652d2d025e6ea434033966154b44ba3a4452cbe3c7439ea4754c6ec420" diff --git a/pyproject.toml b/pyproject.toml index 5ecbbf06a5a78..834eae8d759f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ snowflake-connector-python = { version = ">=3.0.2,<4,!=3.3.0b1", optional = true sqlalchemy = { version = ">=1.4,<3", optional = true } sqlalchemy-views = { version = ">=0.3.1,<1", optional = true } trino = { version = ">=0.321,<1", optional = true } +sqlalchemy-risingwave = { version = ">=1.0.0,<2", optional = true } [tool.poetry.group.dev.dependencies] codespell = { version = ">=2.2.6,<3", extras = [ @@ -169,6 +170,7 @@ all = [ "snowflake-connector-python", "sqlalchemy", "sqlalchemy-views", + "sqlalchemy-risingwave", "trino", ] bigquery = [ @@ -194,6 +196,12 @@ polars = ["polars", "packaging"] postgres = ["psycopg2"] pyspark = ["pyspark", "packaging"] snowflake = ["snowflake-connector-python", "packaging"] +risingwave = [ + "psycopg2", + "sqlalchemy", + "sqlalchemy-views", + "sqlalchemy-risingwave", +] sqlite = ["regex", "sqlalchemy", "sqlalchemy-views"] trino = ["trino"] # non-backend extras @@ -218,6 +226,7 @@ oracle = "ibis.backends.oracle" pandas = "ibis.backends.pandas" polars = "ibis.backends.polars" postgres = "ibis.backends.postgres" +risingwave = "ibis.backends.risingwave" pyspark = "ibis.backends.pyspark" snowflake = "ibis.backends.snowflake" sqlite = "ibis.backends.sqlite" @@ -354,6 +363,7 @@ markers = [ "pandas: Pandas tests", "polars: Polars tests", "postgres: PostgreSQL tests", + "risingwave: Risingwave tests", "pyspark: PySpark tests", "snowflake: Snowflake tests", "sqlite: SQLite tests", diff --git a/requirements-dev.txt b/requirements-dev.txt index 538b128b742d0..ccc3ab5e0f62e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -234,6 +234,7 @@ snowflake-connector-python==3.6.0 ; python_version >= "3.9" and python_version < sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "4.0" soupsieve==2.5 ; python_version >= "3.10" and python_version < "3.13" sphobjinv==2.3.1 ; python_version >= "3.10" and python_version < "3.13" +sqlalchemy-risingwave==1.0.0 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy-views==0.3.2 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy==1.4.51 ; python_version >= "3.9" and python_version < "4.0" sqlglot==20.8.0 ; python_version >= "3.9" and python_version < "4.0"