diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index dcbd8544d3c0..4ee2ae4115f9 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -566,3 +566,6 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right): values = self.f.array_concat(left_values, right_values) return self.cast(self.f.map_from_arrays(keys, values), op.dtype) + + def visit_StructColumn(self, op, *, names, values): + return self.cast(sge.Struct(expressions=list(values)), op.dtype) diff --git a/ibis/backends/flink/tests/test_memtable.py b/ibis/backends/flink/tests/test_memtable.py new file mode 100644 index 000000000000..bfeb06d2b4e6 --- /dev/null +++ b/ibis/backends/flink/tests/test_memtable.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import pytest +from pyflink.common.types import Row + +import ibis +from ibis.backends.tests.errors import Py4JJavaError + + +@pytest.mark.parametrize( + "data,schema,expected", + [ + pytest.param( + {"value": [{"a": 1}, {"a": 2}]}, + {"value": "!struct"}, + [Row(Row([1])), Row(Row([2]))], + id="simple_named_struct", + ), + pytest.param( + {"value": [[{"a": 1}, {"a": 2}], [{"a": 3}, {"a": 4}]]}, + {"value": "!array>"}, + [Row([Row([1]), Row([2])]), Row([Row([3]), Row([4])])], + id="single_field_named_struct_array", + ), + pytest.param( + {"value": [[{"a": 1, "b": 2}, {"a": 2, "b": 2}]]}, + {"value": "!array>"}, + [Row([Row([1, 2]), Row([2, 2])])], + id="named_struct_array", + ), + ], +) +def test_create_memtable(con, data, schema, expected): + t = ibis.memtable(data, schema=ibis.schema(schema)) + # cannot use con.execute(t) directly because of some behavioral discrepancy between + # `TableEnvironment.execute_sql()` and `TableEnvironment.sql_query()` + result = con.raw_sql(con.compile(t)) + # raw_sql() returns a `TableResult` object and doesn't natively convert to pandas + assert list(result.collect()) == expected + + +@pytest.mark.notyet( + ["flink"], + raises=Py4JJavaError, + reason="cannot create an ARRAY of named STRUCTs directly from the ARRAY[] constructor; https://issues.apache.org/jira/browse/FLINK-34898", +) +def test_create_named_struct_array_with_array_constructor(con): + con.raw_sql("SELECT ARRAY[cast(ROW(1) as ROW)];") diff --git a/ibis/backends/sql/dialects.py b/ibis/backends/sql/dialects.py index bd9e77d148db..f1aadc4d64e3 100644 --- a/ibis/backends/sql/dialects.py +++ b/ibis/backends/sql/dialects.py @@ -2,6 +2,7 @@ import contextlib import math +from copy import deepcopy import sqlglot.expressions as sge from sqlglot import transforms @@ -18,6 +19,7 @@ Trino, ) from sqlglot.dialects.dialect import rename_func +from sqlglot.helper import seq_get ClickHouse.Generator.TRANSFORMS |= { sge.ArraySize: rename_func("length"), @@ -113,6 +115,7 @@ class Flink(Hive): class Generator(Hive.Generator): TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | { sge.DataType.Type.TIME: "TIME", + sge.DataType.Type.STRUCT: "ROW", } TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | { @@ -121,10 +124,6 @@ class Generator(Hive.Generator): sge.StddevSamp: rename_func("stddev_samp"), sge.Variance: rename_func("var_samp"), sge.VariancePop: rename_func("var_pop"), - sge.Array: ( - lambda self, - e: f"ARRAY[{', '.join(arg.sql(self.dialect) for arg in e.expressions)}]" - ), sge.ArrayConcat: rename_func("array_concat"), sge.Length: rename_func("char_length"), sge.TryCast: lambda self, @@ -135,6 +134,59 @@ class Generator(Hive.Generator): sge.Interval: _interval_with_precision, } + def struct_sql(self, expression: sge.Struct) -> str: + from sqlglot.optimizer.annotate_types import annotate_types + + expression = annotate_types(expression) + + values = [] + schema = [] + + for e in expression.expressions: + if isinstance(e, sge.PropertyEQ): + e = sge.alias_(e.expression, e.this) + # named structs + if isinstance(e, sge.Alias): + if e.type and e.type.is_type(sge.DataType.Type.UNKNOWN): + self.unsupported( + "Cannot convert untyped key-value definitions (try annotate_types)." + ) + else: + schema.append(f"{self.sql(e, 'alias')} {self.sql(e.type)}") + values.append(self.sql(e, "this")) + else: + values.append(self.sql(e)) + + if not (size := len(expression.expressions)) or len(schema) != size: + return self.func("ROW", *values) + return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" + + def array_sql(self, expression: sge.Array) -> str: + # workaround for the time being because you cannot construct an array of named + # STRUCTs directly from the ARRAY[] constructor + # https://issues.apache.org/jira/browse/FLINK-34898 + from sqlglot.optimizer.annotate_types import annotate_types + + expression = annotate_types(expression) + first_arg = seq_get(expression.expressions, 0) + # it's an array of structs + if isinstance(first_arg, sge.Struct): + # get rid of aliasing because we want to compile this as CAST instead + args = deepcopy(expression.expressions) + for arg in args: + for e in arg.expressions: + arg.set("expressions", [e.unalias() for e in arg.expressions]) + + format_values = ", ".join([self.sql(arg) for arg in args]) + # all elements of the array should have the same type + format_dtypes = self.sql(first_arg.type) + + return f"CAST(ARRAY[{format_values}] AS ARRAY<{format_dtypes}>)" + + return ( + f"ARRAY[{', '.join(self.sql(arg) for arg in expression.expressions)}]" + ) + class Tokenizer(Hive.Tokenizer): # In Flink, embedded single quotes are escaped like most other SQL # dialects: doubling up the single quote diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 28ac4010ab82..f3c9c16a3a1c 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -240,7 +240,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): ), pytest.mark.notyet(["pandas", "dask"]), mark_notyet_postgres, - pytest.mark.notimpl("flink"), + pytest.mark.notyet( + ["flink"], + raises=Py4JJavaError, + reason="does not support selecting struct key from map", + ), mark_notyet_snowflake, ], id="struct", @@ -304,7 +308,6 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): marks=[ pytest.mark.notyet("clickhouse", reason="nested types can't be null"), mark_notyet_postgres, - pytest.mark.notimpl("flink", reason="can't construct structs"), ], id="struct", ), diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 4c250eaf47ae..78abe2e99f18 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -85,7 +85,6 @@ def test_scalar_param_array(con): ["mysql", "sqlite", "mssql"], reason="mysql and sqlite will never implement struct types", ) -@pytest.mark.notimpl(["flink"], "WIP") def test_scalar_param_struct(con): value = dict(a=1, b="abc", c=3.0) param = ibis.param("struct") diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 682da63faf27..1bfa1a8b64c0 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -72,9 +72,6 @@ def test_all_fields(struct, struct_df): @pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) -@pytest.mark.notyet( - ["flink"], reason="flink doesn't support creating struct columns from literals" -) def test_literal(backend, con, field): query = _STRUCT_LITERAL[field] dtype = query.type().to_pandas() @@ -89,9 +86,6 @@ def test_literal(backend, con, field): @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" ) -@pytest.mark.notyet( - ["flink"], reason="flink doesn't support creating struct columns from literals" -) def test_null_literal(backend, con, field): query = _NULL_STRUCT_LITERAL[field] result = pd.Series([con.execute(query)]) @@ -101,9 +95,6 @@ def test_null_literal(backend, con, field): @pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) -@pytest.mark.notyet( - ["flink"], reason="flink doesn't support creating struct columns from literals" -) def test_struct_column(alltypes, df): t = alltypes expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)))