Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flink): add map support #8425

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.MapMerge,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
Expand All @@ -81,7 +80,8 @@
ops.ExtractDayOfYear: "dayofyear",
ops.First: "first_value",
ops.Last: "last_value",
ops.Map: "map_from_arrays",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RegexSearch: "regexp",
Expand Down Expand Up @@ -548,3 +548,21 @@
if where is not None:
arg = self.if_(where, arg, self.f.array(arg)[2])
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_MapContains(self, op: ops.MapContains, *, arg, key):
return self.f.array_contains(self.f.map_keys(arg), key)

Check warning on line 553 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L552-L553

Added lines #L552 - L553 were not covered by tests

def visit_Map(self, op: ops.Map, *, keys, values):
return self.cast(self.f.map_from_arrays(keys, values), op.dtype)

Check warning on line 556 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L555-L556

Added lines #L555 - L556 were not covered by tests

def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
left_keys = self.f.map_keys(left)
left_values = self.f.map_values(left)

Check warning on line 560 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L558-L560

Added lines #L558 - L560 were not covered by tests

right_keys = self.f.map_keys(right)
right_values = self.f.map_values(right)

Check warning on line 563 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L562-L563

Added lines #L562 - L563 were not covered by tests

keys = self.f.array_concat(left_keys, right_keys)
values = self.f.array_concat(left_values, right_values)

Check warning on line 566 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L565-L566

Added lines #L565 - L566 were not covered by tests

return self.cast(self.f.map_from_arrays(keys, values), op.dtype)

Check warning on line 568 in ibis/backends/flink/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/compiler.py#L568

Added line #L568 was not covered by tests
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 3 additions & 2 deletions ibis/backends/flink/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def from_ibis(cls, dtype: dt.DataType) -> DataType:
return DataTypes.ARRAY(cls.from_ibis(dtype.value_type), nullable=nullable)
elif dtype.is_map():
return DataTypes.MAP(
key_type=cls.from_ibis(dtype.key_type),
value_type=cls.from_ibis(dtype.key_type),
# keys *must* be non-nullable
key_type=cls.from_ibis(dtype.key_type.copy(nullable=False)),
value_type=cls.from_ibis(dtype.value_type),
nullable=nullable,
)
elif dtype.is_struct():
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class TestConf(BackendTest):
force_sort = True
stateful = False
supports_map = True

Check warning on line 17 in ibis/backends/flink/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/tests/conftest.py#L17

Added line #L17 was not covered by tests
deps = "pandas", "pyflink"

@staticmethod
Expand Down Expand Up @@ -63,6 +64,17 @@
con.create_table("json_t", json_types, temp=True)
con.create_table("struct", struct_types, temp=True)
con.create_table("win", win, temp=True)
con.create_table(

Check warning on line 67 in ibis/backends/flink/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/tests/conftest.py#L67

Added line #L67 was not covered by tests
"map",
pd.DataFrame(
{
"idx": [1, 2],
"kv": [{"a": 1, "b": 2, "c": 3}, {"d": 4, "e": 5, "f": 6}],
}
),
schema=ibis.schema({"idx": "int64", "kv": "map<string, int64>"}),
temp=True,
)


class TestConfForStreaming(TestConf):
Expand Down
18 changes: 17 additions & 1 deletion ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,9 @@
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
return sge.DataType(
this=typecode.MAP, expressions=[key_type, value_type], nested=True
)


class FlinkType(SqlglotType):
Expand All @@ -1041,3 +1043,17 @@
@classmethod
def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.VARBINARY)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(

Check warning on line 1052 in ibis/backends/sql/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/datatypes.py#L1050-L1052

Added lines #L1050 - L1052 were not covered by tests
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
this=typecode.MAP,
expressions=[
sge.Var(this=key_type.sql(cls.dialect) + " NOT NULL"),
value_type,
],
nested=True,
)
53 changes: 8 additions & 45 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

Expand Down Expand Up @@ -56,20 +58,15 @@ def test_column_map_merge(backend):
table = backend.map
expr = table.select(
"idx",
merged=table.kv.cast("map<string, int8>") + ibis.map({"d": 1}),
merged=table.kv + ibis.map({"d": np.int64(1)}),
).order_by("idx")
result = expr.execute().merged
expected = pd.Series(
[{"a": 1, "b": 2, "c": 3, "d": 1}, {"d": 1, "e": 5, "f": 6}], name="merged"
)
backend.assert_series_equal(result, expected)
tm.assert_series_equal(result, expected)


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapKeys'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -85,11 +82,6 @@ def test_literal_map_keys(con):
assert np.array_equal(result, ["1", "2"])


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapValues'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -104,11 +96,6 @@ def test_literal_map_values(con):


@pytest.mark.notimpl(["postgres", "risingwave"])
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.arrays.ArrayContains'>",
)
def test_scalar_isin_literal_map_keys(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -122,11 +109,6 @@ def test_scalar_isin_literal_map_keys(con):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
def test_map_scalar_contains_key_scalar(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -137,11 +119,6 @@ def test_map_scalar_contains_key_scalar(con):
assert con.execute(false) == False # noqa: E712


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -159,11 +136,6 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason=("No translation rule for <class 'ibis.expr.operations.maps.MapContains'>"),
)
def test_map_column_contains_key_scalar(backend, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1)
Expand All @@ -177,11 +149,6 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
def test_map_column_contains_key_column(alltypes):
map_expr = ibis.map(
ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])
Expand All @@ -194,11 +161,6 @@ def test_map_column_contains_key_column(alltypes):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapMerge'>",
)
def test_literal_map_merge(con):
a = ibis.literal({"a": 0, "b": 2})
b = ibis.literal({"a": 1, "c": 3})
Expand Down Expand Up @@ -270,10 +232,10 @@ def test_map_construct_dict(con, keys, values):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
@pytest.mark.broken(
["flink"],
raises=Py4JJavaError,
reason="Map key type should be non-nullable",
raises=pa.lib.ArrowInvalid,
reason="Map array child array should have no nulls",
)
def test_map_construct_array_column(con, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
Expand Down Expand Up @@ -383,6 +345,7 @@ def test_map_length(con):
assert con.execute(expr) == 2


@pytest.mark.notimpl(["flink"], raises=exc.OperationNotDefinedError)
def test_map_keys_unnest(backend):
expr = backend.map.kv.keys().unnest()
result = expr.to_pandas()
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ibis.common.exceptions as com
from ibis import _, udf
from ibis.backends.tests.errors import Py4JJavaError

no_python_udfs = mark.notimpl(
[
Expand Down Expand Up @@ -54,7 +53,9 @@ def num_vowels(s: str, include_y: bool = False) -> int:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.never(
["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(
["sqlite"], raises=com.IbisTypeError, reason="sqlite doesn't support map types"
Expand Down Expand Up @@ -84,7 +85,9 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.never(
["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types")
def test_map_merge_udf(batting):
Expand Down
Loading