From ce15e5fcef214acf923f0c7a349b97a3cb77fa42 Mon Sep 17 00:00:00 2001 From: mfatihaktas Date: Thu, 22 Feb 2024 08:58:37 -0500 Subject: [PATCH] feat(flink): add map support --- ibis/backends/flink/compiler.py | 22 ++++++++++- ibis/backends/flink/datatypes.py | 5 ++- ibis/backends/flink/tests/conftest.py | 12 ++++++ ibis/backends/sql/datatypes.py | 18 ++++++++- ibis/backends/tests/test_map.py | 53 ++++----------------------- ibis/backends/tests/test_udf.py | 9 +++-- 6 files changed, 66 insertions(+), 53 deletions(-) diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index cba5dbaf7c68..5c13ac15cbad 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -55,7 +55,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.IsInf, ops.IsNan, ops.Levenshtein, - ops.MapMerge, ops.Median, ops.MultiQuantile, ops.NthValue, @@ -81,7 +80,8 @@ class FlinkCompiler(SQLGlotCompiler): ops.ExtractDayOfYear: "dayofyear", ops.First: "first_value", ops.Last: "last_value", - ops.Map: "map_from_arrays", + ops.MapKeys: "map_keys", + ops.MapValues: "map_values", ops.Power: "power", ops.RandomScalar: "rand", ops.RegexSearch: "regexp", @@ -548,3 +548,21 @@ def visit_CountDistinct(self, op, *, arg, where): if where is not None: arg = self.if_(where, arg, self.f.array(arg)[2]) return self.f.count(sge.Distinct(expressions=[arg])) + + def visit_MapContains(self, op: ops.MapContains, *, arg, key): + return self.f.array_contains(self.f.map_keys(arg), key) + + def visit_Map(self, op: ops.Map, *, keys, values): + return self.cast(self.f.map_from_arrays(keys, values), op.dtype) + + def visit_MapMerge(self, op: ops.MapMerge, *, left, right): + left_keys = self.f.map_keys(left) + left_values = self.f.map_values(left) + + right_keys = self.f.map_keys(right) + right_values = self.f.map_values(right) + + keys = self.f.array_concat(left_keys, right_keys) + values = self.f.array_concat(left_values, right_values) + + return self.cast(self.f.map_from_arrays(keys, values), op.dtype) diff --git a/ibis/backends/flink/datatypes.py b/ibis/backends/flink/datatypes.py index 664fb365d01b..0455ffb38366 100644 --- a/ibis/backends/flink/datatypes.py +++ b/ibis/backends/flink/datatypes.py @@ -127,8 +127,9 @@ def from_ibis(cls, dtype: dt.DataType) -> DataType: return DataTypes.ARRAY(cls.from_ibis(dtype.value_type), nullable=nullable) elif dtype.is_map(): return DataTypes.MAP( - key_type=cls.from_ibis(dtype.key_type), - value_type=cls.from_ibis(dtype.key_type), + # keys *must* be non-nullable + key_type=cls.from_ibis(dtype.key_type.copy(nullable=False)), + value_type=cls.from_ibis(dtype.value_type), nullable=nullable, ) elif dtype.is_struct(): diff --git a/ibis/backends/flink/tests/conftest.py b/ibis/backends/flink/tests/conftest.py index 27f6f5f3cba0..b4ac8e4fe4a4 100644 --- a/ibis/backends/flink/tests/conftest.py +++ b/ibis/backends/flink/tests/conftest.py @@ -14,6 +14,7 @@ class TestConf(BackendTest): force_sort = True stateful = False + supports_map = True deps = "pandas", "pyflink" @staticmethod @@ -63,6 +64,17 @@ def _load_data(self, **_: Any) -> None: con.create_table("json_t", json_types, temp=True) con.create_table("struct", struct_types, temp=True) con.create_table("win", win, temp=True) + con.create_table( + "map", + pd.DataFrame( + { + "idx": [1, 2], + "kv": [{"a": 1, "b": 2, "c": 3}, {"d": 4, "e": 5, "f": 6}], + } + ), + schema=ibis.schema({"idx": "int64", "kv": "map"}), + temp=True, + ) class TestConfForStreaming(TestConf): diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index fb104c11746f..9fe2f6f72ef9 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -1030,7 +1030,9 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: # key cannot be nullable in clickhouse key_type = cls.from_ibis(dtype.key_type.copy(nullable=False)) value_type = cls.from_ibis(dtype.value_type) - return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type]) + return sge.DataType( + this=typecode.MAP, expressions=[key_type, value_type], nested=True + ) class FlinkType(SqlglotType): @@ -1041,3 +1043,17 @@ class FlinkType(SqlglotType): @classmethod def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType: return sge.DataType(this=sge.DataType.Type.VARBINARY) + + @classmethod + def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: + # key cannot be nullable in clickhouse + key_type = cls.from_ibis(dtype.key_type.copy(nullable=False)) + value_type = cls.from_ibis(dtype.value_type) + return sge.DataType( + this=typecode.MAP, + expressions=[ + sge.Var(this=key_type.sql(cls.dialect) + " NOT NULL"), + value_type, + ], + nested=True, + ) diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 4efc9fd0527c..a80945653325 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd +import pandas.testing as tm +import pyarrow as pa import pytest from pytest import param @@ -56,20 +58,15 @@ def test_column_map_merge(backend): table = backend.map expr = table.select( "idx", - merged=table.kv.cast("map") + ibis.map({"d": 1}), + merged=table.kv + ibis.map({"d": np.int64(1)}), ).order_by("idx") result = expr.execute().merged expected = pd.Series( [{"a": 1, "b": 2, "c": 3, "d": 1}, {"d": 1, "e": 5, "f": 6}], name="merged" ) - backend.assert_series_equal(result, expected) + tm.assert_series_equal(result, expected) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -85,11 +82,6 @@ def test_literal_map_keys(con): assert np.array_equal(result, ["1", "2"]) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -104,11 +96,6 @@ def test_literal_map_values(con): @pytest.mark.notimpl(["postgres", "risingwave"]) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) def test_scalar_isin_literal_map_keys(con): mapping = ibis.literal({"a": 1, "b": 2}) a = ibis.literal("a") @@ -122,11 +109,6 @@ def test_scalar_isin_literal_map_keys(con): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) def test_map_scalar_contains_key_scalar(con): mapping = ibis.literal({"a": 1, "b": 2}) a = ibis.literal("a") @@ -137,11 +119,6 @@ def test_map_scalar_contains_key_scalar(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -159,11 +136,6 @@ def test_map_scalar_contains_key_column(backend, alltypes, df): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason=("No translation rule for "), -) def test_map_column_contains_key_scalar(backend, alltypes, df): expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])) series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1) @@ -177,11 +149,6 @@ def test_map_column_contains_key_scalar(backend, alltypes, df): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) def test_map_column_contains_key_column(alltypes): map_expr = ibis.map( ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]) @@ -194,11 +161,6 @@ def test_map_column_contains_key_column(alltypes): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( - ["flink"], - raises=exc.OperationNotDefinedError, - reason="No translation rule for ", -) def test_literal_map_merge(con): a = ibis.literal({"a": 0, "b": 2}) b = ibis.literal({"a": 1, "c": 3}) @@ -270,10 +232,10 @@ def test_map_construct_dict(con, keys, values): @pytest.mark.notyet( ["postgres", "risingwave"], reason="only support maps of string -> string" ) -@pytest.mark.notimpl( +@pytest.mark.broken( ["flink"], - raises=Py4JJavaError, - reason="Map key type should be non-nullable", + raises=pa.lib.ArrowInvalid, + reason="Map array child array should have no nulls", ) def test_map_construct_array_column(con, alltypes, df): expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])) @@ -383,6 +345,7 @@ def test_map_length(con): assert con.execute(expr) == 2 +@pytest.mark.notimpl(["flink"], raises=exc.OperationNotDefinedError) def test_map_keys_unnest(backend): expr = backend.map.kv.keys().unnest() result = expr.to_pandas() diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index a40770f173d5..fb99c8207c88 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -4,7 +4,6 @@ import ibis.common.exceptions as com from ibis import _, udf -from ibis.backends.tests.errors import Py4JJavaError no_python_udfs = mark.notimpl( [ @@ -54,7 +53,9 @@ def num_vowels(s: str, include_y: bool = False) -> int: ["postgres"], raises=TypeError, reason="postgres only supports map" ) @mark.notimpl(["polars"]) -@mark.notimpl(["flink"], raises=Py4JJavaError) +@mark.never( + ["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10" +) @mark.notyet(["datafusion"], raises=NotImplementedError) @mark.notyet( ["sqlite"], raises=com.IbisTypeError, reason="sqlite doesn't support map types" @@ -84,7 +85,9 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: ["postgres"], raises=TypeError, reason="postgres only supports map" ) @mark.notimpl(["polars"]) -@mark.notimpl(["flink"], raises=Py4JJavaError) +@mark.never( + ["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10" +) @mark.notyet(["datafusion"], raises=NotImplementedError) @mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types") def test_map_merge_udf(batting):