From 751639c2b7ccdb498018efc82868abab3297ac90 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:38:46 -0500 Subject: [PATCH] refactor(flink): port to sqlglot --- .github/workflows/ibis-backends.yml | 108 ++- ibis/backends/base/__init__.py | 9 +- ibis/backends/base/sqlglot/compiler.py | 2 +- ibis/backends/base/sqlglot/datatypes.py | 10 + ibis/backends/base/sqlglot/dialects.py | 17 + ibis/backends/flink/__init__.py | 124 ++-- ibis/backends/flink/compiler.py | 604 +++++++++++++++++ ibis/backends/flink/compiler/__init__.py | 0 ibis/backends/flink/compiler/core.py | 170 ----- ibis/backends/flink/datatypes.py | 77 ++- ibis/backends/flink/ddl.py | 14 +- ibis/backends/flink/registry.py | 614 ------------------ ibis/backends/flink/tests/conftest.py | 63 +- .../test_complex_filtered_agg/out.sql | 14 +- .../test_complex_groupby_aggregation/out.sql | 14 +- .../test_complex_projections/out.sql | 20 +- .../test_compiler/test_count_star/out.sql | 9 +- .../test_extract_fields/day/out.sql | 5 +- .../test_extract_fields/day_of_year/out.sql | 5 +- .../test_extract_fields/hour/out.sql | 5 +- .../test_extract_fields/minute/out.sql | 5 +- .../test_extract_fields/month/out.sql | 5 +- .../test_extract_fields/quarter/out.sql | 5 +- .../test_extract_fields/second/out.sql | 5 +- .../test_extract_fields/week_of_year/out.sql | 5 +- .../test_extract_fields/year/out.sql | 5 +- .../test_compiler/test_filter/out.sql | 25 +- .../test_compiler/test_having/out.sql | 18 +- .../test_simple_filtered_agg/out.sql | 5 +- .../snapshots/test_compiler/test_sum/out.sql | 5 +- .../timestamp_ms/out.sql | 5 +- .../timestamp_s/out.sql | 5 +- .../test_compiler/test_value_counts/out.sql | 14 +- .../test_window_aggregation/out.sql | 17 +- .../test_compiler/test_window_topn/out.sql | 47 +- .../cumulate_window/out.sql | 7 +- .../test_windowing_tvf/hop_window/out.sql | 5 +- .../test_windowing_tvf/tumble_window/out.sql | 5 +- .../datetime/out.sql | 3 +- .../datetime_with_microseconds/out.sql | 3 +- .../string_time/out.sql | 3 +- .../string_timestamp/out.sql | 3 +- .../time/out.sql | 3 +- .../timestamp/out.sql | 3 +- .../test_window/test_range_window/out.sql | 5 +- .../test_window/test_rows_window/out.sql | 5 +- ibis/backends/flink/tests/test_ddl.py | 230 ++++--- ibis/backends/flink/tests/test_join.py | 7 +- ibis/backends/flink/tests/test_literals.py | 80 --- ibis/backends/flink/tests/test_window.py | 49 +- ibis/backends/flink/translator.py | 16 - ibis/backends/polars/__init__.py | 3 + ibis/backends/postgres/compiler.py | 4 +- .../test_many_subqueries/flink/out.sql | 42 ++ .../test_default_limit/flink/out.sql | 5 + .../test_disable_query_limit/flink/out.sql | 5 + .../flink/out.sql | 3 + .../test_respect_set_limit/flink/out.sql | 10 + .../test_cte_refs_in_topo_order/flink/out.sql | 20 + .../test_group_by_has_index/flink/out.sql | 38 ++ .../test_sql/test_isin_bug/flink/out.sql | 9 + ibis/backends/tests/test_aggregation.py | 25 +- ibis/backends/tests/test_array.py | 49 +- ibis/backends/tests/test_asof_join.py | 2 + ibis/backends/tests/test_benchmarks.py | 4 +- ibis/backends/tests/test_client.py | 49 +- ibis/backends/tests/test_dot_sql.py | 56 +- ibis/backends/tests/test_examples.py | 2 +- ibis/backends/tests/test_export.py | 40 +- ibis/backends/tests/test_generic.py | 37 +- ibis/backends/tests/test_join.py | 1 - ibis/backends/tests/test_map.py | 42 +- ibis/backends/tests/test_network.py | 4 +- ibis/backends/tests/test_numeric.py | 74 ++- ibis/backends/tests/test_param.py | 15 +- ibis/backends/tests/test_sql.py | 2 - ibis/backends/tests/test_string.py | 5 +- ibis/backends/tests/test_temporal.py | 233 +------ ibis/backends/tests/test_timecontext.py | 2 +- ibis/backends/tests/test_window.py | 75 +-- ibis/tests/expr/test_table.py | 2 +- poetry-overrides.nix | 19 + 82 files changed, 1512 insertions(+), 1847 deletions(-) create mode 100644 ibis/backends/flink/compiler.py delete mode 100644 ibis/backends/flink/compiler/__init__.py delete mode 100644 ibis/backends/flink/compiler/core.py delete mode 100644 ibis/backends/flink/registry.py delete mode 100644 ibis/backends/flink/tests/test_literals.py delete mode 100644 ibis/backends/flink/translator.py create mode 100644 ibis/backends/tests/snapshots/test_generic/test_many_subqueries/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_default_limit/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/flink/out.sql create mode 100644 ibis/backends/tests/snapshots/test_sql/test_isin_bug/flink/out.sql diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 46cfba02f80fc..38aa5bbe1520d 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -177,22 +177,28 @@ jobs: - oracle services: - oracle - # - name: flink - # title: Flink - # serial: true - # extras: - # - flink - # additional_deps: - # - apache-flink - # - pytest-split - # services: - # - flink - # - name: risingwave - # title: Risingwave - # services: - # - risingwave - # extras: - # - risingwave + - name: flink + title: Flink + serial: true + extras: + - flink + additional_deps: + - apache-flink + services: + - flink + include: + - os: ubuntu-latest + python-version: "3.10" + backend: + name: flink + title: Flink + serial: true + extras: + - flink + additional_deps: + - apache-flink + services: + - flink exclude: - os: windows-latest backend: @@ -296,32 +302,29 @@ jobs: - oracle services: - oracle - # - os: windows-latest - # backend: - # name: flink - # title: Flink - # serial: true - # extras: - # - flink - # services: - # - flink - # - python-version: "3.11" - # backend: - # name: flink - # title: Flink - # serial: true - # extras: - # - flink - # services: - # - flink - # - os: windows-latest - # backend: - # name: risingwave - # title: Risingwave - # services: - # - risingwave - # extras: - # - risingwave + - os: windows-latest + backend: + name: flink + title: Flink + serial: true + extras: + - flink + additional_deps: + - apache-flink + services: + - flink + - os: ubuntu-latest + python-version: "3.11" + backend: + name: flink + title: Flink + serial: true + extras: + - flink + additional_deps: + - apache-flink + services: + - flink - os: windows-latest backend: name: exasol @@ -390,29 +393,18 @@ jobs: IBIS_TEST_IMPALA_PORT: 21050 IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }} - # FIXME(deepyaman): If some backend-specific test, in test_ddl.py, - # executes before common tests, they will fail with: - # org.apache.flink.table.api.ValidationException: Table `default_catalog`.`default_database`.`functional_alltypes` was not found. - # Therefore, we run backend-specific tests second to avoid this. - # - name: "run serial tests: ${{ matrix.backend.name }}" - # if: matrix.backend.serial && matrix.backend.name == 'flink' - # run: | - # just ci-check -m ${{ matrix.backend.name }} ibis/backends/tests - # just ci-check -m ${{ matrix.backend.name }} ibis/backends/flink/tests - # env: - # IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }} - # FLINK_REMOTE_CLUSTER_ADDR: localhost - # FLINK_REMOTE_CLUSTER_PORT: "8081" - # - name: "run serial tests: ${{ matrix.backend.name }}" - if: matrix.backend.serial # && matrix.backend.name != 'flink' + if: matrix.backend.serial run: just ci-check -m ${{ matrix.backend.name }} env: + FLINK_REMOTE_CLUSTER_ADDR: localhost + FLINK_REMOTE_CLUSTER_PORT: "8081" IBIS_EXAMPLES_DATA: ${{ runner.temp }}/examples-${{ matrix.backend.name }}-${{ matrix.os }}-${{ steps.install_python.outputs.python-version }} - name: check that no untracked files were produced shell: bash - run: git checkout poetry.lock pyproject.toml && ! git status --porcelain | tee /dev/stderr | grep . + run: | + ! git status --porcelain | tee /dev/stderr | grep . - name: upload code coverage if: success() diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 075fb94bb0f2f..7e2dbd46b1e3e 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -1277,9 +1277,14 @@ def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: @functools.cache -def _get_backend_names() -> frozenset[str]: +def _get_backend_names(*, exclude: tuple[str] = ()) -> frozenset[str]: """Return the set of known backend names. + Parameters + ---------- + exclude + Exclude these backend names from the result + Notes ----- This function returns a frozenset to prevent cache pollution. @@ -1293,7 +1298,7 @@ def _get_backend_names() -> frozenset[str]: entrypoints = importlib.metadata.entry_points()["ibis.backends"] else: entrypoints = importlib.metadata.entry_points(group="ibis.backends") - return frozenset(ep.name for ep in entrypoints) + return frozenset(ep.name for ep in entrypoints).difference(exclude) def connect(resource: Path | str, **kwargs: Any) -> BaseBackend: diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index d6c2f90d1e5c5..c6760b06cde12 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -129,7 +129,7 @@ def __getitem__(self, key: str) -> sge.Column: def paren(expr): """Wrap a sqlglot expression in parentheses.""" - return sge.Paren(this=expr) + return sge.Paren(this=sge.convert(expr)) def parenthesize(op, arg): diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index dd9827d2a5498..9cd03adbd6222 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -1029,3 +1029,13 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: key_type = cls.from_ibis(dtype.key_type.copy(nullable=False)) value_type = cls.from_ibis(dtype.value_type) return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type]) + + +class FlinkType(SqlglotType): + dialect = "flink" + default_decimal_precision = 38 + default_decimal_scale = 18 + + @classmethod + def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType: + return sge.DataType(this=sge.DataType.Type.VARBINARY) diff --git a/ibis/backends/base/sqlglot/dialects.py b/ibis/backends/base/sqlglot/dialects.py index 74a603b268ba8..45d3d8638853b 100644 --- a/ibis/backends/base/sqlglot/dialects.py +++ b/ibis/backends/base/sqlglot/dialects.py @@ -70,6 +70,10 @@ class Generator(Postgres.Generator): class Flink(Hive): class Generator(Hive.Generator): + TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | { + sge.DataType.Type.TIME: "TIME", + } + TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | { sge.Stddev: rename_func("stddev_samp"), sge.StddevPop: rename_func("stddev_pop"), @@ -82,8 +86,21 @@ class Generator(Hive.Generator): ), sge.ArrayConcat: rename_func("array_concat"), sge.Length: rename_func("char_length"), + sge.TryCast: lambda self, + e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})", + sge.DayOfYear: rename_func("dayofyear"), + sge.DayOfWeek: rename_func("dayofweek"), + sge.DayOfMonth: rename_func("dayofmonth"), } + class Tokenizer(Hive.Tokenizer): + # In Flink, embedded single quotes are escaped like most other SQL + # dialects: doubling up the single quote + # + # We override it here because we inherit from Hive's dialect and Hive + # uses a backslash to escape single quotes + STRING_ESCAPES = ["'"] + class Impala(Hive): class Generator(Hive.Generator): diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 679c7680de62c..e052d3d6ad535 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -1,23 +1,21 @@ from __future__ import annotations import itertools -from functools import lru_cache from typing import TYPE_CHECKING, Any -import pyarrow as pa import sqlglot as sg +import sqlglot.expressions as sge import ibis.common.exceptions as exc import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir -from ibis.backends.base import BaseBackend, CanCreateDatabase, NoUrl -from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified -from ibis.backends.flink.compiler.core import FlinkCompiler +from ibis.backends.base import CanCreateDatabase, NoUrl +from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.flink.compiler import FlinkCompiler from ibis.backends.flink.ddl import ( CreateDatabase, CreateTableFromConnector, - CreateView, DropDatabase, DropTable, DropView, @@ -32,15 +30,16 @@ from pathlib import Path import pandas as pd + import pyarrow as pa from pyflink.table import Table, TableEnvironment from pyflink.table.table_result import TableResult - from ibis.api import Watermark + from ibis.expr.api import Watermark -class Backend(BaseBackend, CanCreateDatabase, NoUrl): +class Backend(SQLGlotBackend, CanCreateDatabase, NoUrl): name = "flink" - compiler = FlinkCompiler + compiler = FlinkCompiler() supports_temporary_tables = True supports_python_udfs = True @@ -71,6 +70,17 @@ def do_connect(self, table_env: TableEnvironment) -> None: def raw_sql(self, query: str) -> TableResult: return self._table_env.execute_sql(query) + def _metadata(self, query: str): + from pyflink.table.types import create_arrow_schema + + table = self._table_env.sql_query(query) + schema = table.get_schema() + pa_schema = create_arrow_schema( + schema.get_field_names(), schema.get_field_data_types() + ) + # sort of wasteful, but less code to write + return sch.Schema.from_pyarrow(pa_schema).items() + def list_databases(self, like: str | None = None) -> list[str]: databases = self._table_env.list_databases() return self._filter_with_like(databases, like) @@ -207,21 +217,6 @@ def list_views( return self._filter_with_like(views, like) - def _fully_qualified_name( - self, - name: str, - database: str | None = None, - catalog: str | None = None, - ) -> str: - if name and is_fully_qualified(name): - return name - - return sg.table( - name, - db=database or self.current_database, - catalog=catalog or self.current_catalog, - ).sql(dialect="hive") - def table( self, name: str, @@ -250,15 +245,12 @@ def table( f"`database` must be a string; got {type(database)}" ) schema = self.get_schema(name, catalog=catalog, database=database) - qualified_name = self._fully_qualified_name(name, catalog, database) - _, quoted, unquoted = fully_qualified_re.search(qualified_name).groups() - unqualified_name = quoted or unquoted node = ops.DatabaseTable( - unqualified_name, - schema, - self, - namespace=ops.Namespace(schema=database, database=catalog), - ) # TODO(chloeh13q): look into namespacing with catalog + db + name, + schema=schema, + source=self, + namespace=ops.Namespace(schema=catalog, database=database), + ) return node.to_expr() def get_schema( @@ -288,7 +280,9 @@ def get_schema( from ibis.backends.flink.datatypes import get_field_data_types - qualified_name = self._fully_qualified_name(table_name, catalog, database) + qualified_name = sg.table(table_name, db=catalog, catalog=database).sql( + self.name + ) table = self._table_env.from_path(qualified_name) pyflink_schema = table.get_schema() @@ -305,12 +299,9 @@ def version(self) -> str: return pyflink.version.__version__ def compile( - self, - expr: ir.Expr, - params: Mapping[ir.Expr, Any] | None = None, - **kwargs: Any, + self, expr: ir.Expr, params: Mapping[ir.Expr, Any] | None = None, **_: Any ) -> Any: - """Compile an expression.""" + """Compile an Ibis expression to Flink.""" return super().compile(expr, params=params) # Discard `limit` and other kwargs. def _to_sql(self, expr: ir.Expr, **kwargs: Any) -> str: @@ -604,7 +595,9 @@ def create_view( ) if isinstance(obj, pd.DataFrame): - qualified_name = self._fully_qualified_name(name, database, catalog) + qualified_name = sg.table( + name, db=database, catalog=catalog, quoted=self.compiler.quoted + ).sql(self.name) if schema: table = self._table_env.from_pandas( obj, FlinkRowSchema.from_ibis(schema) @@ -617,15 +610,18 @@ def create_view( elif isinstance(obj, ir.Table): query_expression = self.compile(obj) - statement = CreateView( - name=name, - query_expression=query_expression, - database=database, - can_exist=force, - temporary=temp, + stmt = sge.Create( + kind="VIEW", + this=sg.table( + name, db=database, catalog=catalog, quoted=self.compiler.quoted + ), + expression=query_expression, + exists=force, + properties=sge.Properties(expressions=[sge.TemporaryProperty()]) + if temp + else None, ) - sql = statement.compile() - self.raw_sql(sql) + self.raw_sql(stmt.sql(self.name)) else: raise exc.IbisError(f"Unsupported `obj` type: {type(obj)}") @@ -803,16 +799,6 @@ def read_json( file_type="json", path=path, schema=schema, table_name=table_name ) - @classmethod - @lru_cache - def _get_operations(cls): - translator = cls.compiler.translator_class - return translator._registry.keys() - - @classmethod - def has_operation(cls, operation: type[ops.Value]) -> bool: - return operation in cls._get_operations() - def insert( self, table_name: str, @@ -852,12 +838,9 @@ def insert( import pyarrow_hotfix # noqa: F401 if isinstance(obj, ir.Table): - expr = obj - ast = self.compiler.to_ast(expr) - select = ast.queries[0] statement = InsertSelect( table_name, - select, + self.compile(obj), database=database, catalog=catalog, overwrite=overwrite, @@ -891,6 +874,9 @@ def to_pyarrow( limit: int | str | None = None, **kwargs: Any, ) -> pa.Table: + import pyarrow as pa + import pyarrow_hotfix # noqa: F401 + pyarrow_batches = iter( self.to_pyarrow_batches(expr, params=params, limit=limit, **kwargs) ) @@ -914,6 +900,9 @@ def to_pyarrow_batches( limit: int | str | None = None, **kwargs: Any, ): + import pyarrow as pa + import pyarrow_hotfix # noqa: F401 + ibis_table = expr.as_table() if params is None and limit is None: @@ -946,7 +935,9 @@ def _from_ibis_table_to_pyflink_table(self, table: ir.Table) -> Table | None: # `table` is not a registered table in Flink. return None - qualified_name = self._fully_qualified_name(table_name) + qualified_name = sg.table(table_name, quoted=self.compiler.quoted).sql( + self.name + ) try: return self._table_env.from_path(qualified_name) except Py4JJavaError: @@ -959,17 +950,16 @@ def _from_pyflink_table_to_pyarrow_batches( *, chunk_size: int | None = None, ): - # Note (mehmet): Implementation of this is based on - # pyflink/table/table.py: to_pandas(). - + import pyarrow as pa + import pyarrow_hotfix # noqa: F401 import pytz from pyflink.java_gateway import get_gateway from pyflink.table.serializers import ArrowSerializer from pyflink.table.types import create_arrow_schema from ibis.backends.flink.datatypes import get_field_data_types - - pa = self._import_pyarrow() + # Note (mehmet): Implementation of this is based on + # pyflink/table/table.py: to_pandas(). gateway = get_gateway() if chunk_size: diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py new file mode 100644 index 0000000000000..af89583f0de98 --- /dev/null +++ b/ibis/backends/flink/compiler.py @@ -0,0 +1,604 @@ +"""Flink Ibis expression to SQL compiler.""" + +from __future__ import annotations + +from functools import singledispatchmethod + +import sqlglot as sg +import sqlglot.expressions as sge + +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.backends.base.sqlglot.compiler import STAR, SQLGlotCompiler, paren +from ibis.backends.base.sqlglot.datatypes import FlinkType +from ibis.backends.base.sqlglot.dialects import Flink +from ibis.backends.base.sqlglot.rewrites import ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_rank, + exclude_unsupported_window_frame_from_row_number, + rewrite_first_to_first_value, + rewrite_last_to_last_value, + rewrite_sample_as_filter, +) + + +class FlinkCompiler(SQLGlotCompiler): + quoted = True + dialect = Flink + type_mapper = FlinkType + rewrites = ( + rewrite_sample_as_filter, + exclude_unsupported_window_frame_from_row_number, + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_rank, + rewrite_first_to_first_value, + rewrite_last_to_last_value, + *SQLGlotCompiler.rewrites, + ) + + @property + def NAN(self): + raise NotImplementedError("Flink does not support NaN") + + @property + def POS_INF(self): + raise NotImplementedError("Flink does not support Infinity") + + NEG_INF = POS_INF + + def _aggregate(self, funcname: str, *args, where): + func = self.f[funcname] + if where is not None: + # FILTER (WHERE ) is broken for one or both of: + # + # 1. certain aggregates: std/var doesn't return the right result + # 2. certain kinds of predicates: x IN y doesn't filter the right + # values out + # 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y) + # returns an incorrect result + # + # One solution is to try `IF(predicate, arg, NULL)`. + # + # Unfortunately that won't work without casting the NULL to a + # specific type. + # + # At this point in the Ibis compiler we don't have any of the Ibis + # operation's type information because we thrown it away. In every + # other engine Ibis supports the type of a NULL literal is inferred + # by the engine. + # + # Using a CASE statement and leaving out the explicit NULL does the + # trick for Flink. + # + # Le sigh. + args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args) + return func(*args) + + @singledispatchmethod + def visit_node(self, op, **kwargs): + return super().visit_node(op, **kwargs) + + @staticmethod + def _minimize_spec(start, end, spec): + if ( + start is None + and isinstance(getattr(end, "value", None), ops.Literal) + and end.value.value == 0 + and end.following + ): + return None + return spec + + @visit_node.register(ops.TumbleWindowingTVF) + def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset): + args = [ + self.v[f"TABLE {table.this.sql(self.dialect)}"], + # `time_col` has the table _alias_, instead of the table, but it is + # required to be bound to the table, this happens because of the + # way we construct the op in the tumble API using bind + # + # perhaps there's a better way to deal with this + self.f.descriptor(time_col.this), + window_size, + offset, + ] + + return sg.select( + sge.Column( + this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) + ) + ).from_( + self.f.table(self.f.tumble(*filter(None, args))).as_( + table.alias_or_name, quoted=True + ) + ) + + @visit_node.register(ops.HopWindowingTVF) + def visit_HopWindowingTVF( + self, op, *, table, time_col, window_size, window_slide, offset + ): + args = [ + self.v[f"TABLE {table.this.sql(self.dialect)}"], + self.f.descriptor(time_col.this), + window_slide, + window_size, + offset, + ] + return sg.select( + sge.Column( + this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) + ) + ).from_( + self.f.table(self.f.hop(*filter(None, args))).as_( + table.alias_or_name, quoted=True + ) + ) + + @visit_node.register(ops.CumulateWindowingTVF) + def visit_CumulateWindowingTVF( + self, op, *, table, time_col, window_size, window_step, offset + ): + args = [ + self.v[f"TABLE {table.this.sql(self.dialect)}"], + self.f.descriptor(time_col.this), + window_step, + window_size, + offset, + ] + return sg.select( + sge.Column( + this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) + ) + ).from_( + self.f.table(self.f.cumulate(*filter(None, args))).as_( + table.alias_or_name, quoted=True + ) + ) + + @visit_node.register(ops.InMemoryTable) + def visit_InMemoryTable(self, op, *, name, schema, data): + # the performance of this is rather terrible + tuples = data.to_frame().itertuples(index=False) + quoted = self.quoted + columns = [sg.column(col, quoted=quoted) for col in schema.names] + alias = sge.TableAlias( + this=sg.to_identifier(name, quoted=quoted), columns=columns + ) + expressions = [ + sge.Tuple( + expressions=[ + self.visit_Literal( + ops.Literal(col, dtype=dtype), value=col, dtype=dtype + ) + for col, dtype in zip(row, schema.types) + ] + ) + for row in tuples + ] + + expr = sge.Values(expressions=expressions, alias=alias) + return sg.select(*columns).from_(expr) + + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_binary(): + # TODO: is this decode safe? + return self.cast(value.decode(), dtype) + elif dtype.is_uuid(): + return sge.convert(str(value)) + elif dtype.is_array(): + value_type = dtype.value_type + result = self.f.array( + *( + self.visit_Literal( + ops.Literal(v, dtype=value_type), value=v, dtype=value_type + ) + for v in value + ) + ) + if value: + return result + return sge.Cast(this=result, to=self.type_mapper.from_ibis(dtype)) + elif dtype.is_map(): + key_type = dtype.key_type + value_type = dtype.value_type + keys = self.f.array( + *( + self.visit_Literal( + ops.Literal(v, dtype=key_type), value=v, dtype=key_type + ) + for v in value.keys() + ) + ) + values = self.f.array( + *( + self.visit_Literal( + ops.Literal(v, dtype=value_type), value=v, dtype=value_type + ) + for v in value.values() + ) + ) + return self.cast(self.f.map_from_arrays(keys, values), dtype) + elif dtype.is_timestamp(): + return self.cast( + value.replace(tzinfo=None).isoformat(sep=" ", timespec="microseconds"), + dtype, + ) + elif dtype.is_date(): + return self.cast(value.isoformat(), dtype) + elif dtype.is_time(): + return self.cast(value.isoformat(timespec="microseconds"), dtype) + return None + + @visit_node.register(ops.ArrayIndex) + def visit_ArrayIndex(self, op, *, arg, index): + return sge.Bracket(this=arg, expressions=[index + 1]) + + @visit_node.register(ops.Xor) + def visit_Xor(self, op, *, left, right): + return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) + + @visit_node.register(ops.Literal) + def visit_Literal(self, op, *, value, dtype): + if value is None: + assert dtype.nullable, "dtype is not nullable but value is None" + if not dtype.is_null(): + return self.cast(sge.NULL, dtype) + return sge.NULL + return super().visit_Literal(op, value=value, dtype=dtype) + + @visit_node.register(ops.MapGet) + def visit_MapGet(self, op, *, arg, key, default): + if default is sge.NULL: + default = self.cast(default, op.dtype) + return self.f.coalesce(arg[key], default) + + @visit_node.register(ops.ArraySlice) + def visit_ArraySlice(self, op, *, arg, start, stop): + args = [arg, self.if_(start >= 0, start + 1, start)] + + if stop is not None: + args.append( + self.if_(stop >= 0, stop, self.f.cardinality(arg) - self.f.abs(stop)) + ) + + return self.f.array_slice(*args) + + @visit_node.register(ops.Not) + def visit_Not(self, op, *, arg): + return sg.not_(self.cast(arg, dt.boolean)) + + @visit_node.register(ops.Date) + def visit_Date(self, op, *, arg): + return self.cast(arg, dt.date) + + @visit_node.register(ops.TryCast) + def visit_TryCast(self, op, *, arg, to): + type_mapper = self.type_mapper + if op.arg.dtype.is_temporal() and to.is_numeric(): + return self.f.unix_timestamp( + sge.TryCast(this=arg, to=type_mapper.from_ibis(dt.string)) + ) + return sge.TryCast(this=arg, to=type_mapper.from_ibis(to)) + + @visit_node.register(ops.FloorDivide) + def visit_FloorDivide(self, op, *, left, right): + return self.f.floor(left / right) + + @visit_node.register(ops.JSONGetItem) + def visit_JSONGetItem(self, op, *, arg, index): + assert isinstance(op.index, ops.Literal) + idx = op.index + val = idx.value + if idx.dtype.is_integer(): + query_path = f"$[{val}]" + else: + assert idx.dtype.is_string(), idx.dtype + query_path = f"$.{val}" + + key_hack = f"{sge.convert(query_path).sql(self.dialect)} WITH CONDITIONAL ARRAY WRAPPER" + return self.f.json_query(arg, self.v[key_hack]) + + @visit_node.register(ops.TimestampFromUNIX) + def visit_TimestampFromUNIX(self, op, *, arg, unit): + from ibis.common.temporal import TimestampUnit + + if unit == TimestampUnit.MILLISECOND: + precision = 3 + elif unit == TimestampUnit.SECOND: + precision = 0 + else: + raise ValueError(f"{unit!r} unit is not supported!") + + return self.cast(self.f.to_timestamp_ltz(arg, precision), dt.timestamp) + + @visit_node.register(ops.Time) + def visit_Time(self, op, *, arg): + return self.cast(arg, op.dtype) + + @visit_node.register(ops.TimeFromHMS) + def visit_TimeFromHMS(self, op, *, hours, minutes, seconds): + padded_hour = self.f.lpad(self.cast(hours, dt.string), 2, "0") + padded_minute = self.f.lpad(self.cast(minutes, dt.string), 2, "0") + padded_second = self.f.lpad(self.cast(seconds, dt.string), 2, "0") + return self.cast( + self.f.concat(padded_hour, ":", padded_minute, ":", padded_second), op.dtype + ) + + @visit_node.register(ops.DateFromYMD) + def visit_DateFromYMD(self, op, *, year, month, day): + padded_year = self.f.lpad(self.cast(year, dt.string), 4, "0") + padded_month = self.f.lpad(self.cast(month, dt.string), 2, "0") + padded_day = self.f.lpad(self.cast(day, dt.string), 2, "0") + return self.cast( + self.f.concat(padded_year, "-", padded_month, "-", padded_day), op.dtype + ) + + @visit_node.register(ops.TimestampFromYMDHMS) + def visit_TimestampFromYMDHMS( + self, op, *, year, month, day, hours, minutes, seconds + ): + padded_year = self.f.lpad(self.cast(year, dt.string), 4, "0") + padded_month = self.f.lpad(self.cast(month, dt.string), 2, "0") + padded_day = self.f.lpad(self.cast(day, dt.string), 2, "0") + padded_hour = self.f.lpad(self.cast(hours, dt.string), 2, "0") + padded_minute = self.f.lpad(self.cast(minutes, dt.string), 2, "0") + padded_second = self.f.lpad(self.cast(seconds, dt.string), 2, "0") + return self.cast( + self.f.concat( + padded_year, + "-", + padded_month, + "-", + padded_day, + " ", + padded_hour, + ":", + padded_minute, + ":", + padded_second, + ), + op.dtype, + ) + + @visit_node.register(ops.ExtractEpochSeconds) + def visit_ExtractEpochSeconds(self, op, *, arg): + return self.f.unix_timestamp(self.cast(arg, dt.string)) + + @visit_node.register(ops.Cast) + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + if to.is_timestamp(): + if from_.is_numeric(): + arg = self.f.from_unixtime(arg) + if (tz := to.timezone) is not None: + return self.f.to_timestamp( + self.f.convert_tz(self.cast(arg, dt.string), "UTC+0", tz) + ) + else: + return self.f.to_timestamp(arg, "yyyy-MM-dd HH:mm:ss.SSS") + elif to.is_json(): + return arg + elif from_.is_temporal() and to.is_int64(): + return 1_000_000 * self.f.unix_timestamp(arg) + else: + return self.cast(arg, to) + + @visit_node.register(ops.IfElse) + def visit_IfElse(self, op, *, bool_expr, true_expr, false_null_expr): + return self.if_( + bool_expr, + true_expr if true_expr != sge.NULL else self.cast(true_expr, op.dtype), + ( + false_null_expr + if false_null_expr != sge.NULL + else self.cast(false_null_expr, op.dtype) + ), + ) + + @visit_node.register(ops.Log10) + def visit_Log10(self, op, *, arg): + return self.f.anon.log(10, arg) + + @visit_node.register(ops.ExtractMillisecond) + def visit_ExtractMillisecond(self, op, *, arg): + return self.f.extract(self.v.millisecond, arg) + + @visit_node.register(ops.ExtractMicrosecond) + def visit_ExtractMicrosecond(self, op, *, arg): + return self.f.extract(self.v.microsecond, arg) + + @visit_node.register(ops.DayOfWeekIndex) + def visit_DayOfWeekIndex(self, op, *, arg): + return (self.f.dayofweek(arg) + 5) % 7 + + @visit_node.register(ops.DayOfWeekName) + def visit_DayOfWeekName(self, op, *, arg): + index = self.cast(self.f.dayofweek(self.cast(arg, dt.date)), op.dtype) + lookup_table = self.f.str_to_map( + "1=Sunday,2=Monday,3=Tuesday,4=Wednesday,5=Thursday,6=Friday,7=Saturday" + ) + return lookup_table[index] + + @visit_node.register(ops.TimestampNow) + def visit_TimestampNow(self, op): + return self.v.current_timestamp + + @visit_node.register(ops.TimestampBucket) + def visit_TimestampBucket(self, op, *, arg, interval, offset): + unit = op.interval.dtype.unit.name + unit_var = self.v[unit] + + if offset is None: + offset = 0 + else: + offset = op.offset.value + + bucket_width = op.interval.value + unit_func = self.f["dayofmonth" if unit.upper() == "DAY" else unit] + + arg = self.f.anon.timestampadd(unit_var, -paren(offset), arg) + mod = unit_func(arg) % bucket_width + + return self.f.anon.timestampadd( + unit_var, + -paren(mod) + offset, + self.v[f"FLOOR({arg.sql(self.dialect)} TO {unit_var.sql(self.dialect)})"], + ) + + @visit_node.register(ops.TimeDelta) + @visit_node.register(ops.TimestampDelta) + @visit_node.register(ops.DateDelta) + def visit_TemporalDelta(self, op, *, part, left, right): + right = self.visit_TemporalTruncate(None, arg=right, unit=part) + left = self.visit_TemporalTruncate(None, arg=left, unit=part) + return self.f.anon.timestampdiff( + self.v[part.this], + self.cast(right, dt.timestamp), + self.cast(left, dt.timestamp), + ) + + @visit_node.register(ops.TimestampTruncate) + @visit_node.register(ops.DateTruncate) + @visit_node.register(ops.TimeTruncate) + def visit_TemporalTruncate(self, op, *, arg, unit): + unit_var = self.v[unit.name] + arg_sql = arg.sql(self.dialect) + unit_sql = unit_var.sql(self.dialect) + return self.f.floor(self.v[f"{arg_sql} TO {unit_sql}"]) + + @visit_node.register(ops.StringContains) + def visit_StringContains(self, op, *, haystack, needle): + return self.f.instr(haystack, needle) > 0 + + @visit_node.register(ops.StringFind) + def visit_StringFind(self, op, *, arg, substr, start, end): + if end is not None: + raise com.UnsupportedOperationError( + "String find doesn't support `end` argument" + ) + + if start is not None: + arg = self.f.substr(arg, start + 1) + pos = self.f.instr(arg, substr) + return self.if_(pos > 0, pos + start, 0) + + return self.f.instr(arg, substr) + + @visit_node.register(ops.StartsWith) + def visit_StartsWith(self, op, *, arg, start): + return self.f.left(arg, self.f.char_length(start)).eq(start) + + @visit_node.register(ops.EndsWith) + def visit_EndsWith(self, op, *, arg, end): + return self.f.right(arg, self.f.char_length(end)).eq(end) + + @visit_node.register(ops.ExtractProtocol) + @visit_node.register(ops.ExtractAuthority) + @visit_node.register(ops.ExtractUserInfo) + @visit_node.register(ops.ExtractHost) + @visit_node.register(ops.ExtractFile) + @visit_node.register(ops.ExtractPath) + def visit_ExtractUrlField(self, op, *, arg): + return self.f.parse_url(arg, type(op).__name__[len("Extract") :].upper()) + + @visit_node.register(ops.ExtractQuery) + def visit_ExtractQuery(self, op, *, arg, key): + return self.f.parse_url(*filter(None, (arg, "QUERY", key))) + + @visit_node.register(ops.ExtractFragment) + def visit_ExtractFragment(self, op, *, arg): + return self.f.parse_url(arg, "REF") + + @visit_node.register(ops.CountStar) + def visit_CountStar(self, op, *, arg, where): + if where is None: + return self.f.count(STAR) + return self.f.sum(self.cast(where, dt.int64)) + + @visit_node.register(ops.CountDistinct) + def visit_CountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg, self.f.array(arg)[2]) + return self.f.count(sge.Distinct(expressions=[arg])) + + @visit_node.register(ops.AnalyticVectorizedUDF) + @visit_node.register(ops.ApproxMedian) + @visit_node.register(ops.Arbitrary) + @visit_node.register(ops.ArgMax) + @visit_node.register(ops.ArgMin) + @visit_node.register(ops.ArrayCollect) + @visit_node.register(ops.ArrayFlatten) + @visit_node.register(ops.ArraySort) + @visit_node.register(ops.ArrayStringJoin) + @visit_node.register(ops.Correlation) + @visit_node.register(ops.CountDistinctStar) + @visit_node.register(ops.Covariance) + @visit_node.register(ops.DateDiff) + @visit_node.register(ops.ExtractURLField) + @visit_node.register(ops.FindInSet) + @visit_node.register(ops.IsInf) + @visit_node.register(ops.IsNan) + @visit_node.register(ops.Levenshtein) + @visit_node.register(ops.MapMerge) + @visit_node.register(ops.Median) + @visit_node.register(ops.MultiQuantile) + @visit_node.register(ops.NthValue) + @visit_node.register(ops.Quantile) + @visit_node.register(ops.ReductionVectorizedUDF) + @visit_node.register(ops.RegexSplit) + @visit_node.register(ops.RowID) + @visit_node.register(ops.ScalarUDF) + @visit_node.register(ops.StringSplit) + @visit_node.register(ops.Translate) + @visit_node.register(ops.Unnest) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) + + @staticmethod + def _generate_groups(groups): + return groups + + +_SIMPLE_OPS = { + ops.All: "min", + ops.Any: "max", + ops.ApproxCountDistinct: "approx_count_distinct", + ops.ArrayDistinct: "array_distinct", + ops.ArrayLength: "cardinality", + ops.ArrayPosition: "array_position", + ops.ArrayRemove: "array_remove", + ops.ArrayUnion: "array_union", + ops.ExtractDayOfYear: "dayofyear", + ops.First: "first_value", + ops.Last: "last_value", + ops.Map: "map_from_arrays", + ops.Power: "power", + ops.RandomScalar: "rand", + ops.RegexSearch: "regexp", + ops.StrRight: "right", + ops.StringLength: "char_length", + ops.StringToTimestamp: "to_timestamp", + ops.Strip: "trim", + ops.TypeOf: "typeof", +} + + +for _op, _name in _SIMPLE_OPS.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + + @FlinkCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, where, **kw): + return self.agg[_name](*kw.values(), where=where) + + else: + + @FlinkCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, **kw): + return self.f[_name](*kw.values()) + + setattr(FlinkCompiler, f"visit_{_op.__name__}", _fmt) + + +del _op, _name, _fmt diff --git a/ibis/backends/flink/compiler/__init__.py b/ibis/backends/flink/compiler/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/ibis/backends/flink/compiler/core.py b/ibis/backends/flink/compiler/core.py deleted file mode 100644 index be187ebdb664a..0000000000000 --- a/ibis/backends/flink/compiler/core.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Flink ibis expression to SQL string compiler.""" - -from __future__ import annotations - -import functools - -import ibis.common.exceptions as com -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis.backends.base.sql.compiler import ( - Compiler, - Select, - SelectBuilder, - TableSetFormatter, -) -from ibis.backends.base.sql.registry import quote_identifier -from ibis.backends.base.sqlglot.dialects import Flink -from ibis.backends.flink.translator import FlinkExprTranslator - - -class FlinkTableSetFormatter(TableSetFormatter): - def _quote_identifier(self, name): - return quote_identifier(name, force=True) - - def _format_in_memory_table(self, op): - names = op.schema.names - raw_rows = [] - for row in op.data.to_frame().itertuples(index=False): - raw_row = [] - for val, name in zip(row, names): - lit = ops.Literal(val, dtype=op.schema[name]) - raw_row.append(self._translate(lit)) - raw_rows.append(", ".join(raw_row)) - rows = ", ".join(f"({raw_row})" for raw_row in raw_rows) - return f"(VALUES {rows})" - - def _format_window_tvf(self, op) -> str: - if isinstance(op, ops.TumbleWindowingTVF): - function_type = "TUMBLE" - elif isinstance(op, ops.HopWindowingTVF): - function_type = "HOP" - elif isinstance(op, ops.CumulateWindowingTVF): - function_type = "CUMULATE" - return f"TABLE({function_type}({format_windowing_tvf_params(op, self)}))" - - def _format_table(self, op) -> str: - ctx = self.context - if isinstance(op, ops.WindowingTVF): - formatted_table = self._format_window_tvf(op) - return f"{formatted_table} {ctx.get_ref(op)}" - else: - result = super()._format_table(op) - - ref_op = op - if isinstance(op, ops.SelfReference): - ref_op = op.table - - if isinstance(ref_op, ops.InMemoryTable): - names = op.schema.names - result += f"({', '.join(self._quote_identifier(name) for name in names)})" - - return result - - -class FlinkSelectBuilder(SelectBuilder): - def _convert_group_by(self, exprs): - return exprs - - -class FlinkSelect(Select): - def format_group_by(self) -> str: - if not len(self.group_by): - # There is no aggregation, nothing to see here - return None - - lines = [] - if len(self.group_by) > 0: - group_keys = map(self._translate, self.group_by) - clause = "GROUP BY {}".format(", ".join(list(group_keys))) - lines.append(clause) - - if len(self.having) > 0: - trans_exprs = [] - for expr in self.having: - translated = self._translate(expr) - trans_exprs.append(translated) - lines.append("HAVING {}".format(" AND ".join(trans_exprs))) - - return "\n".join(lines) - - -class FlinkCompiler(Compiler): - translator_class = FlinkExprTranslator - table_set_formatter_class = FlinkTableSetFormatter - select_builder_class = FlinkSelectBuilder - select_class = FlinkSelect - - cheap_in_memory_tables = True - - dialect = Flink - - @classmethod - def to_sql(cls, node, context=None, params=None): - if isinstance(node, ir.Expr): - node = node.op() - - if isinstance(node, ops.Literal): - from ibis.backends.flink.utils import translate_literal - - return translate_literal(node) - - return super().to_sql(node, context, params) - - -@functools.singledispatch -def format_windowing_tvf_params( - op: ops.WindowingTVF, formatter: TableSetFormatter -) -> str: - raise com.OperationNotDefinedError(f"No formatting rule for {type(op)}") - - -@format_windowing_tvf_params.register(ops.TumbleWindowingTVF) -def _tumble_window_params( - op: ops.TumbleWindowingTVF, formatter: TableSetFormatter -) -> str: - return ", ".join( - filter( - None, - [ - f"TABLE {formatter._quote_identifier(op.table.name)}", - f"DESCRIPTOR({formatter._translate(op.time_col)})", - formatter._translate(op.window_size), - formatter._translate(op.offset) if op.offset else None, - ], - ) - ) - - -@format_windowing_tvf_params.register(ops.HopWindowingTVF) -def _hop_window_params(op: ops.HopWindowingTVF, formatter: TableSetFormatter) -> str: - return ", ".join( - filter( - None, - [ - f"TABLE {formatter._quote_identifier(op.table.name)}", - f"DESCRIPTOR({formatter._translate(op.time_col)})", - formatter._translate(op.window_slide), - formatter._translate(op.window_size), - formatter._translate(op.offset) if op.offset else None, - ], - ) - ) - - -@format_windowing_tvf_params.register(ops.CumulateWindowingTVF) -def _cumulate_window_params( - op: ops.CumulateWindowingTVF, formatter: TableSetFormatter -) -> str: - return ", ".join( - filter( - None, - [ - f"TABLE {formatter._quote_identifier(op.table.name)}", - f"DESCRIPTOR({formatter._translate(op.time_col)})", - formatter._translate(op.window_step), - formatter._translate(op.window_size), - formatter._translate(op.offset) if op.offset else None, - ], - ) - ) diff --git a/ibis/backends/flink/datatypes.py b/ibis/backends/flink/datatypes.py index 3291314513e32..664fb365d01b8 100644 --- a/ibis/backends/flink/datatypes.py +++ b/ibis/backends/flink/datatypes.py @@ -2,7 +2,14 @@ from typing import TYPE_CHECKING -from pyflink.table.types import DataType, DataTypes, RowType, _from_java_data_type +from pyflink.table.types import ( + ArrayType, + DataType, + DataTypes, + MapType, + RowType, + _from_java_data_type, +) import ibis.expr.datatypes as dt import ibis.expr.schema as sch @@ -28,8 +35,9 @@ def from_ibis(cls, schema: sch.Schema | None) -> list[RowType]: class FlinkType(TypeMapper): @classmethod - def to_ibis(cls, typ: DataType, nullable=True) -> dt.DataType: + def to_ibis(cls, typ: DataType) -> dt.DataType: """Convert a flink type to an ibis type.""" + nullable = typ.nullable if typ == DataTypes.STRING(): return dt.String(nullable=nullable) elif typ == DataTypes.BOOLEAN(): @@ -53,8 +61,18 @@ def to_ibis(cls, typ: DataType, nullable=True) -> dt.DataType: elif typ == DataTypes.TIME(): return dt.Time(nullable=nullable) elif typ == DataTypes.TIMESTAMP(): - return dt.Timestamp( - scale=typ.precision, + return dt.Timestamp(scale=typ.precision, nullable=nullable) + elif isinstance(typ, ArrayType): + return dt.Array(value_type=cls.to_ibis(typ.element_type), nullable=nullable) + elif isinstance(typ, MapType): + return dt.Map( + key_type=cls.to_ibis(typ.key_type), + value_type=cls.to_ibis(typ.value_type), + nullable=nullable, + ) + elif isinstance(typ, RowType): + return dt.Struct( + {field.name: cls.to_ibis(field.data_type) for field in typ.fields}, nullable=nullable, ) else: @@ -63,38 +81,39 @@ def to_ibis(cls, typ: DataType, nullable=True) -> dt.DataType: @classmethod def from_ibis(cls, dtype: dt.DataType) -> DataType: """Convert an ibis type to a flink type.""" + nullable = dtype.nullable if dtype.is_string(): - return DataTypes.STRING(nullable=dtype.nullable) + return DataTypes.STRING(nullable=nullable) elif dtype.is_boolean(): - return DataTypes.BOOLEAN(nullable=dtype.nullable) + return DataTypes.BOOLEAN(nullable=nullable) elif dtype.is_binary(): - return DataTypes.BYTES(nullable=dtype.nullable) + return DataTypes.BYTES(nullable=nullable) elif dtype.is_int8(): - return DataTypes.TINYINT(nullable=dtype.nullable) + return DataTypes.TINYINT(nullable=nullable) elif dtype.is_int16(): - return DataTypes.SMALLINT(nullable=dtype.nullable) + return DataTypes.SMALLINT(nullable=nullable) elif dtype.is_int32(): - return DataTypes.INT(nullable=dtype.nullable) + return DataTypes.INT(nullable=nullable) elif dtype.is_int64(): - return DataTypes.BIGINT(nullable=dtype.nullable) + return DataTypes.BIGINT(nullable=nullable) elif dtype.is_uint8(): - return DataTypes.TINYINT(nullable=dtype.nullable) + return DataTypes.TINYINT(nullable=nullable) elif dtype.is_uint16(): - return DataTypes.SMALLINT(nullable=dtype.nullable) + return DataTypes.SMALLINT(nullable=nullable) elif dtype.is_uint32(): - return DataTypes.INT(nullable=dtype.nullable) + return DataTypes.INT(nullable=nullable) elif dtype.is_uint64(): - return DataTypes.BIGINT(nullable=dtype.nullable) + return DataTypes.BIGINT(nullable=nullable) elif dtype.is_float16(): - return DataTypes.FLOAT(nullable=dtype.nullable) + return DataTypes.FLOAT(nullable=nullable) elif dtype.is_float32(): - return DataTypes.FLOAT(nullable=dtype.nullable) + return DataTypes.FLOAT(nullable=nullable) elif dtype.is_float64(): - return DataTypes.DOUBLE(nullable=dtype.nullable) + return DataTypes.DOUBLE(nullable=nullable) elif dtype.is_date(): - return DataTypes.DATE(nullable=dtype.nullable) + return DataTypes.DATE(nullable=nullable) elif dtype.is_time(): - return DataTypes.TIME(nullable=dtype.nullable) + return DataTypes.TIME(nullable=nullable) elif dtype.is_timestamp(): # Note (mehmet): If `precision` is None, set it to 6. # This is because `DataTypes.TIMESTAMP` throws TypeError @@ -102,7 +121,23 @@ def from_ibis(cls, dtype: dt.DataType) -> DataType: # if it is not provided. return DataTypes.TIMESTAMP( precision=dtype.scale if dtype.scale is not None else 6, - nullable=dtype.nullable, + nullable=nullable, + ) + elif dtype.is_array(): + return DataTypes.ARRAY(cls.from_ibis(dtype.value_type), nullable=nullable) + elif dtype.is_map(): + return DataTypes.MAP( + key_type=cls.from_ibis(dtype.key_type), + value_type=cls.from_ibis(dtype.key_type), + nullable=nullable, + ) + elif dtype.is_struct(): + return DataTypes.ROW( + [ + DataTypes.FIELD(name, data_type=cls.from_ibis(typ)) + for name, typ in dtype.items() + ], + nullable=nullable, ) else: return super().from_ibis(dtype) diff --git a/ibis/backends/flink/ddl.py b/ibis/backends/flink/ddl.py index c48f5063defda..13e90e6eb75e1 100644 --- a/ibis/backends/flink/ddl.py +++ b/ibis/backends/flink/ddl.py @@ -19,13 +19,13 @@ is_fully_qualified, ) from ibis.backends.base.sql.registry import quote_identifier -from ibis.backends.flink.registry import type_to_sql_string +from ibis.backends.base.sqlglot.datatypes import FlinkType from ibis.util import promote_list if TYPE_CHECKING: from collections.abc import Sequence - from ibis.api import Watermark + from ibis.expr.api import Watermark def format_schema(schema: sch.Schema): @@ -41,13 +41,11 @@ def _format_schema_element(name, t): def type_to_flink_sql_string(tval): + sql_string = FlinkType.from_ibis(tval) if tval.is_timestamp(): - return f"timestamp({tval.scale})" if tval.scale is not None else "timestamp" + return f"TIMESTAMP({tval.scale})" if tval.scale is not None else "TIMESTAMP" else: - sql_string = type_to_sql_string(tval) - if not tval.nullable: - sql_string += " NOT NULL" - return sql_string + return sql_string.sql("flink") + " NOT NULL" * (not tval.nullable) def _format_watermark_strategy(watermark: Watermark) -> str: @@ -364,7 +362,7 @@ def compile(self): else: partition = "" - select_query = self.select.compile() + select_query = self.select scoped_name = self._get_scoped_name( self.table_name, self.database, self.catalog ) diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py deleted file mode 100644 index 3fc518de536d3..0000000000000 --- a/ibis/backends/flink/registry.py +++ /dev/null @@ -1,614 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import ibis.common.exceptions as com -import ibis.expr.operations as ops -from ibis.backends.base.sql.registry import ( - aggregate, - fixed_arity, - helpers, - quote_identifier, - unary, -) -from ibis.backends.base.sql.registry import ( - operation_registry as base_operation_registry, -) -from ibis.backends.base.sql.registry.main import varargs -from ibis.common.temporal import TimestampUnit - -if TYPE_CHECKING: - from ibis.backends.base.sql.compiler import ExprTranslator - -operation_registry = base_operation_registry.copy() - - -def type_to_sql_string(tval): - if tval.is_array(): - return f"array<{helpers.type_to_sql_string(tval.value_type)}>" - return helpers.type_to_sql_string(tval) - - -def _not(translator: ExprTranslator, op: ops.Node) -> str: - formatted_arg = translator.translate(op.arg) - if helpers.needs_parens(op.arg): - formatted_arg = helpers.parenthesize(formatted_arg) - return f"NOT CAST({formatted_arg} AS boolean)" - - -def _count_star(translator: ExprTranslator, op: ops.Node) -> str: - if (where := op.where) is not None: - condition = f" FILTER (WHERE {translator.translate(where)})" - else: - condition = "" - - return f"COUNT(*){condition}" - - -def _string_concat(translator: ExprTranslator, op: ops.StringConcat) -> str: - joined_args = ", ".join(map(translator.translate, op.arg)) - return f"CONCAT({joined_args})" - - -def _strftime(translator: ExprTranslator, op: ops.Strftime) -> str: - import sqlglot as sg - - import ibis.expr.datatypes as dt - - hive_dialect = sg.dialects.hive.Hive - if (time_mapping := getattr(hive_dialect, "TIME_MAPPING", None)) is None: - time_mapping = hive_dialect.time_mapping - reverse_hive_mapping = {v: k for k, v in time_mapping.items()} - - format_str = translator.translate(op.format_str) - transformed_format_str = sg.time.format_time(format_str, reverse_hive_mapping) - arg = translator.translate(ops.Cast(op.arg, to=dt.string)) - - return f"FROM_UNIXTIME(UNIX_TIMESTAMP({arg}), {transformed_format_str})" - - -def _date(translator: ExprTranslator, op: ops.Node) -> str: - (arg,) = op.args - return f"CAST({translator.translate(arg)} AS DATE)" - - -def _extract_field(sql_attr: str) -> str: - def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str: - arg = translator.translate(op.args[0]) - return f"EXTRACT({sql_attr} from {arg})" - - return extract_field_formatter - - -def _cast(translator: ExprTranslator, op: ops.generic.Cast) -> str: - arg, to = op.arg, op.to - arg_translated = translator.translate(arg) - if to.is_timestamp(): - if arg.dtype.is_numeric(): - arg_translated = f"FROM_UNIXTIME({arg_translated})" - - if to.timezone: - return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{to.timezone}'))" - else: - return f"TO_TIMESTAMP({arg_translated}, 'yyyy-MM-dd HH:mm:ss.SSS')" - - elif to.is_date(): - return f"CAST({arg_translated} AS date)" - elif to.is_json(): - return arg_translated - elif op.arg.dtype.is_temporal() and op.to.is_int64(): - return f"1000000 * unix_timestamp({arg_translated})" - else: - sql_type = type_to_sql_string(op.to) - return f"CAST({arg_translated} AS {sql_type})" - - -def _left_op_right(translator: ExprTranslator, op_node: ops.Node, op_sign: str) -> str: - """Utility to be used in operators that perform '{op.left} {op_sign} {op.right}'.""" - return f"{translator.translate(op_node.left)} {op_sign} {translator.translate(op_node.right)}" - - -def _interval_add(translator: ExprTranslator, op: ops.temporal.IntervalSubtract) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="+") - - -def _interval_subtract( - translator: ExprTranslator, op: ops.temporal.IntervalSubtract -) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="-") - - -def _literal(translator: ExprTranslator, op: ops.Literal) -> str: - from ibis.backends.flink.utils import translate_literal - - return translate_literal(op) - - -def _try_cast(translator: ExprTranslator, op: ops.Node) -> str: - arg_formatted = translator.translate(op.arg) - - if op.arg.dtype.is_temporal() and op.to.is_numeric(): - # The cast from TIMESTAMP type to NUMERIC type is not allowed. - # It's recommended to use UNIX_TIMESTAMP(CAST(timestamp_col AS STRING)) instead. - return f"UNIX_TIMESTAMP(TRY_CAST({arg_formatted} AS STRING))" - else: - sql_type = type_to_sql_string(op.to) - return f"TRY_CAST({arg_formatted} AS {sql_type})" - - -def _filter(translator: ExprTranslator, op: ops.Node) -> str: - bool_expr = translator.translate(op.bool_expr) - true_expr = translator.translate(op.true_expr) - false_null_expr = translator.translate(op.false_null_expr) - - # [TODO](chloeh13q): It's preferable to use the FILTER syntax instead of CASE WHEN - # to let the planner do more optimizations to reduce the state size; besides, FILTER - # is more compliant with the SQL standard. - # For example, - # ``` - # COUNT(DISTINCT CASE WHEN flag = 'app' THEN user_id ELSE NULL END) AS app_uv - # ``` - # is equivalent to - # ``` - # COUNT(DISTINCT) FILTER (WHERE flag = 'app') AS app_uv - # ``` - return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END" - - -def _format_window_start(translator: ExprTranslator, boundary): - if boundary is None: - return "UNBOUNDED PRECEDING" - - if isinstance(boundary.value, ops.Literal) and boundary.value.value == 0: - return "CURRENT ROW" - - value = translator.translate(boundary.value) - return f"{value} PRECEDING" - - -def _format_window_end(translator: ExprTranslator, boundary): - if boundary is None: - raise com.UnsupportedOperationError( - "OVER RANGE FOLLOWING windows are not supported in Flink yet" - ) - - value = boundary.value - if isinstance(value, ops.Cast): - value = boundary.value.arg - if isinstance(value, ops.Literal): - if value.value != 0: - raise com.UnsupportedOperationError( - "OVER RANGE FOLLOWING windows are not supported in Flink yet" - ) - - return "CURRENT ROW" - - -def _format_window_frame(translator: ExprTranslator, func, frame): - components = [] - - if frame.group_by: - partition_args = ", ".join(map(translator.translate, frame.group_by)) - components.append(f"PARTITION BY {partition_args}") - - (order_by,) = frame.order_by - components.append(f"ORDER BY {translator.translate(order_by)}") - - if frame.start is None and frame.end is None: - # no-op, default is full sample - pass - elif not isinstance(func, translator._forbids_frame_clause): - # [NOTE] Flink allows - # "ROWS BETWEEN INTERVAL [...] PRECEDING AND CURRENT ROW" - # but not - # "RANGE BETWEEN [...] PRECEDING AND CURRENT ROW", - # but `.over(rows=(-ibis.interval(...), 0)` is not allowed in Ibis - if isinstance(frame, ops.RangeWindowFrame): - if not frame.start.value.dtype.is_interval(): - # [TODO] need to expand support for range-based interval windowing on expr - # side, for now only ibis intervals can be used - raise com.UnsupportedOperationError( - "Data Type mismatch between ORDER BY and RANGE clause" - ) - - start = _format_window_start(translator, frame.start) - end = _format_window_end(translator, frame.end) - - frame = f"{frame.how.upper()} BETWEEN {start} AND {end}" - components.append(frame) - - return "OVER ({})".format(" ".join(components)) - - -def _window(translator: ExprTranslator, op: ops.Node) -> str: - frame = op.frame - if not frame.order_by: - raise com.UnsupportedOperationError( - "Flink engine does not support generic window clause with no order by" - ) - if len(frame.order_by) > 1: - raise com.UnsupportedOperationError( - "Windows in Flink can only be ordered by a single time column" - ) - - _unsupported_reductions = translator._unsupported_reductions - - func = op.func.__window_op__ - - if isinstance(func, _unsupported_reductions): - raise com.UnsupportedOperationError( - f"{type(func)} is not supported in window functions" - ) - - if isinstance(frame, ops.RowsWindowFrame): - if frame.max_lookback is not None: - raise NotImplementedError( - "Rows with max lookback is not implemented for SQL-based backends." - ) - - window_formatted = _format_window_frame(translator, func, frame) - - arg_formatted = translator.translate(func.__window_op__) - result = f"{arg_formatted} {window_formatted}" - - if isinstance(func, (ops.RankBase, ops.NTile)): - return f"({result} - 1)" - return result - - -def _clip(translator: ExprTranslator, op: ops.Node) -> str: - from ibis.backends.flink.datatypes import FlinkType - - arg = translator.translate(op.arg) - - if op.upper is not None: - upper = translator.translate(op.upper) - arg = f"IF({arg} > {upper} AND {arg} IS NOT NULL, {upper}, {arg})" - - if op.lower is not None: - lower = translator.translate(op.lower) - arg = f"IF({arg} < {lower} AND {arg} IS NOT NULL, {lower}, {arg})" - - return f"CAST({arg} AS {FlinkType.from_ibis(op.dtype)!s})" - - -def _ntile(translator: ExprTranslator, op: ops.NTile) -> str: - return f"NTILE({op.buckets.value})" - - -def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: - left = translator.translate(op.left) - right = translator.translate(op.right) - return f"FLOOR(({left}) / ({right}))" - - -def _array(translator: ExprTranslator, op: ops.Array) -> str: - return f"ARRAY[{', '.join(map(translator.translate, op.exprs))}]" - - -def _array_index(translator: ExprTranslator, op: ops.ArrayIndex): - table_column = op.arg - index = op.index - - table_column_translated = translator.translate(table_column) - index_translated = translator.translate(index) - - return f"{table_column_translated} [ {index_translated} + 1 ]" - - -def _array_length(translator: ExprTranslator, op: ops.ArrayLength) -> str: - return f"CARDINALITY({translator.translate(op.arg)})" - - -def _array_position(translator: ExprTranslator, op: ops.ArrayPosition) -> str: - arg = translator.translate(op.arg) - other = translator.translate(op.other) - return f"ARRAY_POSITION({arg}, {other}) - 1" - - -def _array_slice(translator: ExprTranslator, op: ops.ArraySlice) -> str: - array = translator.translate(op.arg) - start = op.start.value - # The offsets are 1-based for ARRAY_SLICE. - # Ref: https://nightlies.apache.org/flink/flink-docs-master/docs/dev/table/functions/systemfunctions - if start >= 0: - start += 1 - - if op.stop is None: - return f"ARRAY_SLICE({array}, {start})" - - stop = op.stop.value - if stop >= 0: - return f"ARRAY_SLICE({array}, {start}, {stop})" - else: - # To imitate the behavior of pandas array slicing. - return f"ARRAY_SLICE({array}, {start}, CARDINALITY({array}) - {abs(stop)})" - - -def _json_get_item(translator: ExprTranslator, op: ops.json.JSONGetItem) -> str: - arg_translated = translator.translate(op.arg) - if op.index.dtype.is_integer(): - query_path = f"$[{op.index.value}]" - else: # is string - query_path = f"$.{op.index.value}" - - return ( - f"JSON_QUERY({arg_translated}, '{query_path}' WITH CONDITIONAL ARRAY WRAPPER)" - ) - - -def _map(translator: ExprTranslator, op: ops.maps.Map) -> str: - key_array = translator.translate(op.keys) - value_array = translator.translate(op.values) - - return f"MAP_FROM_ARRAYS({key_array}, {value_array})" - - -def _map_get(translator: ExprTranslator, op: ops.maps.MapGet) -> str: - map_ = translator.translate(op.arg) - key = translator.translate(op.key) - return f"{map_} [ {key} ]" - - -def _struct_field(translator: ExprTranslator, op: ops.StructField) -> str: - arg = translator.translate(op.arg) - return f"{arg}.`{op.field}`" - - -def _day_of_week_index( - translator: ExprTranslator, op: ops.temporal.DayOfWeekIndex -) -> str: - arg = translator.translate(op.arg) - return f"MOD(DAYOFWEEK({arg}) + 5, 7)" - - -def _day_of_week_name( - translator: ExprTranslator, op: ops.temporal.DayOfWeekName -) -> str: - arg = translator.translate(op.arg) - map_str = "1=Sunday,2=Monday,3=Tuesday,4=Wednesday,5=Thursday,6=Friday,7=Saturday" - return f"STR_TO_MAP('{map_str}')[CAST(DAYOFWEEK(CAST({arg} AS DATE)) AS STRING)]" - - -def _date_add(translator: ExprTranslator, op: ops.temporal.DateAdd) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="+") - - -def _date_delta(translator: ExprTranslator, op: ops.temporal.DateDelta) -> str: - left = translator.translate(op.left) - right = translator.translate(op.right) - unit = op.part.value.upper() - - return ( - f"TIMESTAMPDIFF({unit}, CAST({right} AS TIMESTAMP), CAST({left} AS TIMESTAMP))" - ) - - -def _date_diff(translator: ExprTranslator, op: ops.temporal.DateDiff) -> str: - raise com.UnsupportedOperationError("DATE_DIFF is not supported in Flink.") - - -def _date_from_ymd(translator: ExprTranslator, op: ops.temporal.DateFromYMD) -> str: - year, month, day = ( - f"CAST({translator.translate(e)} AS STRING)" - for e in [op.year, op.month, op.day] - ) - concat_string = f"CONCAT({year}, '-', {month}, '-', {day})" - return f"CAST({concat_string} AS DATE)" - - -def _date_sub(translator: ExprTranslator, op: ops.temporal.DateSub) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="-") - - -def _extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str: - arg = translator.translate(op.arg) - return f"UNIX_TIMESTAMP(CAST({arg} AS STRING))" - - -def _string_to_timestamp( - translator: ExprTranslator, op: ops.temporal.StringToTimestamp -) -> str: - arg = translator.translate(op.arg) - format_string = translator.translate(op.format_str) - return f"TO_TIMESTAMP({arg}, {format_string})" - - -def _time(translator: ExprTranslator, op: ops.temporal.Time) -> str: - if op.arg.dtype.is_timestamp(): - datetime = op.arg.value - return f"TIME '{datetime.hour}:{datetime.minute}:{datetime.second}.{datetime.microsecond}'" - - else: - raise com.UnsupportedOperationError(f"Does NOT support dtype= {op.arg.dtype}") - - -def _time_delta(translator: ExprTranslator, op: ops.temporal.TimeDiff) -> str: - left = translator.translate(op.left) - right = translator.translate(op.right) - unit = op.part.value.upper() - - return ( - f"TIMESTAMPDIFF({unit}, CAST({right} AS TIMESTAMP), CAST({left} AS TIMESTAMP))" - ) - - -def _time_from_hms(translator: ExprTranslator, op: ops.temporal.TimeFromHMS) -> str: - hours, minutes, seconds = ( - f"CAST({translator.translate(e)} AS STRING)" - for e in [op.hours, op.minutes, op.seconds] - ) - concat_string = f"CONCAT({hours}, ':', {minutes}, ':', {seconds})" - return f"CAST({concat_string} AS TIME)" - - -def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="+") - - -def _timestamp_bucket( - translator: ExprTranslator, op: ops.temporal.TimestampBucket -) -> str: - arg_translated = translator.translate(op.arg) - - unit = op.interval.dtype.unit.name - unit_for_mod = "DAYOFMONTH" if unit == "DAY" else unit - bucket_width = op.interval.value - offset = op.offset.value if op.offset else 0 - - arg_offset = f"TIMESTAMPADD({unit}, -({offset}), {arg_translated})" - num = f"{unit_for_mod}({arg_offset})" - mod = f"{num} % {bucket_width}" - - return f"TIMESTAMPADD({unit}, -({mod}) + {offset}, FLOOR({arg_offset} TO {unit}))" - - -def _timestamp_delta( - translator: ExprTranslator, op: ops.temporal.TimestampDelta -) -> str: - left = translator.translate(op.left) - right = translator.translate(op.right) - unit = op.part.value.upper() - - return f"TIMESTAMPDIFF({unit}, {right}, {left})" - - -def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="-") - - -def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str: - table_column = op.left - interval = op.right - - table_column_translated = translator.translate(table_column) - interval_translated = translator.translate(interval) - return f"{table_column_translated} - {interval_translated}" - - -def _timestamp_from_unix(translator: ExprTranslator, op: ops.TimestampFromUNIX) -> str: - arg, unit = op.arg, op.unit - - if unit == TimestampUnit.MILLISECOND: - precision = 3 - elif unit == TimestampUnit.SECOND: - precision = 0 - else: - raise ValueError(f"{unit!r} unit is not supported!") - - arg = translator.translate(op.arg) - return f"CAST(TO_TIMESTAMP_LTZ({arg}, {precision}) AS TIMESTAMP)" - - -def _timestamp_from_ymdhms( - translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS -) -> str: - year, month, day, hours, minutes, seconds = ( - f"CAST({translator.translate(e)} AS STRING)" - for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds] - ) - concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})" - return f"CAST({concat_string} AS TIMESTAMP)" - - -def _struct_field(translator, op): - arg = translator.translate(op.arg) - return f"{arg}.{quote_identifier(op.field, force=True)}" - - -operation_registry.update( - { - # Unary operations - ops.Not: _not, - ops.NullIf: fixed_arity("nullif", 2), - ops.RandomScalar: lambda *_: "rand()", - ops.Degrees: unary("degrees"), - ops.Radians: unary("radians"), - # Unary aggregates - ops.ApproxCountDistinct: aggregate.reduction("approx_count_distinct"), - ops.CountStar: _count_star, - # String operations - ops.RegexSearch: fixed_arity("regexp", 2), - ops.StringConcat: _string_concat, - ops.Strftime: _strftime, - ops.StringLength: unary("char_length"), - ops.StrRight: fixed_arity("right", 2), - # Timestamp operations - ops.Date: _date, - ops.ExtractEpochSeconds: _extract_epoch_seconds, - ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date) - ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date) - ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date) - ops.ExtractQuarter: _extract_field("quarter"), # equivalent to QUARTER(date) - ops.ExtractWeekOfYear: _extract_field("week"), # equivalent to WEEK(date) - ops.ExtractDayOfYear: _extract_field("doy"), # equivalent to DAYOFYEAR(date) - ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp) - ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp) - ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp) - ops.ExtractMillisecond: _extract_field("millisecond"), - ops.ExtractMicrosecond: _extract_field("microsecond"), - # Other operations - ops.Cast: _cast, - ops.Coalesce: varargs("coalesce"), - ops.IntervalAdd: _interval_add, - ops.IntervalSubtract: _interval_subtract, - ops.Literal: _literal, - ops.TryCast: _try_cast, - ops.IfElse: _filter, - ops.Window: _window, - ops.Clip: _clip, - ops.NTile: _ntile, - # Binary operations - ops.Power: fixed_arity("power", 2), - ops.FloorDivide: _floor_divide, - # Collection operations - ops.Array: _array, - ops.ArrayContains: fixed_arity("ARRAY_CONTAINS", 2), - ops.ArrayDistinct: fixed_arity("ARRAY_DISTINCT", 1), - ops.ArrayIndex: _array_index, - ops.ArrayLength: _array_length, - ops.ArrayPosition: _array_position, - ops.ArrayRemove: fixed_arity("ARRAY_REMOVE", 2), - ops.ArraySlice: _array_slice, - ops.ArrayUnion: fixed_arity("ARRAY_UNION", 2), - ops.JSONGetItem: _json_get_item, - ops.Map: _map, - ops.MapGet: _map_get, - ops.StructField: _struct_field, - # Temporal functions - ops.DateAdd: _date_add, - ops.DateDelta: _date_delta, - ops.DateDiff: _date_diff, - ops.DateFromYMD: _date_from_ymd, - ops.DateSub: _date_sub, - ops.DayOfWeekIndex: _day_of_week_index, - ops.DayOfWeekName: _day_of_week_name, - ops.StringToTimestamp: _string_to_timestamp, - ops.Time: _time, - ops.TimeDelta: _time_delta, - ops.TimeFromHMS: _time_from_hms, - ops.TimestampAdd: _timestamp_add, - ops.TimestampBucket: _timestamp_bucket, - ops.TimestampDelta: _timestamp_delta, - ops.TimestampDiff: _timestamp_diff, - ops.TimestampFromUNIX: _timestamp_from_unix, - ops.TimestampFromYMDHMS: _timestamp_from_ymdhms, - ops.TimestampSub: _timestamp_sub, - ops.StructField: _struct_field, - } -) - -_invalid_operations = { - # ibis.expr.operations.numeric - ops.IsNan, - ops.IsInf, - # ibis.expr.operations.reductions - ops.ApproxMedian, - # ibis.expr.operations.strings - ops.Translate, - ops.FindInSet, -} - -operation_registry = { - k: v for k, v in operation_registry.items() if k not in _invalid_operations -} diff --git a/ibis/backends/flink/tests/conftest.py b/ibis/backends/flink/tests/conftest.py index 5fa2078432e41..27f6f5f3cba08 100644 --- a/ibis/backends/flink/tests/conftest.py +++ b/ibis/backends/flink/tests/conftest.py @@ -2,15 +2,18 @@ from typing import Any +import pandas as pd import pytest import ibis from ibis.backends.conftest import TEST_TABLES from ibis.backends.tests.base import BackendTest +from ibis.backends.tests.data import array_types, json_types, struct_types, win class TestConf(BackendTest): force_sort = True + stateful = False deps = "pandas", "pyflink" @staticmethod @@ -50,18 +53,16 @@ def connect(*, tmpdir, worker_id, **kw: Any): return ibis.flink.connect(stream_table_env, **kw) def _load_data(self, **_: Any) -> None: - import pandas as pd - - from ibis.backends.tests.data import array_types, json_types, struct_types, win + con = self.connection for table_name in TEST_TABLES: path = self.data_dir / "parquet" / f"{table_name}.parquet" - self.connection.create_table(table_name, pd.read_parquet(path), temp=True) + con.create_table(table_name, pd.read_parquet(path), temp=True) - self.connection.create_table("array_types", array_types, temp=True) - self.connection.create_table("json_t", json_types, temp=True) - self.connection.create_table("struct", struct_types, temp=True) - self.connection.create_table("win", win, temp=True) + con.create_table("array_types", array_types, temp=True) + con.create_table("json_t", json_types, temp=True) + con.create_table("struct", struct_types, temp=True) + con.create_table("win", win, temp=True) class TestConfForStreaming(TestConf): @@ -110,16 +111,6 @@ def con(tmp_path_factory, data_dir, worker_id): ).connection -@pytest.fixture(scope="session") -def db(con): - return con.database() - - -@pytest.fixture(scope="session") -def alltypes(con): - return con.tables.functional_alltypes - - @pytest.fixture def awards_players_schema(): return TEST_TABLES["awards_players"] @@ -163,39 +154,3 @@ def generate_csv_configs(csv_file): } return generate_csv_configs - - -@pytest.fixture -def temp_view(con) -> str: - """Return a temporary view name. - - Parameters - ---------- - con : backend connection - - Yields - ------ - name : string - Random view name for a temporary usage. - - Note (mehmet): Added this here because the fixture - ibis/ibis/backends/conftest.py::temp_view() - leads to docker related errors through its parameter `ddl_con`. - """ - from ibis import util - - name = util.gen_name("view") - yield name - - con.drop_view(name, force=True) - - -@pytest.fixture(autouse=True) -def reset_con(con): - yield - tables_to_drop = list(set(con.list_tables()) - set(TEST_TABLES.keys())) - for table in tables_to_drop: - con.drop_table(table, force=True) - views_to_drop = list(set(con.list_views()) - set(TEST_TABLES.keys())) - for view in views_to_drop: - con.drop_view(view, temp=True, force=True) diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_filtered_agg/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_filtered_agg/out.sql index 5b9990fe38920..4a4b2ce0ecb80 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_filtered_agg/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_filtered_agg/out.sql @@ -1,5 +1,9 @@ -SELECT t0.`b`, COUNT(*) AS `total`, avg(t0.`a`) AS `avg_a`, - avg(CASE WHEN t0.`g` = 'A' THEN t0.`a` ELSE NULL END) AS `avg_a_A`, - avg(CASE WHEN t0.`g` = 'B' THEN t0.`a` ELSE NULL END) AS `avg_a_B` -FROM table t0 -GROUP BY t0.`b` \ No newline at end of file +SELECT + `t0`.`b`, + COUNT(*) AS `total`, + AVG(`t0`.`a`) AS `avg_a`, + AVG(CASE WHEN `t0`.`g` = 'A' THEN `t0`.`a` END) AS `avg_a_A`, + AVG(CASE WHEN `t0`.`g` = 'B' THEN `t0`.`a` END) AS `avg_a_B` +FROM `table` AS `t0` +GROUP BY + `t0`.`b` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_groupby_aggregation/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_groupby_aggregation/out.sql index 1ce799579cd5a..e86b81ec26ee9 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_groupby_aggregation/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_groupby_aggregation/out.sql @@ -1,5 +1,9 @@ -SELECT EXTRACT(year from t0.`i`) AS `year`, - EXTRACT(month from t0.`i`) AS `month`, COUNT(*) AS `total`, - count(DISTINCT t0.`b`) AS `b_unique` -FROM table t0 -GROUP BY EXTRACT(year from t0.`i`), EXTRACT(month from t0.`i`) \ No newline at end of file +SELECT + EXTRACT(year FROM `t0`.`i`) AS `year`, + EXTRACT(month FROM `t0`.`i`) AS `month`, + COUNT(*) AS `total`, + COUNT(DISTINCT `t0`.`b`) AS `b_unique` +FROM `table` AS `t0` +GROUP BY + EXTRACT(year FROM `t0`.`i`), + EXTRACT(month FROM `t0`.`i`) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_projections/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_projections/out.sql index c883cc8ab63a8..f9d32940529d7 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_projections/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_complex_projections/out.sql @@ -1,7 +1,15 @@ -SELECT t0.`a`, avg(abs(t0.`the_sum`)) AS `mad` +SELECT + `t1`.`a`, + AVG(ABS(`t1`.`the_sum`)) AS `mad` FROM ( - SELECT t1.`a`, t1.`c`, sum(t1.`b`) AS `the_sum` - FROM table t1 - GROUP BY t1.`a`, t1.`c` -) t0 -GROUP BY t0.`a` \ No newline at end of file + SELECT + `t0`.`a`, + `t0`.`c`, + SUM(`t0`.`b`) AS `the_sum` + FROM `table` AS `t0` + GROUP BY + `t0`.`a`, + `t0`.`c` +) AS `t1` +GROUP BY + `t1`.`a` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_count_star/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_count_star/out.sql index e3e3e49089cda..b94c832490c0b 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_count_star/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_count_star/out.sql @@ -1,3 +1,6 @@ -SELECT t0.`i`, COUNT(*) AS `CountStar(table)` -FROM table t0 -GROUP BY t0.`i` \ No newline at end of file +SELECT + `t0`.`i`, + COUNT(*) AS `CountStar(table)` +FROM `table` AS `t0` +GROUP BY + `t0`.`i` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day/out.sql index 14d96a04c89ca..1f9573020c9dd 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(day from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(day FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day_of_year/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day_of_year/out.sql index 9774a20af63a9..39cded443d0c7 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day_of_year/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/day_of_year/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(doy from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + DAYOFYEAR(`t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/hour/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/hour/out.sql index e19999b7b1f1d..c91eccac3d062 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/hour/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/hour/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(hour from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(hour FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/minute/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/minute/out.sql index aeed550bdab21..a621ade98023c 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/minute/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/minute/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(minute from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(minute FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/month/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/month/out.sql index 57e3d1c6fa45e..e5596983477c0 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/month/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/month/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(month from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(month FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/quarter/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/quarter/out.sql index c6170172bde11..6afeb5e1b26f4 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/quarter/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/quarter/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(quarter from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(quarter FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/second/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/second/out.sql index 0c32dade798e5..0478f739a6c7f 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/second/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/second/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(second from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(second FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/week_of_year/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/week_of_year/out.sql index 3f6ca60700ebc..aac465a39a88c 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/week_of_year/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/week_of_year/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(week from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(week FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/year/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/year/out.sql index ab354a7fd8612..09d387170f993 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/year/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_extract_fields/year/out.sql @@ -1,2 +1,3 @@ -SELECT EXTRACT(year from t0.`i`) AS `tmp` -FROM table t0 \ No newline at end of file +SELECT + EXTRACT(year FROM `t0`.`i`) AS `tmp` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_filter/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_filter/out.sql index a74d83d269b56..079893708b237 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_filter/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_filter/out.sql @@ -1,4 +1,21 @@ -SELECT t0.* -FROM table t0 -WHERE ((t0.`c` > CAST(0 AS TINYINT)) OR (t0.`c` < CAST(0 AS TINYINT))) AND - (t0.`g` IN ('A', 'B')) \ No newline at end of file +SELECT + `t0`.`a`, + `t0`.`b`, + `t0`.`c`, + `t0`.`d`, + `t0`.`e`, + `t0`.`f`, + `t0`.`g`, + `t0`.`h`, + `t0`.`i`, + `t0`.`j`, + `t0`.`k` +FROM `table` AS `t0` +WHERE + ( + ( + `t0`.`c` > 0 + ) OR ( + `t0`.`c` < 0 + ) + ) AND `t0`.`g` IN ('A', 'B') \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_having/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_having/out.sql index 3744dd045f0de..4dc04f8e5686a 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_having/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_having/out.sql @@ -1,4 +1,14 @@ -SELECT t0.`g`, sum(t0.`b`) AS `b_sum` -FROM table t0 -GROUP BY t0.`g` -HAVING COUNT(*) >= CAST(1000 AS SMALLINT) \ No newline at end of file +SELECT + `t1`.`g`, + `t1`.`b_sum` +FROM ( + SELECT + `t0`.`g`, + SUM(`t0`.`b`) AS `b_sum`, + COUNT(*) AS `CountStar(table)` + FROM `table` AS `t0` + GROUP BY + `t0`.`g` +) AS `t1` +WHERE + `t1`.`CountStar(table)` >= 1000 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_simple_filtered_agg/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_simple_filtered_agg/out.sql index 19afa7a54cdf1..8ba32b106e1f0 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_simple_filtered_agg/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_simple_filtered_agg/out.sql @@ -1,2 +1,3 @@ -SELECT count(DISTINCT CASE WHEN t0.`g` = 'A' THEN t0.`b` ELSE NULL END) AS `CountDistinct(b, Equals(g, 'A'))` -FROM table t0 \ No newline at end of file +SELECT + COUNT(DISTINCT IF(`t0`.`g` = 'A', `t0`.`b`, ARRAY[`t0`.`b`][2])) AS `CountDistinct(b, Equals(g, 'A'))` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_sum/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_sum/out.sql index 4cd6d4fde779f..3701dafb94c45 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_sum/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_sum/out.sql @@ -1,2 +1,3 @@ -SELECT sum(t0.`a`) AS `Sum(a)` -FROM table t0 \ No newline at end of file +SELECT + SUM(`t0`.`a`) AS `Sum(a)` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_ms/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_ms/out.sql index c2424f7e63e2d..940b6016a3365 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_ms/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_ms/out.sql @@ -1,2 +1,3 @@ -SELECT CAST(TO_TIMESTAMP_LTZ(t0.`d`, 3) AS TIMESTAMP) AS `TimestampFromUNIX(d)` -FROM table t0 \ No newline at end of file +SELECT + CAST(TO_TIMESTAMP_LTZ(`t0`.`d`, 3) AS TIMESTAMP) AS `TimestampFromUNIX(d, MILLISECOND)` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_s/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_s/out.sql index 31766957381dc..91c783b17492e 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_s/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_timestamp_from_unix/timestamp_s/out.sql @@ -1,2 +1,3 @@ -SELECT CAST(TO_TIMESTAMP_LTZ(t0.`d`, 0) AS TIMESTAMP) AS `TimestampFromUNIX(d)` -FROM table t0 \ No newline at end of file +SELECT + CAST(TO_TIMESTAMP_LTZ(`t0`.`d`, 0) AS TIMESTAMP) AS `TimestampFromUNIX(d, SECOND)` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_value_counts/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_value_counts/out.sql index 53e792c45392a..5645f83ecec1d 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_value_counts/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_value_counts/out.sql @@ -1,6 +1,10 @@ -SELECT t0.`ExtractYear(i)`, COUNT(*) AS `ExtractYear(i)_count` +SELECT + `t1`.`ExtractYear(i)`, + COUNT(*) AS `ExtractYear(i)_count` FROM ( - SELECT EXTRACT(year from t1.`i`) AS `ExtractYear(i)` - FROM table t1 -) t0 -GROUP BY t0.`ExtractYear(i)` \ No newline at end of file + SELECT + EXTRACT(year FROM `t0`.`i`) AS `ExtractYear(i)` + FROM `table` AS `t0` +) AS `t1` +GROUP BY + `t1`.`ExtractYear(i)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_window_aggregation/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_window_aggregation/out.sql index 9867d27155f14..713b7ce6343ae 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_window_aggregation/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_window_aggregation/out.sql @@ -1,3 +1,14 @@ -SELECT t0.`window_start`, t0.`window_end`, t0.`g`, avg(t0.`d`) AS `mean` -FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) t0 -GROUP BY t0.`window_start`, t0.`window_end`, t0.`g` \ No newline at end of file +SELECT + `t1`.`window_start`, + `t1`.`window_end`, + `t1`.`g`, + AVG(`t1`.`d`) AS `mean` +FROM ( + SELECT + `t0`.* + FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0` +) AS `t1` +GROUP BY + `t1`.`window_start`, + `t1`.`window_end`, + `t1`.`g` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_window_topn/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_window_topn/out.sql index 92d4e692424ff..fe04f27ee89cb 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_window_topn/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_window_topn/out.sql @@ -1,12 +1,37 @@ -WITH t0 AS ( - SELECT t2.`a`, t2.`b`, t2.`c`, t2.`d`, t2.`g`, t2.`window_start`, - t2.`window_end` - FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' MINUTE)) t2 -) -SELECT t1.* +SELECT + `t3`.`a`, + `t3`.`b`, + `t3`.`c`, + `t3`.`d`, + `t3`.`g`, + `t3`.`window_start`, + `t3`.`window_end`, + `t3`.`rownum` FROM ( - SELECT t0.*, - (row_number() OVER (PARTITION BY t0.`window_start`, t0.`window_end` ORDER BY t0.`g` DESC) - 1) AS `rownum` - FROM t0 -) t1 -WHERE t1.`rownum` <= CAST(3 AS TINYINT) \ No newline at end of file + SELECT + `t2`.`a`, + `t2`.`b`, + `t2`.`c`, + `t2`.`d`, + `t2`.`g`, + `t2`.`window_start`, + `t2`.`window_end`, + ROW_NUMBER() OVER (PARTITION BY `t2`.`window_start`, `t2`.`window_end` ORDER BY `t2`.`g` DESC) - 1 AS `rownum` + FROM ( + SELECT + `t1`.`a`, + `t1`.`b`, + `t1`.`c`, + `t1`.`d`, + `t1`.`g`, + `t1`.`window_start`, + `t1`.`window_end` + FROM ( + SELECT + `t0`.* + FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '600' SECOND)) AS `t0` + ) AS `t1` + ) AS `t2` +) AS `t3` +WHERE + `t3`.`rownum` <= 3 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/cumulate_window/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/cumulate_window/out.sql index 522c6d576e1ea..e2f0514f8925b 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/cumulate_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/cumulate_window/out.sql @@ -1,2 +1,5 @@ -SELECT t0.* -FROM TABLE(CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND, INTERVAL '1' MINUTE)) t0 \ No newline at end of file +SELECT + `t0`.* +FROM TABLE( + CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND, INTERVAL '1' MINUTE) +) AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/hop_window/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/hop_window/out.sql index 38376568cba92..ff64fce1ccbf9 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/hop_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/hop_window/out.sql @@ -1,2 +1,3 @@ -SELECT t0.* -FROM TABLE(HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE, INTERVAL '15' MINUTE)) t0 \ No newline at end of file +SELECT + `t0`.* +FROM TABLE(HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE, INTERVAL '15' MINUTE)) AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/tumble_window/out.sql b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/tumble_window/out.sql index d46ca32d3d239..e85caed0ef7ed 100644 --- a/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/tumble_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_compiler/test_windowing_tvf/tumble_window/out.sql @@ -1,2 +1,3 @@ -SELECT t0.* -FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) t0 \ No newline at end of file +SELECT + `t0`.* +FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime/out.sql index 8f93e34a40ddc..bb447abd40a1d 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime/out.sql @@ -1 +1,2 @@ -TIMESTAMP '2017-01-01 04:55:59' \ No newline at end of file +SELECT + CAST('2017-01-01 04:55:59' AS TIMESTAMP) AS `datetime.datetime(2017, 1, 1, 4, 55, 59)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime_with_microseconds/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime_with_microseconds/out.sql index 9260f37f43e61..d5c8151660fe6 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime_with_microseconds/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/datetime_with_microseconds/out.sql @@ -1 +1,2 @@ -TIMESTAMP '2017-01-01 04:55:59.001122' \ No newline at end of file +SELECT + CAST('2017-01-01 04:55:59.001122' AS TIMESTAMP) AS `datetime.datetime(2017, 1, 1, 4, 55, 59, 1122)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_time/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_time/out.sql index b3976577078cb..8b62e47adfdbc 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_time/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_time/out.sql @@ -1 +1,2 @@ -TIME '04:55:59' \ No newline at end of file +SELECT + CAST('04:55:59' AS TIMESTAMP) AS `datetime.time(4, 55, 59)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_timestamp/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_timestamp/out.sql index 8f93e34a40ddc..bb447abd40a1d 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_timestamp/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/string_timestamp/out.sql @@ -1 +1,2 @@ -TIMESTAMP '2017-01-01 04:55:59' \ No newline at end of file +SELECT + CAST('2017-01-01 04:55:59' AS TIMESTAMP) AS `datetime.datetime(2017, 1, 1, 4, 55, 59)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/time/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/time/out.sql index b3976577078cb..8b62e47adfdbc 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/time/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/time/out.sql @@ -1 +1,2 @@ -TIME '04:55:59' \ No newline at end of file +SELECT + CAST('04:55:59' AS TIMESTAMP) AS `datetime.time(4, 55, 59)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/timestamp/out.sql b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/timestamp/out.sql index 8f93e34a40ddc..bb447abd40a1d 100644 --- a/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/timestamp/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_literals/test_literal_timestamp_or_time/timestamp/out.sql @@ -1 +1,2 @@ -TIMESTAMP '2017-01-01 04:55:59' \ No newline at end of file +SELECT + CAST('2017-01-01 04:55:59' AS TIMESTAMP) AS `datetime.datetime(2017, 1, 1, 4, 55, 59)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_window/test_range_window/out.sql b/ibis/backends/flink/tests/snapshots/test_window/test_range_window/out.sql index b43021ac30d0b..a739ff47435ab 100644 --- a/ibis/backends/flink/tests/snapshots/test_window/test_range_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_window/test_range_window/out.sql @@ -1,2 +1,3 @@ -SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC RANGE BETWEEN INTERVAL '00 08:20:00.000000' DAY TO SECOND PRECEDING AND CURRENT ROW) AS `Sum(f)` -FROM table t0 \ No newline at end of file +SELECT + SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST RANGE BETWEEN INTERVAL '500' MINUTE preceding AND CAST(0 AS INTERVAL MINUTE) following) AS `Sum(f)` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql b/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql index 3173072eafb64..7b5bb83d8c921 100644 --- a/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql @@ -1,2 +1,3 @@ -SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN CAST(1000 AS SMALLINT) PRECEDING AND CURRENT ROW) AS `Sum(f)` -FROM table t0 \ No newline at end of file +SELECT + SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 1000 preceding AND CAST(0 AS SMALLINT) following) AS `Sum(f)` +FROM `table` AS `t0` \ No newline at end of file diff --git a/ibis/backends/flink/tests/test_ddl.py b/ibis/backends/flink/tests/test_ddl.py index d1e741e6eb095..44742bc7c8213 100644 --- a/ibis/backends/flink/tests/test_ddl.py +++ b/ibis/backends/flink/tests/test_ddl.py @@ -110,21 +110,14 @@ def test_force_recreate_table_from_schema( assert new_table.schema() == awards_players_schema -@pytest.mark.parametrize( - "employee_df", - [ - pd.DataFrame( - [("fred flintstone", "award", 2002, "lg_id", "tie", "this is a note")] - ) - ], -) @pytest.mark.parametrize( "schema, table_name", [(None, None), (TEST_TABLES["awards_players"], "awards_players")], ) -def test_recreate_in_mem_table( - con, employee_df, schema, table_name, temp_table, csv_source_configs -): +def test_recreate_in_mem_table(con, schema, table_name, temp_table, csv_source_configs): + employee_df = pd.DataFrame( + [("fred flintstone", "award", 2002, "lg_id", "tie", "this is a note")] + ) # create table once if table_name is not None: tbl_properties = csv_source_configs(table_name) @@ -138,39 +131,35 @@ def test_recreate_in_mem_table( tbl_properties=tbl_properties, temp=True, ) - assert temp_table in con.list_tables() - if schema is not None: - assert new_table.schema() == schema - - # create the same table a second time should fail - with pytest.raises( - Py4JJavaError, - match="An error occurred while calling o8.createTemporaryView", - ): - new_table = con.create_table( - name=temp_table, - obj=employee_df, - schema=schema, - tbl_properties=tbl_properties, - overwrite=False, - temp=True, - ) + try: + assert temp_table in con.list_tables() + if schema is not None: + assert new_table.schema() == schema + + # create the same table a second time should fail + with pytest.raises( + Py4JJavaError, + match=r"An error occurred while calling o\d+\.createTemporaryView", + ): + new_table = con.create_table( + name=temp_table, + obj=employee_df, + schema=schema, + tbl_properties=tbl_properties, + overwrite=False, + temp=True, + ) + finally: + con.drop_table(temp_table, force=True) -@pytest.mark.parametrize( - "employee_df", - [ - pd.DataFrame( - [("fred flintstone", "award", 2002, "lg_id", "tie", "this is a note")] - ) - ], -) @pytest.mark.parametrize( "schema_props", [(None, None), (TEST_TABLES["awards_players"], "awards_players")] ) -def test_force_recreate_in_mem_table( - con, employee_df, schema_props, temp_table, csv_source_configs -): +def test_force_recreate_in_mem_table(con, schema_props, temp_table, csv_source_configs): + employee_df = pd.DataFrame( + [("fred flintstone", "award", 2002, "lg_id", "tie", "this is a note")] + ) # create table once schema = schema_props[0] if schema_props[1] is not None: @@ -185,22 +174,25 @@ def test_force_recreate_in_mem_table( tbl_properties=tbl_properties, temp=True, ) - assert temp_table in con.list_tables() - if schema is not None: - assert new_table.schema() == schema + try: + assert temp_table in con.list_tables() + if schema is not None: + assert new_table.schema() == schema - # force recreate the same table a second time should succeed - new_table = con.create_table( - name=temp_table, - obj=employee_df, - schema=schema, - tbl_properties=tbl_properties, - temp=True, - overwrite=True, - ) - assert temp_table in con.list_tables() - if schema is not None: - assert new_table.schema() == schema + # force recreate the same table a second time should succeed + new_table = con.create_table( + name=temp_table, + obj=employee_df, + schema=schema, + tbl_properties=tbl_properties, + temp=True, + overwrite=True, + ) + assert temp_table in con.list_tables() + if schema is not None: + assert new_table.schema() == schema + finally: + con.drop_table(temp_table, force=True) @pytest.fixture @@ -250,8 +242,11 @@ def test_create_source_table_with_watermark_and_primary_key( ), primary_key=primary_key, ) - assert temp_table in con.list_tables() - assert new_table.schema() == functional_alltypes_schema_w_nonnullable_columns + try: + assert temp_table in con.list_tables() + assert new_table.schema() == functional_alltypes_schema_w_nonnullable_columns + finally: + con.drop_table(temp_table, force=True) @pytest.mark.parametrize( @@ -279,14 +274,16 @@ def test_create_table_failure_with_invalid_primary_keys( assert temp_table not in con.list_tables() +@pytest.fixture +def temp_view(con): + name = ibis.util.gen_name("view") + yield name + con.drop_view(name, force=True) + + @pytest.mark.parametrize("temp", [True, False]) def test_create_view( - con, - temp_table: str, - awards_players_schema: sch.Schema, - csv_source_configs, - temp_view: str, - temp, + con, temp_table, awards_players_schema, csv_source_configs, temp_view, temp ): table = con.create_table( name=temp_table, @@ -404,10 +401,15 @@ def test_insert_values_into_table(con, tempdir_sink_configs, obj): schema=sink_schema, tbl_properties=tempdir_sink_configs(tempdir), ) - con.insert("tempdir_sink", obj).wait() - temporary_file = next(iter(os.listdir(tempdir))) - with open(os.path.join(tempdir, temporary_file)) as f: - assert f.read() == '"fred flintstone",35,1.28\n"barney rubble",32,2.32\n' + try: + con.insert("tempdir_sink", obj).wait() + temporary_file = next(iter(os.listdir(tempdir))) + with open(os.path.join(tempdir, temporary_file)) as f: + assert ( + f.read() == '"fred flintstone",35,1.28\n"barney rubble",32,2.32\n' + ) + finally: + con.drop_table("tempdir_sink", force=True) def test_insert_simple_select(con, tempdir_sink_configs): @@ -419,20 +421,27 @@ def test_insert_simple_select(con, tempdir_sink_configs): ), temp=True, ) - sink_schema = sch.Schema({"name": dt.string, "age": dt.int64}) - source_table = ibis.table( - sch.Schema({"name": dt.string, "age": dt.int64, "gpa": dt.float64}), "source" - ) - with tempfile.TemporaryDirectory() as tempdir: - con.create_table( - "tempdir_sink", - schema=sink_schema, - tbl_properties=tempdir_sink_configs(tempdir), + try: + sink_schema = sch.Schema({"name": dt.string, "age": dt.int64}) + source_table = ibis.table( + sch.Schema({"name": dt.string, "age": dt.int64, "gpa": dt.float64}), + "source", ) - con.insert("tempdir_sink", source_table[["name", "age"]]).wait() - temporary_file = next(iter(os.listdir(tempdir))) - with open(os.path.join(tempdir, temporary_file)) as f: - assert f.read() == '"fred flintstone",35\n"barney rubble",32\n' + with tempfile.TemporaryDirectory() as tempdir: + con.create_table( + "tempdir_sink", + schema=sink_schema, + tbl_properties=tempdir_sink_configs(tempdir), + ) + try: + con.insert("tempdir_sink", source_table[["name", "age"]]).wait() + temporary_file = next(iter(os.listdir(tempdir))) + with open(os.path.join(tempdir, temporary_file)) as f: + assert f.read() == '"fred flintstone",35\n"barney rubble",32\n' + finally: + con.drop_table("tempdir_sink", force=True) + finally: + con.drop_table("source", force=True) @pytest.mark.parametrize("table_name", ["new_table", None]) @@ -443,13 +452,13 @@ def test_read_csv(con, awards_players_schema, csv_source_configs, table_name): schema=awards_players_schema, table_name=table_name, ) - - if table_name is None: - table_name = table.get_name() - assert table_name in con.list_tables() - assert table.schema() == awards_players_schema - - con.drop_table(table_name) + try: + if table_name is None: + table_name = table.get_name() + assert table_name in con.list_tables() + assert table.schema() == awards_players_schema + finally: + con.drop_table(table_name) assert table_name not in con.list_tables() @@ -463,12 +472,13 @@ def test_read_parquet(con, data_dir, tmp_path, table_name, functional_alltypes_s table_name=table_name, ) - if table_name is None: - table_name = table.get_name() - assert table_name in con.list_tables() - assert table.schema() == functional_alltypes_schema - - con.drop_table(table_name) + try: + if table_name is None: + table_name = table.get_name() + assert table_name in con.list_tables() + assert table.schema() == functional_alltypes_schema + finally: + con.drop_table(table_name) assert table_name not in con.list_tables() @@ -487,29 +497,19 @@ def test_read_json(con, data_dir, tmp_path, table_name, functional_alltypes_sche path=path, schema=functional_alltypes_schema, table_name=table_name ) - if table_name is None: - table_name = table.get_name() - assert table_name in con.list_tables() - assert table.schema() == functional_alltypes_schema - assert table.count().execute() == len(pq_table) - - con.drop_table(table_name) + try: + if table_name is None: + table_name = table.get_name() + assert table_name in con.list_tables() + assert table.schema() == functional_alltypes_schema + assert table.count().execute() == len(pq_table) + finally: + con.drop_table(table_name) assert table_name not in con.list_tables() -@pytest.fixture(scope="module") -def functional_alltypes(con): - return con.table("functional_alltypes") - - @pytest.mark.parametrize( - "table_name", - [ - "astronauts", - "awards_players", - "diamonds", - "functional_alltypes", - ], + "table_name", ["astronauts", "awards_players", "diamonds", "functional_alltypes"] ) def test_to_csv(con, tmp_path, table_name): table = con.table(table_name) @@ -523,13 +523,7 @@ def test_to_csv(con, tmp_path, table_name): @pytest.mark.parametrize( - "table_name", - [ - "astronauts", - "awards_players", - "diamonds", - "functional_alltypes", - ], + "table_name", ["astronauts", "awards_players", "diamonds", "functional_alltypes"] ) def test_to_parquet(con, tmp_path, table_name): table = con.table(table_name) diff --git a/ibis/backends/flink/tests/test_join.py b/ibis/backends/flink/tests/test_join.py index 79da7c62016f5..d9ea09a819d5c 100644 --- a/ibis/backends/flink/tests/test_join.py +++ b/ibis/backends/flink/tests/test_join.py @@ -8,6 +8,7 @@ import ibis from ibis.backends.flink.tests.conftest import TestConf as tm +from ibis.backends.tests.errors import Py4JJavaError @pytest.fixture(scope="module") @@ -115,7 +116,11 @@ def remove_temp_files(left_tmp, right_tmp): right_tmp.close() -@pytest.mark.xfail(raises=AssertionError, reason="test seems broken", strict=False) +@pytest.mark.xfail( + raises=(Py4JJavaError, AssertionError), + reason="subquery probably uses too much memory/resources, flink complains about network buffers", + strict=False, +) def test_outer_join(left_tumble, right_tumble): expr = left_tumble.join( right_tumble, diff --git a/ibis/backends/flink/tests/test_literals.py b/ibis/backends/flink/tests/test_literals.py deleted file mode 100644 index 091405d142d02..0000000000000 --- a/ibis/backends/flink/tests/test_literals.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -import datetime - -import pandas as pd -import pytest -from pytest import param - -import ibis -import ibis.expr.datatypes as dt - - -@pytest.mark.parametrize( - "value,expected", - [ - param(5, "CAST(5 AS TINYINT)", id="int"), - param(1.5, "CAST(1.5 AS DOUBLE)", id="float"), - param(True, "TRUE", id="true"), - param(False, "FALSE", id="false"), - ], -) -def test_simple_literals(con, value, expected): - expr = ibis.literal(value) - result = con.compile(expr) - assert result == expected - - -@pytest.mark.parametrize( - "value,expected", - [ - param("simple", "'simple'", id="simple"), - param("I can't", "'I can''t'", id="nested_quote"), - param('An "escape"', """'An "escape"'""", id="nested_token"), - ], -) -def test_string_literals(con, value, expected): - expr = ibis.literal(value) - result = con.compile(expr) - assert result == expected - - -@pytest.mark.parametrize( - "value,expected", - [ - param( - datetime.timedelta(seconds=70), - "INTERVAL '00 00:01:10.000000' DAY TO SECOND", - id="70seconds", - ), - param( - ibis.interval(months=50), "INTERVAL '04-02' YEAR TO MONTH", id="50months" - ), - param(ibis.interval(seconds=5), "INTERVAL '5' SECOND", id="5seconds"), - ], -) -def test_translate_interval_literal(con, value, expected): - expr = ibis.literal(value) - result = con.compile(expr) - assert result == expected - - -@pytest.mark.parametrize( - ("case", "dtype"), - [ - param(datetime.datetime(2017, 1, 1, 4, 55, 59), dt.timestamp, id="datetime"), - param( - datetime.datetime(2017, 1, 1, 4, 55, 59, 1122), - dt.timestamp, - id="datetime_with_microseconds", - ), - param("2017-01-01 04:55:59", dt.timestamp, id="string_timestamp"), - param(pd.Timestamp("2017-01-01 04:55:59"), dt.timestamp, id="timestamp"), - param(datetime.time(4, 55, 59), dt.time, id="time"), - param("04:55:59", dt.time, id="string_time"), - ], -) -def test_literal_timestamp_or_time(con, snapshot, case, dtype): - expr = ibis.literal(case, type=dtype) - result = con.compile(expr) - snapshot.assert_match(result, "out.sql") diff --git a/ibis/backends/flink/tests/test_window.py b/ibis/backends/flink/tests/test_window.py index ca67a317fa2eb..9104cc558b30e 100644 --- a/ibis/backends/flink/tests/test_window.py +++ b/ibis/backends/flink/tests/test_window.py @@ -1,63 +1,56 @@ from __future__ import annotations import pytest +from pyflink.util.exceptions import TableException from pytest import param import ibis -from ibis.common.exceptions import UnsupportedOperationError +from ibis.backends.tests.errors import Py4JJavaError -def test_window_requires_order_by(con, simple_table): - expr = simple_table.mutate(simple_table.c - simple_table.c.mean()) - with pytest.raises( - UnsupportedOperationError, - match="Flink engine does not support generic window clause with no order by", - ): - con.compile(expr) +@pytest.mark.xfail(raises=TableException) +def test_window_requires_order_by(con): + t = con.tables.functional_alltypes + expr = t.mutate(t.double_col - t.double_col.mean()) + con.execute(expr) -def test_window_does_not_support_multiple_order_by(con, simple_table): - expr = simple_table.f.sum().over( - rows=(-1, 1), - group_by=[simple_table.g, simple_table.a], - order_by=[simple_table.f, simple_table.d], - ) - with pytest.raises( - UnsupportedOperationError, - match="Windows in Flink can only be ordered by a single time column", - ): - con.compile(expr) +@pytest.mark.xfail(raises=TableException) +def test_window_does_not_support_multiple_order_by(con): + t = con.tables.functional_alltypes + expr = t.double_col.sum().over(rows=(-1, 1), order_by=[t.timestamp_col, t.int_col]) + con.execute(expr) @pytest.mark.parametrize( - ("window", "err"), + "window", [ param( {"rows": (-1, 1)}, - "OVER RANGE FOLLOWING windows are not supported in Flink yet", id="bounded_rows_following", + marks=[pytest.mark.xfail(raises=TableException)], ), param( {"rows": (-1, None)}, - "OVER RANGE FOLLOWING windows are not supported in Flink yet", id="unbounded_rows_following", + marks=[pytest.mark.xfail(raises=TableException)], ), param( {"rows": (-500, 1)}, - "OVER RANGE FOLLOWING windows are not supported in Flink yet", id="casted_bounded_rows_following", + marks=[pytest.mark.xfail(raises=TableException)], ), param( {"range": (-1000, 0)}, - "Data Type mismatch between ORDER BY and RANGE clause", id="int_range", + marks=[pytest.mark.xfail(raises=Py4JJavaError)], ), ], ) -def test_window_invalid_start_end(con, simple_table, window, err): - expr = simple_table.f.sum().over(**window, order_by=simple_table.f) - with pytest.raises(UnsupportedOperationError, match=err): - con.compile(expr) +def test_window_invalid_start_end(con, window): + t = con.tables.functional_alltypes + expr = t.int_col.sum().over(**window, order_by=t.timestamp_col) + con.execute(expr) def test_range_window(con, snapshot, simple_table): diff --git a/ibis/backends/flink/translator.py b/ibis/backends/flink/translator.py deleted file mode 100644 index 37bbcc0170ab7..0000000000000 --- a/ibis/backends/flink/translator.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import ibis.expr.operations as ops -from ibis.backends.base.sql.compiler import ExprTranslator -from ibis.backends.flink.registry import operation_registry - - -class FlinkExprTranslator(ExprTranslator): - _dialect_name = "hive" # TODO: make a custom sqlglot dialect for Flink - _registry = operation_registry - _bool_aggs_need_cast_to_int32 = True - - -@FlinkExprTranslator.rewrites(ops.Clip) -def _clip_no_op(op): - return op diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index ef3c9b09ecee5..819fb995dc7fa 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -409,6 +409,9 @@ def compile( return translate(node, ctx=self._context) + def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: + raise NotImplementedError("table.sql() not yet supported in polars") + def _get_schema_using_query(self, query: str) -> sch.Schema: return schema_from_polars(self._context.execute(query).schema) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 0c9155dc3f20f..27fa2cd341223 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -180,8 +180,10 @@ def visit_ArrayConcat(self, op, *, arg): @visit_node.register(ops.ArrayContains) def visit_ArrayContains(self, op, *, arg, other): + arg_dtype = op.arg.dtype return sge.ArrayContains( - this=arg, expression=self.f.array(self.cast(other, op.arg.dtype.value_type)) + this=self.cast(arg, arg_dtype), + expression=self.f.array(self.cast(other, arg_dtype.value_type)), ) @visit_node.register(ops.ArrayFilter) diff --git a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/flink/out.sql b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/flink/out.sql new file mode 100644 index 0000000000000..208dc189381db --- /dev/null +++ b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/flink/out.sql @@ -0,0 +1,42 @@ +WITH `t6` AS ( + SELECT + `t5`.`street`, + ROW_NUMBER() OVER (ORDER BY `t5`.`street` ASC NULLS LAST) - 1 AS `key` + FROM ( + SELECT + `t2`.`street`, + `t2`.`key` + FROM ( + SELECT + `t0`.`street`, + ROW_NUMBER() OVER (ORDER BY `t0`.`street` ASC NULLS LAST) - 1 AS `key` + FROM `data` AS `t0` + ) AS `t2` + INNER JOIN ( + SELECT + `t1`.`key` + FROM ( + SELECT + `t0`.`street`, + ROW_NUMBER() OVER (ORDER BY `t0`.`street` ASC NULLS LAST) - 1 AS `key` + FROM `data` AS `t0` + ) AS `t1` + ) AS `t4` + ON `t2`.`key` = `t4`.`key` + ) AS `t5` +), `t1` AS ( + SELECT + `t0`.`street`, + ROW_NUMBER() OVER (ORDER BY `t0`.`street` ASC NULLS LAST) - 1 AS `key` + FROM `data` AS `t0` +) +SELECT + `t8`.`street`, + `t8`.`key` +FROM `t6` AS `t8` +INNER JOIN ( + SELECT + `t7`.`key` + FROM `t6` AS `t7` +) AS `t10` + ON `t8`.`key` = `t10`.`key` \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_default_limit/flink/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/flink/out.sql new file mode 100644 index 0000000000000..f63de03c314af --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/flink/out.sql @@ -0,0 +1,5 @@ +SELECT + `t0`.`id`, + `t0`.`bool_col` +FROM `functional_alltypes` AS `t0` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/flink/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/flink/out.sql new file mode 100644 index 0000000000000..f63de03c314af --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/flink/out.sql @@ -0,0 +1,5 @@ +SELECT + `t0`.`id`, + `t0`.`bool_col` +FROM `functional_alltypes` AS `t0` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/flink/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/flink/out.sql new file mode 100644 index 0000000000000..d8a9c4090dc11 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/flink/out.sql @@ -0,0 +1,3 @@ +SELECT + SUM(`t0`.`bigint_col`) AS `Sum(bigint_col)` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/flink/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/flink/out.sql new file mode 100644 index 0000000000000..d4b1b19815b09 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/flink/out.sql @@ -0,0 +1,10 @@ +SELECT + * +FROM ( + SELECT + `t0`.`id`, + `t0`.`bool_col` + FROM `functional_alltypes` AS `t0` + LIMIT 10 +) AS `t2` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/flink/out.sql b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/flink/out.sql new file mode 100644 index 0000000000000..8d13c9ddda1be --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/flink/out.sql @@ -0,0 +1,20 @@ +WITH `t1` AS ( + SELECT + `t0`.`key` + FROM `leaf` AS `t0` + WHERE + TRUE +) +SELECT + `t3`.`key` +FROM `t1` AS `t3` +INNER JOIN `t1` AS `t4` + ON `t3`.`key` = `t4`.`key` +INNER JOIN ( + SELECT + `t3`.`key` + FROM `t1` AS `t3` + INNER JOIN `t1` AS `t4` + ON `t3`.`key` = `t4`.`key` +) AS `t6` + ON `t3`.`key` = `t6`.`key` \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/flink/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/flink/out.sql new file mode 100644 index 0000000000000..98e8ba8a8c2f0 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/flink/out.sql @@ -0,0 +1,38 @@ +SELECT + CASE `t0`.`continent` + WHEN 'NA' + THEN 'North America' + WHEN 'SA' + THEN 'South America' + WHEN 'EU' + THEN 'Europe' + WHEN 'AF' + THEN 'Africa' + WHEN 'AS' + THEN 'Asia' + WHEN 'OC' + THEN 'Oceania' + WHEN 'AN' + THEN 'Antarctica' + ELSE 'Unknown continent' + END AS `cont`, + SUM(`t0`.`population`) AS `total_pop` +FROM `countries` AS `t0` +GROUP BY + CASE `t0`.`continent` + WHEN 'NA' + THEN 'North America' + WHEN 'SA' + THEN 'South America' + WHEN 'EU' + THEN 'Europe' + WHEN 'AF' + THEN 'Africa' + WHEN 'AS' + THEN 'Asia' + WHEN 'OC' + THEN 'Oceania' + WHEN 'AN' + THEN 'Antarctica' + ELSE 'Unknown continent' + END \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/flink/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/flink/out.sql new file mode 100644 index 0000000000000..db5ddb124e868 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/flink/out.sql @@ -0,0 +1,9 @@ +SELECT + `t0`.`x` IN ( + SELECT + `t0`.`x` + FROM `t` AS `t0` + WHERE + `t0`.`x` > 2 + ) AS `InSubquery(x)` +FROM `t` AS `t0` \ No newline at end of file diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 66988e15bb633..fb49a7f800da3 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -23,6 +23,7 @@ PolarsInvalidOperationError, PsycoPg2InternalError, Py4JError, + Py4JJavaError, PyDruidProgrammingError, PyODBCProgrammingError, PySparkAnalysisException, @@ -637,13 +638,10 @@ def mean_and_std(v): id="first", marks=[ pytest.mark.notimpl( - ["druid", "impala", "mssql", "mysql", "oracle", "flink"], + ["druid", "impala", "mssql", "mysql", "oracle"], raises=com.OperationNotDefinedError, ), - pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - ), + pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError), ], ), param( @@ -652,13 +650,10 @@ def mean_and_std(v): id="last", marks=[ pytest.mark.notimpl( - ["druid", "impala", "mssql", "mysql", "oracle", "flink"], + ["druid", "impala", "mssql", "mysql", "oracle"], raises=com.OperationNotDefinedError, ), - pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - ), + pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError), ], ), param( @@ -1048,7 +1043,7 @@ def test_quantile( reason="backend only implements population correlation coefficient", ), pytest.mark.notyet( - ["impala", "mysql", "sqlite"], + ["impala", "mysql", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -1318,11 +1313,7 @@ def test_date_quantile(alltypes, func): reason="ORA-00904: 'GROUP_CONCAT': invalid identifier", ) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) -@pytest.mark.notyet( - ["flink"], - raises=Py4JError, - reason='SQL parse failed. Encountered "group_concat ("', -) +@pytest.mark.notyet(["flink"], raises=Py4JJavaError) def test_group_concat( backend, alltypes, df, ibis_cond, pandas_cond, ibis_sep, pandas_sep ): @@ -1573,7 +1564,7 @@ def test_grouped_case(backend, con): @pytest.mark.notimpl(["datafusion", "polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) -@pytest.mark.notyet(["flink"], raises=com.UnsupportedOperationError) +@pytest.mark.notyet(["flink"], raises=Py4JJavaError) @pytest.mark.notyet(["impala"], raises=ImpalaHiveServer2Error) @pytest.mark.notyet(["clickhouse"], raises=ClickHouseDatabaseError) @pytest.mark.notyet(["druid"], raises=PyDruidProgrammingError) diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 2d203c2c9f926..0bfa55e1d6e98 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -116,8 +116,6 @@ def test_array_repeat(con): assert np.array_equal(result, expected) -# Issues #2370 -@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) def test_array_concat(con): left = ibis.literal([1, 2, 3]) right = ibis.literal([2, 1]) @@ -126,8 +124,6 @@ def test_array_concat(con): assert sorted(result) == sorted([1, 2, 3, 2, 1]) -# Issues #2370 -@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) def test_array_concat_variadic(con): left = ibis.literal([1, 2, 3]) right = ibis.literal([2, 1]) @@ -138,7 +134,7 @@ def test_array_concat_variadic(con): # Issues #2370 -@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["flink"], raises=Py4JJavaError) @pytest.mark.notyet(["trino"], raises=TrinoUserError) def test_array_concat_some_empty(con): left = ibis.literal([]) @@ -149,7 +145,6 @@ def test_array_concat_some_empty(con): assert np.array_equal(result, expected) -@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) def test_array_radd_concat(con): left = [1] right = ibis.literal([2]) @@ -250,7 +245,7 @@ def test_array_discovery(backend): reason="BigQuery doesn't support casting array to array", raises=GoogleBadRequest, ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) def test_unnest_simple(backend): array_types = backend.array_types expected = ( @@ -361,7 +356,7 @@ def test_unnest_no_nulls(backend): raises=ValueError, reason="all the input arrays must have same number of dimensions", ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -547,11 +542,28 @@ def test_array_filter(con, input, output): @builtin_array @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -def test_array_contains(backend, con): +@pytest.mark.parametrize( + ("col", "value"), + [ + param( + "x", + 1, + marks=[ + pytest.mark.broken( + ["flink"], + raises=Py4JJavaError, + reason="unknown; NPE during execution", + ) + ], + ), + ("y", "a"), + ], +) +def test_array_contains(backend, con, col, value): t = backend.array_types - expr = t.x.contains(1) + expr = t[col].contains(value) result = con.execute(expr) - expected = t.x.execute().map(lambda lst: 1 in lst) + expected = t[col].execute().map(lambda lst: value in lst) assert frozenset(result.values) == frozenset(expected.values) @@ -643,9 +655,10 @@ def test_array_remove(con, a): reason="bigquery doesn't support null elements in arrays", ) @pytest.mark.broken( - ["risingwave"], - raises=AssertionError, - reason="TODO(Kexiang): seems a bug", + ["risingwave"], raises=AssertionError, reason="TODO(Kexiang): seems a bug" +) +@pytest.mark.notyet( + ["flink"], raises=Py4JJavaError, reason="empty arrays not supported" ) @pytest.mark.parametrize( ("input", "expected"), @@ -662,9 +675,6 @@ def test_array_remove(con, a): ), ], ) -@pytest.mark.notimpl( - ["flink"], raises=NotImplementedError, reason="`from_ibis()` is not implemented" -) def test_array_unique(con, input, expected): t = ibis.memtable(input) expr = t.a.unique() @@ -789,7 +799,7 @@ def test_array_intersect(con, data): ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) @pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError ) @@ -815,6 +825,7 @@ def test_unnest_struct(con): "polars", "postgres", "risingwave", + "flink", ], raises=com.OperationNotDefinedError, ) @@ -844,7 +855,7 @@ def test_zip(backend): ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) @pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], raises=com.OperationNotDefinedError, diff --git a/ibis/backends/tests/test_asof_join.py b/ibis/backends/tests/test_asof_join.py index 2a1901efc520b..5f2d1ac067c26 100644 --- a/ibis/backends/tests/test_asof_join.py +++ b/ibis/backends/tests/test_asof_join.py @@ -96,6 +96,7 @@ def time_keyed_right(time_keyed_df2): "mssql", "sqlite", "risingwave", + "flink", ] ) def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op): @@ -137,6 +138,7 @@ def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op "mssql", "sqlite", "risingwave", + "flink", ] ) def test_keyed_asof_join_with_tolerance( diff --git a/ibis/backends/tests/test_benchmarks.py b/ibis/backends/tests/test_benchmarks.py index 614f75f419392..5e1295e7b74e2 100644 --- a/ibis/backends/tests/test_benchmarks.py +++ b/ibis/backends/tests/test_benchmarks.py @@ -154,9 +154,7 @@ def test_builtins(benchmark, expr_fn, builtin, t, base, large_expr): benchmark(builtin, expr) -_backends = set(_get_backend_names()) -# compile is a no-op -_backends.remove("pandas") +_backends = _get_backend_names(exclude=("pandas",)) _XFAIL_COMPILE_BACKENDS = ("dask", "polars") diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3339a2c250505..6728cb6633fa8 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -30,7 +30,6 @@ OracleDatabaseError, PsycoPg2InternalError, PsycoPg2UndefinedObject, - Py4JJavaError, PyODBCProgrammingError, SnowflakeProgrammingError, TrinoUserError, @@ -49,6 +48,10 @@ def new_schema(): def _create_temp_table_with_schema(backend, con, temp_table_name, schema, data=None): if con.name == "druid": pytest.xfail("druid doesn't implement create_table") + elif con.name == "flink": + pytest.xfail( + "flink doesn't implement create_table from schema without additional arguments" + ) temporary = con.create_table(temp_table_name, schema=schema) assert temporary.to_pandas().empty @@ -130,12 +133,8 @@ def test_create_table(backend, con, temp_table, func, sch): param( False, True, - marks=[ - pytest.mark.notyet( - ["polars"], raises=com.IbisError, reason="all tables are ephemeral" - ) - ], id="no temp, overwrite", + marks=pytest.mark.notyet(["flink", "polars"]), ), param( True, @@ -182,6 +181,7 @@ def test_create_table_overwrite_temp(backend, con, temp_table, temp, overwrite): ids=["dataframe", "pyarrow table"], ) @pytest.mark.notyet(["druid"], raises=NotImplementedError) +@pytest.mark.notyet(["flink"], raises=com.IbisError) def test_load_data(backend, con, temp_table, lamduh): sch = ibis.schema( [ @@ -221,14 +221,6 @@ def test_load_data(backend, con, temp_table, lamduh): ), ], ) -@pytest.mark.broken( - ["flink"], - raises=Py4JJavaError, - reason=( - "org.apache.flink.table.api.ValidationException: " - "Table `default_catalog`.`default_database`.`functional_alltypes` was not found." - ), -) def test_query_schema(ddl_backend, expr_fn, expected): expr = expr_fn(ddl_backend.functional_alltypes) @@ -250,9 +242,6 @@ def test_query_schema(ddl_backend, expr_fn, expected): @pytest.mark.notimpl(["datafusion", "mssql"]) @pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL") -@pytest.mark.notimpl( - ["flink"], raises=AttributeError, reason="'Backend' object has no attribute 'sql'" -) def test_sql(backend, con): # execute the expression using SQL query table = backend.format_table("functional_alltypes") @@ -316,6 +305,11 @@ def test_create_table_from_schema(con, new_schema, temp_table): raises=PsycoPg2InternalError, reason="truncate not supported upstream", ) +@pytest.mark.notimpl( + ["flink"], + raises=com.IbisError, + reason="`tbl_properties` is required when creating table with schema", +) def test_create_temporary_table_from_schema(con_no_data, new_schema): temp_table = gen_name(f"test_{con_no_data.name}_tmp") table = con_no_data.create_table(temp_table, schema=new_schema, temp=True) @@ -395,16 +389,6 @@ def test_nullable_input_output(con, temp_table): @mark.notimpl(["druid", "polars"]) -@pytest.mark.broken( - ["flink"], - raises=ValueError, - reason=( - "table `FUNCTIONAL_ALLTYPES` does not exist" - "Note (mehmet): Not raised when only this test function is executed, " - "but can be reproduced by running all the test functions in this file." - "TODO (mehmet): Caused by the test execution order?" - ), -) def test_create_drop_view(ddl_con, temp_view): # setup table_name = "functional_alltypes" @@ -609,6 +593,11 @@ def _emp(a, b, c, d): reason="`insert` method not implemented", ) @pytest.mark.notyet(["druid"], raises=NotImplementedError) +@pytest.mark.notimpl( + ["flink"], + raises=com.IbisError, + reason="`tbl_properties` is required when creating table with schema", +) def test_insert_from_memtable(con, temp_table): df = pd.DataFrame({"x": range(3)}) table_name = temp_table @@ -644,6 +633,7 @@ def test_list_databases(con): "pyspark": set(), "sqlite": {"main"}, "trino": {"memory"}, + "flink": set(), } result = set(con.list_databases()) assert test_databases[con.name] <= result @@ -672,6 +662,11 @@ def test_list_databases(con): raises=PsycoPg2InternalError, reason="unsigned integers are not supported", ) +@pytest.mark.notimpl( + ["flink"], + raises=com.IbisError, + reason="`tbl_properties` is required when creating table with schema", +) def test_unsigned_integer_type(con, temp_table): con.create_table( temp_table, diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 26d3936e47561..39c407beb9599 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -12,7 +12,7 @@ from ibis import _ from ibis.backends.base import _get_backend_names -# here to load the dialect in to sqlglot so we can use it for transpilation +# import here to load the dialect in to sqlglot so we can use it for transpilation from ibis.backends.base.sqlglot.dialects import ( # noqa: F401 MSSQL, DataFusion, @@ -25,12 +25,12 @@ RisingWave, ) from ibis.backends.tests.errors import ( + ExaQueryError, GoogleBadRequest, OracleDatabaseError, PolarsComputeError, ) -dot_sql_notimpl = pytest.mark.notimpl(["flink"]) dot_sql_never = pytest.mark.never( ["dask", "pandas"], reason="dask and pandas do not accept SQL" ) @@ -41,7 +41,6 @@ } -@pytest.mark.notimpl(["flink"]) @pytest.mark.notyet(["oracle"], reason="table quoting behavior") @dot_sql_never @pytest.mark.parametrize( @@ -94,7 +93,6 @@ def test_con_dot_sql(backend, con, schema): @pytest.mark.notyet( ["druid"], raises=com.IbisTypeError, reason="druid does not preserve case" ) -@dot_sql_notimpl @dot_sql_never def test_table_dot_sql(backend): alltypes = backend.functional_alltypes @@ -142,7 +140,6 @@ def test_table_dot_sql(backend): OracleDatabaseError, reason="oracle doesn't know which of the tables in the join to sort from", ) -@dot_sql_notimpl @dot_sql_never def test_table_dot_sql_with_join(backend): alltypes = backend.functional_alltypes @@ -194,7 +191,6 @@ def test_table_dot_sql_with_join(backend): @pytest.mark.notyet( ["bigquery"], raises=GoogleBadRequest, reason="requires a qualified name" ) -@dot_sql_notimpl @dot_sql_never def test_table_dot_sql_repr(backend): alltypes = backend.functional_alltypes @@ -220,7 +216,6 @@ def test_table_dot_sql_repr(backend): assert repr(t) -@dot_sql_notimpl @dot_sql_never def test_dot_sql_alias_with_params(backend, alltypes, df): t = alltypes @@ -230,7 +225,6 @@ def test_dot_sql_alias_with_params(backend, alltypes, df): backend.assert_series_equal(result.x, expected) -@dot_sql_notimpl @dot_sql_never def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): foo1 = alltypes.select(x=alltypes.string_col).alias("foo") @@ -241,21 +235,15 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): backend.assert_series_equal(foo2.x.execute(), expected2) -_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink"} -no_sqlglot_dialect = sorted( - # TODO(cpcloud): remove the strict=False hack once backends are ported to - # sqlglot - param(backend, marks=pytest.mark.xfail(strict=False)) - for backend in _NO_SQLGLOT_DIALECT -) +_NO_SQLGLOT_DIALECT = ("pandas", "dask") +no_sqlglot_dialect = [ + param(dialect, marks=pytest.mark.xfail) for dialect in sorted(_NO_SQLGLOT_DIALECT) +] +dialects = sorted(_get_backend_names(exclude=_NO_SQLGLOT_DIALECT)) + no_sqlglot_dialect -@pytest.mark.parametrize( - "dialect", - [*sorted(_get_backend_names() - _NO_SQLGLOT_DIALECT), *no_sqlglot_dialect], -) +@pytest.mark.parametrize("dialect", dialects) @pytest.mark.notyet(["polars"], raises=PolarsComputeError) -@dot_sql_notimpl @dot_sql_never @pytest.mark.notyet(["druid"], reason="druid doesn't respect column name case") def test_table_dot_sql_transpile(backend, alltypes, dialect, df): @@ -269,18 +257,11 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): backend.assert_series_equal(result.x, expected) -@pytest.mark.parametrize( - "dialect", - [ - *sorted(_get_backend_names() - {"pyspark", *_NO_SQLGLOT_DIALECT}), - *no_sqlglot_dialect, - ], -) +@pytest.mark.parametrize("dialect", dialects) @pytest.mark.notyet( ["druid"], raises=AttributeError, reason="druid doesn't respect column names" ) @pytest.mark.notyet(["bigquery"]) -@dot_sql_notimpl @dot_sql_never def test_con_dot_sql_transpile(backend, con, dialect, df): t = sg.table("functional_alltypes", quoted=True) @@ -294,9 +275,13 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): backend.assert_series_equal(result.x, expected) -@dot_sql_notimpl @dot_sql_never -@pytest.mark.notimpl(["druid", "flink", "polars", "exasol"]) +@pytest.mark.notimpl(["druid", "polars"]) +@pytest.mark.notimpl( + ["exasol"], + raises=ExaQueryError, + reason="loading the test data is a pain because of embedded commas", +) def test_order_by_no_projection(backend): con = backend.connection expr = ( @@ -309,7 +294,6 @@ def test_order_by_no_projection(backend): assert set(result) == {"Ross, Jerry L.", "Chang-Diaz, Franklin R."} -@dot_sql_notimpl @dot_sql_never @pytest.mark.notyet(["polars"], raises=PolarsComputeError) def test_dot_sql_limit(con): @@ -328,15 +312,19 @@ def mem_t(con): pytest.xfail("druid does not support create_table") name = ibis.util.gen_name(con.name) - con.create_table(name, ibis.memtable({"a": list("def")})) + + # flink only supports memtables if `temp` is True, seems like we should + # address that for users + con.create_table( + name, ibis.memtable({"a": list("def")}), temp=con.name == "flink" or None + ) yield name with contextlib.suppress(NotImplementedError): con.drop_table(name, force=True) -@dot_sql_notimpl @dot_sql_never -@pytest.mark.notyet(["polars"], raises=PolarsComputeError) +@pytest.mark.notyet(["polars"], raises=NotImplementedError) def test_cte(con, mem_t): t = con.table(mem_t) foo = t.alias("foo") diff --git a/ibis/backends/tests/test_examples.py b/ibis/backends/tests/test_examples.py index 845a20e3b3df5..cfcd893cfb1de 100644 --- a/ibis/backends/tests/test_examples.py +++ b/ibis/backends/tests/test_examples.py @@ -15,7 +15,7 @@ (LINUX or MACOS) and SANDBOXED, reason="nix on linux cannot download duckdb extensions or data due to sandboxing", ) -@pytest.mark.notimpl(["pyspark", "flink", "exasol"]) +@pytest.mark.notimpl(["pyspark", "exasol"]) @pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino", "risingwave"]) @pytest.mark.parametrize( ("example", "columns"), diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 02eefb296c3a8..03ee2ccebe01f 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -29,16 +29,8 @@ param( 42, id="limit", - marks=[ - pytest.mark.notimpl( - [ - # limit not implemented for flink and pandas backend execution - "dask", - "pandas", - "flink", - ] - ), - ], + # limit not implemented for pandas-family backends + marks=[pytest.mark.notimpl(["dask", "pandas"])], ), ] @@ -301,7 +293,6 @@ def test_memtable_to_file(tmp_path, con, ftype, monkeypatch): assert outfile.is_file() -@pytest.mark.notimpl(["flink"]) def test_table_to_csv(tmp_path, backend, awards_players): outcsv = tmp_path / "out.csv" @@ -315,7 +306,6 @@ def test_table_to_csv(tmp_path, backend, awards_players): backend.assert_frame_equal(awards_players.to_pandas(), df) -@pytest.mark.notimpl(["flink"]) @pytest.mark.notimpl( ["duckdb"], reason="cannot inline WriteOptions objects", @@ -339,10 +329,7 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): dt.Decimal(38, 9), pa.Decimal128Type, id="decimal128", - marks=[ - pytest.mark.notyet(["flink"], raises=NotImplementedError), - pytest.mark.notyet(["exasol"], raises=ExaQueryError), - ], + marks=[pytest.mark.notyet(["exasol"], raises=ExaQueryError)], ), param( dt.Decimal(76, 38), @@ -361,7 +348,6 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): raises=(PySparkParseException, PySparkArithmeticException), reason="precision is out of range", ), - pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), ], ), @@ -480,14 +466,7 @@ def test_to_pandas_batches_empty_table(backend, con): assert sum(map(len, t.to_pandas_batches())) == n -@pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize( - "n", - [ - None, - 1, - ], -) +@pytest.mark.parametrize("n", [None, 1]) def test_to_pandas_batches_nonempty_table(backend, con, n): t = backend.functional_alltypes.limit(n) n = t.count().execute() @@ -496,16 +475,7 @@ def test_to_pandas_batches_nonempty_table(backend, con, n): assert sum(map(len, t.to_pandas_batches())) == n -@pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize( - "n", - [ - None, - 0, - 1, - 2, - ], -) +@pytest.mark.parametrize("n", [None, 0, 1, 2]) def test_to_pandas_batches_column(backend, con, n): t = backend.functional_alltypes.limit(n).timestamp_col n = t.count().execute() diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ab353c2acfa27..ca876e4a30401 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -26,7 +26,6 @@ MySQLProgrammingError, OracleDatabaseError, PsycoPg2InternalError, - Py4JJavaError, PyDruidProgrammingError, PyODBCDataError, PyODBCProgrammingError, @@ -120,6 +119,11 @@ def test_scalar_fillna_nullif(con, expr, expected): raises=ExaQueryError, reason="no way to test for nan-ness", ), + pytest.mark.notyet( + ["flink"], + "NaN is not supported in Flink SQL", + raises=NotImplementedError, + ), ], id="nan_col", ), @@ -128,7 +132,6 @@ def test_scalar_fillna_nullif(con, expr, expected): ), ], ) -@pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError) def test_isna(backend, alltypes, col, value, filt): table = alltypes.select(**{col: value}) df = table.execute() @@ -168,7 +171,9 @@ def test_isna(backend, alltypes, col, value, filt): reason="NaN != NULL for these backends", ), pytest.mark.notyet( - ["flink"], "NaN is not supported in Flink SQL", raises=ValueError + ["flink"], + "NaN is not supported in Flink SQL", + raises=NotImplementedError, ), ], id="nan_col", @@ -373,7 +378,9 @@ def test_case_where(backend, alltypes, df): # TODO: some of these are notimpl (datafusion) others are probably never @pytest.mark.notimpl(["mysql", "sqlite", "mssql", "druid", "exasol"]) -@pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError) +@pytest.mark.notyet( + ["flink"], "NaN is not supported in Flink SQL", raises=NotImplementedError +) def test_select_filter_mutate(backend, alltypes, df): """Test that select, filter and mutate are executed in right order. @@ -484,14 +491,15 @@ def test_dropna_invalid(alltypes): @pytest.mark.parametrize( "subset", [ - None, + param(None, id="none"), param( [], marks=pytest.mark.notimpl(["exasol"], raises=ExaQueryError, strict=False), + id="empty", ), - "col_1", - ["col_1", "col_2"], - ["col_1", "col_3"], + param("col_1", id="single"), + param(["col_1", "col_2"], id="one-and-two"), + param(["col_1", "col_3"], id="one-and-three"), ], ) def test_dropna_table(backend, alltypes, how, subset): @@ -748,11 +756,6 @@ def test_between(backend, alltypes, df): @pytest.mark.notimpl(["druid"]) -@pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="Flink does not support now() - t.`timestamp_col`", -) def test_interactive(alltypes, monkeypatch): monkeypatch.setattr(ibis.options, "interactive", True) @@ -995,7 +998,6 @@ def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns) @pytest.mark.notimpl( ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) -@pytest.mark.notimpl(["flink"], reason="no sqlglot dialect", raises=ValueError) def test_many_subqueries(con, snapshot): def query(t, group_cols): t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) @@ -1390,7 +1392,7 @@ def test_try_cast(con, from_val, to_type, expected): "int", marks=[ pytest.mark.never( - ["clickhouse", "pyspark"], reason="casts to 1672531200" + ["clickhouse", "pyspark", "flink"], reason="casts to 1672531200" ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), pytest.mark.notyet(["trino"], raises=TrinoUserError), @@ -1750,7 +1752,7 @@ def test_dynamic_table_slice_with_computed_offset(backend): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl(["druid", "polars", "snowflake"]) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -1771,7 +1773,7 @@ def test_sample(backend): backend.assert_frame_equal(empty, df.iloc[:0]) -@pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl(["druid", "polars", "snowflake"]) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -1829,7 +1831,6 @@ def test_substitute(backend): @pytest.mark.notimpl( ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) -@pytest.mark.notimpl(["flink"], reason="no sqlglot dialect", raises=ValueError) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index fa36dea7ab0ce..14d3263a7a2b9 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -289,7 +289,6 @@ def test_join_with_trivial_predicate(awards_players, predicate, how, pandas_valu @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) -@pytest.mark.notimpl(["flink"], reason="`win` table isn't loaded") @pytest.mark.parametrize( ("how", "nrows", "gen_right", "keys"), [ diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 441332cc8f6d1..4efc9fd0527c1 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -207,11 +207,6 @@ def test_literal_map_merge(con): assert con.execute(expr) == {"a": 1, "b": 2, "c": 3} -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -229,11 +224,6 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -296,11 +286,6 @@ def test_map_construct_array_column(con, alltypes, df): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", -) def test_map_get_with_compatible_value_smaller(con): value = ibis.literal({"A": 1000, "B": 2000}) expr = value.get("C", 3) @@ -310,11 +295,6 @@ def test_map_get_with_compatible_value_smaller(con): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", -) def test_map_get_with_compatible_value_bigger(con): value = ibis.literal({"A": 1, "B": 2}) expr = value.get("C", 3000) @@ -324,11 +304,6 @@ def test_map_get_with_compatible_value_bigger(con): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="NotImplementedError: No translation rule for map", -) def test_map_get_with_incompatible_value_different_kind(con): value = ibis.literal({"A": 1000, "B": 2000}) expr = value.get("C", 3.0) @@ -339,11 +314,6 @@ def test_map_get_with_incompatible_value_different_kind(con): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", -) def test_map_get_with_null_on_not_nullable(con, null_value): map_type = dt.Map(dt.string, dt.Int16(nullable=False)) value = ibis.literal({"A": 1000, "B": 2000}).cast(map_type) @@ -353,10 +323,8 @@ def test_map_get_with_null_on_not_nullable(con, null_value): @pytest.mark.parametrize("null_value", [None, ibis.NA]) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", +@pytest.mark.notyet( + ["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls" ) @pytest.mark.notimpl( ["risingwave"], @@ -373,10 +341,8 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=NotImplementedError, - reason="No translation rule for map", +@pytest.mark.notyet( + ["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls" ) def test_map_get_with_null_on_null_type_with_non_null(con): value = ibis.literal({"A": None, "B": None}) diff --git a/ibis/backends/tests/test_network.py b/ibis/backends/tests/test_network.py index dca5815c68554..0947ecf6f4d20 100644 --- a/ibis/backends/tests/test_network.py +++ b/ibis/backends/tests/test_network.py @@ -25,7 +25,7 @@ } -@pytest.mark.notimpl(["flink", "polars"], raises=NotImplementedError) +@pytest.mark.notimpl(["polars"], raises=NotImplementedError) def test_macaddr_literal(con, backend): test_macaddr = "00:00:0A:BB:28:FC" expr = ibis.literal(test_macaddr, type=dt.macaddr) @@ -110,7 +110,7 @@ def test_macaddr_literal(con, backend): ), ], ) -@pytest.mark.notimpl(["flink", "polars"], raises=NotImplementedError) +@pytest.mark.notimpl(["polars"], raises=NotImplementedError) @pytest.mark.notimpl(["druid", "oracle"], raises=KeyError) @pytest.mark.notimpl(["exasol"], raises=(ExaQueryError, KeyError)) def test_inet_literal(con, backend, test_value, expected_values, expected_types): diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 79989561ae883..3dc233ffdcd13 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -26,6 +26,7 @@ PsycoPg2DivisionByZero, PsycoPg2InternalError, Py4JError, + Py4JJavaError, PyDruidProgrammingError, PyODBCDataError, PyODBCProgrammingError, @@ -52,7 +53,7 @@ "duckdb": "TINYINT", "postgres": "integer", "risingwave": "integer", - "flink": "TINYINT NOT NULL", + "flink": "INT NOT NULL", }, id="int8", ), @@ -68,7 +69,7 @@ "duckdb": "SMALLINT", "postgres": "integer", "risingwave": "integer", - "flink": "SMALLINT NOT NULL", + "flink": "INT NOT NULL", }, id="int16", ), @@ -100,7 +101,7 @@ "duckdb": "BIGINT", "postgres": "integer", "risingwave": "integer", - "flink": "BIGINT NOT NULL", + "flink": "INT NOT NULL", }, id="int64", ), @@ -116,7 +117,7 @@ "duckdb": "UTINYINT", "postgres": "integer", "risingwave": "integer", - "flink": "TINYINT NOT NULL", + "flink": "INT NOT NULL", }, id="uint8", ), @@ -132,7 +133,7 @@ "duckdb": "USMALLINT", "postgres": "integer", "risingwave": "integer", - "flink": "SMALLINT NOT NULL", + "flink": "INT NOT NULL", }, id="uint16", ), @@ -164,7 +165,7 @@ "duckdb": "UBIGINT", "postgres": "integer", "risingwave": "integer", - "flink": "BIGINT NOT NULL", + "flink": "INT NOT NULL", }, id="uint64", ), @@ -180,7 +181,7 @@ "duckdb": "FLOAT", "postgres": "numeric", "risingwave": "numeric", - "flink": "FLOAT NOT NULL", + "flink": "DECIMAL(2, 1) NOT NULL", }, marks=[ pytest.mark.notimpl( @@ -208,7 +209,7 @@ "duckdb": "FLOAT", "postgres": "numeric", "risingwave": "numeric", - "flink": "FLOAT NOT NULL", + "flink": "DECIMAL(2, 1) NOT NULL", }, id="float32", ), @@ -224,7 +225,7 @@ "duckdb": "DOUBLE", "postgres": "numeric", "risingwave": "numeric", - "flink": "DOUBLE NOT NULL", + "flink": "DECIMAL(2, 1) NOT NULL", }, id="float64", ), @@ -373,7 +374,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.notyet( ["flink"], "The precision can be up to 38 in Flink", - raises=ValueError, + raises=Py4JJavaError, ), pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError), ], @@ -431,7 +432,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.notyet( ["flink"], "Infinity is not supported in Flink SQL", - raises=ValueError, + raises=Py4JJavaError, ), pytest.mark.notyet( ["snowflake"], @@ -490,7 +491,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.notyet( ["flink"], "Infinity is not supported in Flink SQL", - raises=ValueError, + raises=Py4JJavaError, ), pytest.mark.notyet( ["snowflake"], @@ -557,7 +558,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.notyet( ["flink"], "NaN is not supported in Flink SQL", - raises=ValueError, + raises=Py4JJavaError, ), pytest.mark.notyet( ["snowflake"], @@ -659,8 +660,9 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): ), ], ) +@pytest.mark.notimpl(["sqlite", "mssql", "druid"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( - ["sqlite", "mssql", "flink", "druid"], raises=com.OperationNotDefinedError + ["flink"], raises=(com.OperationNotDefinedError, NotImplementedError) ) @pytest.mark.notimpl(["mysql"], raises=(MySQLOperationalError, NotImplementedError)) def test_isnan_isinf( @@ -716,16 +718,7 @@ def test_isnan_isinf( param(L(5.556).exp(), math.exp(5.556), id="exp"), param(L(5.556).sign(), 1, id="sign-pos"), param(L(-5.556).sign(), -1, id="sign-neg"), - param( - L(0).sign(), - 0, - id="sign-zero", - marks=pytest.mark.broken( - ["flink"], - "An error occurred while calling z:org.apache.flink.table.runtime.arrow.ArrowUtils.collectAsPandasDataFrame.", - raises=Py4JError, - ), - ), + param(L(0).sign(), 0, id="sign-zero"), param(L(5.556).sqrt(), math.sqrt(5.556), id="sqrt"), param( L(5.556).log(2), @@ -1130,7 +1123,7 @@ def test_floating_mod(backend, alltypes, df): ), pytest.mark.notyet( "flink", - raises=Py4JError, + raises=Py4JJavaError, reason="Flink doesn't do integer division by zero", ), ], @@ -1146,7 +1139,7 @@ def test_floating_mod(backend, alltypes, df): ), pytest.mark.notyet( "flink", - raises=Py4JError, + raises=Py4JJavaError, reason="Flink doesn't do integer division by zero", ), ], @@ -1162,7 +1155,7 @@ def test_floating_mod(backend, alltypes, df): ), pytest.mark.notyet( "flink", - raises=Py4JError, + raises=Py4JJavaError, reason="Flink doesn't do integer division by zero", ), ], @@ -1178,7 +1171,7 @@ def test_floating_mod(backend, alltypes, df): ), pytest.mark.notyet( "flink", - raises=Py4JError, + raises=Py4JJavaError, reason="Flink doesn't do integer division by zero", ), ], @@ -1195,6 +1188,11 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet( + "flink", + raises=Py4JJavaError, + reason="Flink doesn't do integer division by zero", + ), ], ), param( @@ -1207,6 +1205,11 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet( + "flink", + raises=Py4JJavaError, + reason="Flink doesn't do integer division by zero", + ), ], ), param( @@ -1219,6 +1222,11 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet( + "flink", + raises=Py4JJavaError, + reason="Flink doesn't do integer division by zero", + ), ], ), param( @@ -1231,6 +1239,11 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet( + "flink", + raises=Py4JJavaError, + reason="Flink doesn't do integer division by zero", + ), ], ), param( @@ -1343,11 +1356,6 @@ def test_clip(backend, alltypes, df, ibis_func, pandas_func): raises=PyDruidProgrammingError, reason="SQL query requires 'MIN' operator that is not supported.", ) -@pytest.mark.never( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink does not support 'MIN' or 'MAX' operation without specifying window.", -) def test_histogram(con, alltypes): n = 10 hist = con.execute(alltypes.int_col.histogram(n).name("hist")) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index e1f59c1fcbf90..3be9222cee342 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -11,11 +11,7 @@ import ibis import ibis.expr.datatypes as dt from ibis import _ -from ibis.backends.tests.errors import ( - OracleDatabaseError, - PsycoPg2InternalError, - Py4JJavaError, -) +from ibis.backends.tests.errors import OracleDatabaseError, PsycoPg2InternalError @pytest.mark.parametrize( @@ -103,15 +99,6 @@ def test_scalar_param_struct(con): reason="mysql and sqlite will never implement map types", ) @pytest.mark.notyet(["bigquery"]) -@pytest.mark.notimpl( - ["flink"], - "WIP", - raises=Py4JJavaError, - reason=( - "SqlParseException: Expecting alias, found character literal" - "sql= SELECT MAP_FROM_ARRAYS(ARRAY['a', 'b', 'c'], ARRAY['ghi', 'def', 'abc']) '[' 'b' ']' AS `MapGet(param_0, 'b', None)`" - ), -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index a9d526c30146f..a5edcf3c5db81 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -10,8 +10,6 @@ sg = pytest.importorskip("sqlglot") -pytestmark = pytest.mark.notimpl(["flink"]) - simple_literal = param(ibis.literal(1), id="simple_literal") array_literal = param( ibis.array([1]), diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 989d6dc393db6..a62d59734dbac 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -604,10 +604,7 @@ def uses_java_re(t): lambda t: t.date_string_col[-2], lambda t: t.date_string_col.str[-2], id="negative-index", - marks=[ - pytest.mark.broken(["druid"], raises=PyDruidProgrammingError), - pytest.mark.broken(["flink"], raises=AssertionError), - ], + marks=[pytest.mark.broken(["druid"], raises=PyDruidProgrammingError)], ), param( lambda t: t.date_string_col[t.date_string_col.length() - 1 :], diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 0987e991d4fc8..2cf24eb95ed4f 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -274,14 +274,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), ], ), param( @@ -292,14 +284,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), ], ), param( @@ -310,14 +294,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), ], ), param( @@ -325,17 +301,9 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): marks=[ pytest.mark.notimpl(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( - ["polars"], + ["polars", "flink"], raises=AssertionError, - reason="numpy array are different", - ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), + reason="implemented, but doesn't match other backends", ), ], ), @@ -348,14 +316,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), ], ), param( @@ -367,14 +327,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), ], ), param( @@ -386,11 +338,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.never( - ["flink"], - raises=com.UnsupportedOperationError, - reason=" unit is not supported in timestamp truncate", - ), ], ), param( @@ -405,11 +352,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason=" unit is not supported in timestamp truncate", - ), ], ), param( @@ -424,10 +366,10 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=AssertionError, reason="numpy array are different", ), - pytest.mark.notimpl( + pytest.mark.notyet( ["flink"], - raises=com.UnsupportedOperationError, - reason=" unit is not supported in timestamp truncate", + raises=Py4JJavaError, + reason="microseconds not supported in truncation", ), ], ), @@ -457,10 +399,10 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): raises=PolarsPanicException, reason="attempt to calculate the remainder with a divisor of zero", ), - pytest.mark.notimpl( + pytest.mark.notyet( ["flink"], - raises=com.UnsupportedOperationError, - reason=" unit is not supported in timestamp truncate", + raises=Py4JJavaError, + reason="nanoseconds not supported in truncation", ), ], ), @@ -491,61 +433,17 @@ def test_timestamp_truncate(backend, alltypes, df, unit): @pytest.mark.parametrize( "unit", [ - param( - "Y", - marks=[ - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), - ], - ), - param( - "M", - marks=[ - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), - ], - ), - param( - "D", - marks=[ - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), - ], - ), + "Y", + "M", + "D", param( "W", marks=[ pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "CalciteContextException: No match found for function signature trunc(, )" - "Timestamp truncation is not supported in Flink" - ), - ), pytest.mark.broken( - ["exasol"], + ["flink", "exasol"], raises=AssertionError, - reason="behavior is different than expected", + reason="Implemented, but behavior doesn't match other backends", ), ], ), @@ -984,9 +882,9 @@ def convert_to_offset(x): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ), - pytest.mark.broken( + pytest.mark.notimpl( ["flink"], - raises=com.UnsupportedOperationError, + raises=com.OperationNotDefinedError, reason="DATE_DIFF is not supported in Flink", ), pytest.mark.broken( @@ -1207,26 +1105,7 @@ def test_temporal_binop_pandas_timedelta( backend.assert_series_equal(result, expected.astype(result.dtype)) -@pytest.mark.parametrize( - "func_name", - [ - "gt", - "ge", - "lt", - "le", - "eq", - param( - "ne", - marks=[ - pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="SqlParseException: Bang equal '!=' is not allowed under the current SQL conformance level", - ), - ], - ), - ], -) +@pytest.mark.parametrize("func_name", ["gt", "ge", "lt", "le", "eq", "ne"]) @pytest.mark.notimpl( ["polars"], raises=TypeError, @@ -1316,16 +1195,7 @@ def test_timestamp_comparison_filter(backend, con, alltypes, df, func_name): ], ), "eq", - param( - "ne", - marks=[ - pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="SqlParseException: Bang equal '!=' is not allowed under the current SQL conformance level", - ), - ], - ), + "ne", ], ) @pytest.mark.broken( @@ -1370,11 +1240,6 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ) -@pytest.mark.broken( - ["flink"], - raises=Py4JJavaError, - reason="ParseException: Encountered '+ INTERVAL CAST'", -) def test_interval_add_cast_scalar(backend, alltypes): timestamp_date = alltypes.timestamp_col.date() delta = ibis.literal(10).cast("interval('D')") @@ -1393,6 +1258,7 @@ def test_interval_add_cast_scalar(backend, alltypes): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ) +@pytest.mark.broken(["flink"], raises=AssertionError, reason="incorrect results") def test_interval_add_cast_column(backend, alltypes, df): timestamp_date = alltypes.timestamp_col.date() delta = alltypes.bigint_col.cast("interval('D')") @@ -1438,10 +1304,6 @@ def test_interval_add_cast_column(backend, alltypes, df): raises=AttributeError, reason="'StringColumn' object has no attribute 'strftime'", ), - pytest.mark.notimpl( - ["flink"], - raises=AssertionError, - ), ], id="column_format_str", ), @@ -1638,11 +1500,6 @@ def test_string_to_timestamp(alltypes, fmt): ], ) @pytest.mark.notimpl(["druid", "oracle"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="DayOfWeekName is not supported in Flink", -) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["risingwave"], @@ -1702,20 +1559,6 @@ def test_day_of_week_column(backend, alltypes, df): raises=AssertionError, reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", ), - pytest.mark.never( - ["flink"], - raises=Py4JJavaError, - reason=( - "SqlValidatorException: No match found for function signature dayname()" - "`day_of_week_name` is not supported in Flink" - "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" - ), - ), - pytest.mark.broken( - ["risingwave"], - raises=AssertionError, - reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", - ), ], ), ], @@ -1959,13 +1802,10 @@ def test_time_literal(con, backend): reason="doesn't have enough precision to capture microseconds", ), pytest.mark.notyet(["trino"], raises=AssertionError), - pytest.mark.notimpl( + pytest.mark.notyet( ["flink"], raises=AssertionError, - reason=( - "Flink does not support microsecond precision in time." - "assert datetime.time(13, 20, 5) == datetime.time(13, 20, 5, 561021)" - ), + reason="flink doesn't preserve subsecond information", ), ], ), @@ -2176,7 +2016,6 @@ def test_integer_cast_to_timestamp_scalar(alltypes, df): ) @pytest.mark.broken( ["flink"], - reason="Casting from timestamp[s] to timestamp[ns] would result in out of bounds timestamp: 81953424000", raises=ArrowInvalid, ) @pytest.mark.notyet(["polars"], raises=PolarsComputeError) @@ -2272,15 +2111,10 @@ def test_large_timestamp(con): id="us", marks=[ pytest.mark.notyet( - ["sqlite"], + ["sqlite", "flink"], reason="doesn't support microseconds", raises=AssertionError, ), - pytest.mark.broken( - ["flink"], - reason="assert Timestamp('2023-01-07 13:20:05.561000') == Timestamp('2023-01-07 13:20:05.561021')", - raises=AssertionError, - ), pytest.mark.notyet( ["druid"], reason="time_parse truncates to milliseconds", @@ -2408,15 +2242,6 @@ def test_timestamp_precision_output(con, ts, scale, unit): raises=com.OperationNotDefinedError, reason="timestampdiff rounds after subtraction and mysql doesn't have a date_trunc function", ), - pytest.mark.broken( - ["flink"], - raises=AssertionError, - reason=( - "assert 1 == 2" - "Note (mehmet): Flink rounds the time difference down not up." - "Hence computes 1 hour difference between 23:59:59 and 01:58:00." - ), - ), ], ), ], @@ -2601,10 +2426,7 @@ def test_timestamp_bucket_offset(backend, offset_mins): ) @pytest.mark.parametrize( "dialect", - [ - *sorted(_get_backend_names() - {*_NO_SQLGLOT_DIALECT}), - *no_sqlglot_dialect, - ], + [*sorted(_get_backend_names(exclude=_NO_SQLGLOT_DIALECT)), *no_sqlglot_dialect], ) def test_temporal_literal_sql(value, dialect, snapshot): expr = ibis.literal(value) @@ -2621,8 +2443,15 @@ def test_temporal_literal_sql(value, dialect, snapshot): "dialect", [ *sorted( - _get_backend_names() - - {"pyspark", "impala", "clickhouse", "oracle", *_NO_SQLGLOT_DIALECT} + _get_backend_names( + exclude=( + "pyspark", + "impala", + "clickhouse", + "oracle", + *_NO_SQLGLOT_DIALECT, + ) + ) ), *no_sqlglot_dialect, ], diff --git a/ibis/backends/tests/test_timecontext.py b/ibis/backends/tests/test_timecontext.py index a974bc9ee2961..5b335fe2b8f3d 100644 --- a/ibis/backends/tests/test_timecontext.py +++ b/ibis/backends/tests/test_timecontext.py @@ -122,7 +122,7 @@ def test_context_adjustment_filter_before_window( @pytest.mark.notimpl(["duckdb"]) @pytest.mark.notimpl( ["flink"], - raises=com.UnsupportedOperationError, + raises=com.OperationNotDefinedError, reason="Flink engine does not support generic window clause with no order by", ) def test_context_adjustment_multi_col_udf_non_grouped( diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 768872f76f468..3f8370ad24147 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -88,13 +88,6 @@ def calc_zscore(s): lambda t, win: t.float_col.lag().over(win), lambda t: t.float_col.shift(1), id="lag", - marks=[ - pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="CalciteContextException: ROW/RANGE not allowed with RANK, DENSE_RANK or ROW_NUMBER functions", - ), - ], ), param( lambda t, win: t.float_col.lead().over(win), @@ -106,11 +99,6 @@ def calc_zscore(s): reason="upstream is broken; returns all nulls", raises=AssertionError, ), - pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="CalciteContextException: ROW/RANGE not allowed with RANK, DENSE_RANK or ROW_NUMBER functions", - ), ], ), param( @@ -178,11 +166,6 @@ def calc_zscore(s): raises=AssertionError, reason="Results are shifted + 1", ), - pytest.mark.broken( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Windows in Flink can only be ordered by a single time column", - ), pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -434,11 +417,6 @@ def test_ungrouped_bounded_expanding_window( ], ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="OVER RANGE FOLLOWING windows are not supported in Flink yet", -) def test_grouped_bounded_following_window(backend, alltypes, df, preceding, following): window = ibis.window( preceding=preceding, @@ -569,7 +547,6 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn): False, id="unordered", marks=[ - pytest.mark.notimpl(["flink"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["mssql"], raises=PyODBCProgrammingError, @@ -621,7 +598,6 @@ def test_grouped_unbounded_window( @pytest.mark.broken(["dask"], raises=AssertionError) @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["flink"], raises=com.UnsupportedOperationError) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -642,8 +618,8 @@ def test_simple_ungrouped_unbound_following_window( @pytest.mark.notimpl( ["flink"], - raises=com.UnsupportedOperationError, - reason="OVER RANGE FOLLOWING windows are not supported in Flink yet", + raises=Py4JJavaError, + reason="flink doesn't allow order by NULL without casting the null to a specific type", ) @pytest.mark.never( ["mssql"], raises=Exception, reason="order by constant is not supported" @@ -674,7 +650,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): id="ordered-mean", marks=[ pytest.mark.broken( - ["flink", "impala"], + ["impala"], reason="default window semantics are different", raises=AssertionError, ), @@ -691,11 +667,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): False, id="unordered-mean", marks=[ - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink engine does not support generic window clause with no order by", - ), pytest.mark.broken( ["mssql"], raises=PyODBCProgrammingError, @@ -712,11 +683,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): pytest.mark.notimpl( ["pandas", "dask"], raises=com.OperationNotDefinedError ), - pytest.mark.notimpl( - ["flink"], - raises=Py4JJavaError, - reason="CalciteContextException: Argument to function 'NTILE' must be a literal", - ), pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -774,14 +740,10 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "trino", "datafusion", "exasol", + "flink", ], raises=com.OperationNotDefinedError, ), - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink engine does not support generic window clause with no order by", - ), ], ), # Analytic ops @@ -810,11 +772,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=AssertionError, strict=False, # sometimes it passes ), - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink engine does not support generic window clause with no order by", - ), + pytest.mark.notyet(["flink"], raises=Py4JJavaError), pytest.mark.broken(["mssql"], raises=PyODBCProgrammingError), pytest.mark.notyet( ["snowflake"], @@ -856,11 +814,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=AssertionError, strict=False, # sometimes it passes ), - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink engine does not support generic window clause with no order by", - ), + pytest.mark.notyet(["flink"], raises=Py4JJavaError), pytest.mark.broken(["mssql"], raises=PyODBCProgrammingError), pytest.mark.notyet( ["snowflake"], @@ -926,14 +880,10 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "trino", "datafusion", "exasol", + "flink", ], raises=com.OperationNotDefinedError, ), - pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Flink engine does not support generic window clause with no order by", - ), ], ), ], @@ -970,11 +920,6 @@ def test_ungrouped_unbounded_window( @pytest.mark.notimpl( ["impala"], raises=ImpalaHiveServer2Error, reason="limited RANGE support" ) -@pytest.mark.notimpl( - ["flink"], - raises=com.UnsupportedOperationError, - reason="Data Type mismatch between ORDER BY and RANGE clause", -) @pytest.mark.notyet( ["clickhouse"], reason="RANGE OFFSET frame for 'DB::ColumnNullable' ORDER BY column is not implemented", @@ -1115,11 +1060,6 @@ def test_mutate_window_filter(backend, alltypes): @pytest.mark.notimpl(["polars", "exasol"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["flink"], - raises=Exception, - reason="KeyError: Table with name win doesn't exist.", -) def test_first_last(backend): t = backend.win w = ibis.window(group_by=t.g, order_by=[t.x, t.y], preceding=1, following=0) @@ -1252,7 +1192,6 @@ def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): raises=PyODBCProgrammingError, ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) -@pytest.mark.notyet(["flink"], raises=com.UnsupportedOperationError) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 767b498b08f80..537700e1bed54 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1774,7 +1774,7 @@ def test_merge_as_of_allows_overlapping_columns(): ] # select columns we care about signal_two = signal_two.rename(voltage="value", signal_two="field") - merged = ibis.api.asof_join(signal_one, signal_two, "timestamp_received") + merged = signal_one.asof_join(signal_two, "timestamp_received") assert merged.columns == [ "current", "timestamp_received", diff --git a/poetry-overrides.nix b/poetry-overrides.nix index 176ce22e308f9..212b84c5da3c7 100644 --- a/poetry-overrides.nix +++ b/poetry-overrides.nix @@ -18,9 +18,28 @@ self: super: { }) ]; }); + pyodbc = super.pyodbc.overridePythonAttrs (attrs: { preFixup = attrs.preFixup or "" + '' addAutoPatchelfSearchPath ${self.pkgs.unixODBC} ''; }); + + avro-python3 = super.avro-python3.overridePythonAttrs (attrs: { + nativeBuildInputs = attrs.nativeBuildInputs or [ ] ++ [ + self.pycodestyle + self.isort + ]; + }); + + apache-flink-libraries = super.apache-flink-libraries.overridePythonAttrs (attrs: { + buildInputs = attrs.nativeBuildInputs or [ ] ++ [ self.setuptools ]; + # apache-flink and apache-flink-libraries both install version.py into the + # pyflink output derivation, which is invalid: whichever gets installed + # last will be used + postInstall = '' + rm $out/${self.python.sitePackages}/pyflink/version.py + rm $out/${self.python.sitePackages}/pyflink/__pycache__/version.*.pyc + ''; + }); }