diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index b9bf71c0b77f3..0bbf456df2153 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -43,10 +43,7 @@ class DuckDBCompiler(SQLGlotCompiler): ops.Hash: "hash", ops.IntegerRange: "range", ops.TimestampRange: "range", - ops.MapKeys: "map_keys", ops.MapLength: "cardinality", - ops.MapMerge: "map_concat", - ops.MapValues: "map_values", ops.Mode: "mode", ops.TimeFromHMS: "make_time", ops.TypeOf: "typeof", @@ -201,17 +198,36 @@ def visit_ArrayZip(self, op, *, arg): def visit_Map(self, op, *, keys, values): # workaround for https://github.com/ibis-project/ibis/issues/8632 - regular = self.f.map(keys, values) - either_null = sg.or_(keys.is_(NULL), values.is_(NULL)) - return self.if_(either_null, NULL, regular) + return self.if_( + sg.or_(keys.is_(NULL), values.is_(NULL)), NULL, self.f.map(keys, values) + ) def visit_MapGet(self, op, *, arg, key, default): - return self.f.ifnull( - self.f.list_extract(self.f.element_at(arg, key), 1), default + return self.if_( + sg.or_(arg.is_(NULL), key.is_(NULL)), + NULL, + self.f.ifnull(self.f.list_extract(self.f.element_at(arg, key), 1), default), ) def visit_MapContains(self, op, *, arg, key): - return self.f.len(self.f.element_at(arg, key)).neq(0) + return self.if_( + sg.or_(arg.is_(NULL), key.is_(NULL)), + NULL, + self.f.len(self.f.element_at(arg, key)).neq(0), + ) + + def visit_MapKeys(self, op, *, arg): + return self.if_(arg.is_(NULL), NULL, self.f.map_keys(arg)) + + def visit_MapValues(self, op, *, arg): + return self.if_(arg.is_(NULL), NULL, self.f.map_values(arg)) + + def visit_MapMerge(self, op, *, left, right): + return self.if_( + sg.or_(left.is_(NULL), right.is_(NULL)), + NULL, + self.f.map_concat(left, right), + ) def visit_ToJSONMap(self, op, *, arg): return self.if_( diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index fe30da1f41992..6bcfdee068c3c 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -61,6 +61,171 @@ def test_map_nulls(con, k, v): assert con.execute(m) is None +@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") +@pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array") +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) +@pytest.mark.parametrize( + ("k", "v"), + [ + param(None, ["c", "d"], id="null_keys"), + param(None, None, id="null_both"), + ], +) +def test_map_keys_nulls(con, k, v): + k = ibis.literal(k, type="array") + v = ibis.literal(v, type="array") + m = ibis.map(k, v) + assert con.execute(m.keys()) is None + + +@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) +@pytest.mark.parametrize( + "map", + [ + param( + ibis.map( + ibis.literal(["a", "b"]), ibis.literal(None, type="array") + ), + marks=[ + pytest.mark.broken( + ["pandas", "dask"], reason="TypeError: iteration over a 0-d array" + ) + ], + id="null_values", + ), + param( + ibis.map( + ibis.literal(None, type="array"), + ibis.literal(None, type="array"), + ), + marks=[ + pytest.mark.broken( + ["pandas", "dask"], reason="TypeError: iteration over a 0-d array" + ) + ], + id="null_both", + ), + param(ibis.literal(None, type="map"), id="null_map"), + ], +) +def test_map_values_nulls(con, map): + assert con.execute(map.values()) is None + + +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) +@pytest.mark.parametrize( + ("map", "key"), + [ + param( + ibis.map( + ibis.literal(["a", "b"]), ibis.literal(["c", "d"], type="array") + ), + ibis.literal(None, type="string"), + marks=[ + pytest.mark.broken( + ["pandas", "dask"], + reason="result is False instead of None", + strict=False, # passes for contains, but not for get + ) + ], + id="non_null_map_null_key", + ), + param( + ibis.map( + ibis.literal(None, type="array"), + ibis.literal(None, type="array"), + ), + "a", + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be NULL"), + pytest.mark.broken( + ["pandas", "dask"], reason="TypeError: iteration over a 0-d array" + ), + ], + id="null_both_non_null_key", + ), + param( + ibis.map( + ibis.literal(None, type="array"), + ibis.literal(None, type="array"), + ), + ibis.literal(None, type="string"), + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be NULL"), + pytest.mark.broken( + ["pandas", "dask"], reason="TypeError: iteration over a 0-d array" + ), + ], + id="null_both_null_key", + ), + param( + ibis.literal(None, type="map"), + "a", + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") + ], + id="null_map_non_null_key", + ), + param( + ibis.literal(None, type="map"), + ibis.literal(None, type="string"), + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") + ], + id="null_map_null_key", + ), + ], +) +@pytest.mark.parametrize("method", ["get", "contains"]) +def test_map_get_contains_nulls(con, map, key, method): + expr = getattr(map, method) + assert con.execute(expr(key)) is None + + +@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") +@pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) +@pytest.mark.parametrize( + ("m1", "m2"), + [ + param( + ibis.literal(None, type="map"), + ibis.literal({"a": "b"}, type="map"), + id="null_and_non_null", + ), + param( + ibis.literal({"a": "b"}, type="map"), + ibis.literal(None, type="map"), + id="non_null_and_null", + ), + param( + ibis.literal(None, type="map"), + ibis.literal(None, type="map"), + id="null_and_null", + ), + ], +) +def test_map_merge_nulls(con, m1, m2): + concatted = m1 + m2 + assert con.execute(concatted) is None + + @pytest.mark.notimpl(["pandas", "dask"]) def test_map_table(backend): table = backend.map