From a8e6dd8038d66b7701aac74dbfb5c19b3ce2bc1d Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 28 Jun 2024 18:59:26 -0800 Subject: [PATCH] feat: support empty arrays, improve ibis.array() API Picking out the array stuff from https://github.com/ibis-project/ibis/pull/8666 --- ibis/backends/dask/helpers.py | 4 +- ibis/backends/pandas/executor.py | 9 ++- ibis/backends/polars/compiler.py | 24 ++++-- ibis/backends/sql/compiler.py | 3 + ibis/backends/tests/test_array.py | 74 +++++++++++++++++-- ibis/backends/tests/test_generic.py | 7 +- ibis/backends/tests/test_sql.py | 18 ++++- ibis/backends/tests/test_string.py | 10 +-- ibis/expr/operations/arrays.py | 16 +++- ibis/expr/rules.py | 4 + .../test_format_dummy_table/repr.txt | 2 +- ibis/expr/types/arrays.py | 64 +++++++++++++--- 12 files changed, 192 insertions(+), 43 deletions(-) diff --git a/ibis/backends/dask/helpers.py b/ibis/backends/dask/helpers.py index 1ca8d191c29a7..dec137a8f9319 100644 --- a/ibis/backends/dask/helpers.py +++ b/ibis/backends/dask/helpers.py @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs): @classmethod def asseries(cls, value, like=None): - """Ensure that value is a pandas Series object, broadcast if necessary.""" + """Ensure that value is a dask Series object, broadcast if necessary.""" if isinstance(value, dd.Series): return value @@ -50,7 +50,7 @@ def asseries(cls, value, like=None): elif isinstance(value, pd.Series): return dd.from_pandas(value, npartitions=1) elif like is not None: - if isinstance(value, (tuple, list, dict)): + if isinstance(value, (tuple, list, dict, np.ndarray)): fn = lambda df: pd.Series([value] * len(df), index=df.index) else: fn = lambda df: pd.Series(value, index=df.index) diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index a3153d17b8b47..868a08a98dfc8 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs): @classmethod def visit(cls, op: ops.Literal, value, dtype): - if dtype.is_interval(): + if value is None: + value = None + elif dtype.is_interval(): value = pd.Timedelta(value, dtype.unit.short) elif dtype.is_array(): value = np.array(value) @@ -219,6 +221,11 @@ def visit(cls, op: ops.FindInSet, needle, values): result = np.select(condlist, choicelist, default=-1) return pd.Series(result, name=op.name) + @classmethod + def visit(cls, op: ops.EmptyArray, dtype): + pdt = PandasType.from_ibis(dtype) + return np.array([], dtype=pdt) + @classmethod def visit(cls, op: ops.Array, exprs): return cls.rowwise(lambda row: np.array(row, dtype=object), exprs) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 4d9a497191b4a..9ad461d0d120c 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -87,25 +87,27 @@ def literal(op, **_): value = op.value dtype = op.dtype - if dtype.is_array(): - value = pl.Series("", value) - typ = PolarsType.from_ibis(dtype) - val = pl.lit(value, dtype=typ) - return val.implode() + # There are some interval types that _make_duration() can handle, + # but PolarsType.from_ibis can't, so we need to handle them here. + if dtype.is_interval(): + return _make_duration(value, dtype) + + typ = PolarsType.from_ibis(dtype) + if value is None: + return pl.lit(None, dtype=typ) + elif dtype.is_array(): + return pl.lit(pl.Series("", value).implode(), dtype=typ) elif dtype.is_struct(): values = [ pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k) for k, v in value.items() ] return pl.struct(values) - elif dtype.is_interval(): - return _make_duration(value, dtype) elif dtype.is_null(): return pl.lit(value) elif dtype.is_binary(): return pl.lit(value) else: - typ = PolarsType.from_ibis(dtype) return pl.lit(op.value, dtype=typ) @@ -973,6 +975,12 @@ def array_concat(op, **kw): return result +@translate.register(ops.EmptyArray) +def empty_array(op, **kw): + pdt = PolarsType.from_ibis(op.dtype) + return pl.lit([], dtype=pdt) + + @translate.register(ops.Array) def array_column(op, **kw): cols = [translate(col, **kw) for col in op.exprs] diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 09f26b80d62fe..18699aa560e54 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -1019,6 +1019,9 @@ def visit_InSubquery(self, op, *, rel, needle): query = sg.select(STAR).from_(query) return needle.isin(query=query) + def visit_EmptyArray(self, op, *, dtype): + return self.cast(self.f.array(), dtype) + def visit_Array(self, op, *, exprs): return self.f.array(*exprs) diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 8b55c189e9d6b..70fb7cbe18a1a 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -31,6 +31,7 @@ PySparkAnalysisException, TrinoUserError, ) +from ibis.common.annotations import ValidationError from ibis.common.collections import frozendict pytestmark = [ @@ -72,6 +73,74 @@ # list. +def test_array_factory(con): + a = ibis.array([1, 2, 3]) + assert a.type() == dt.Array(value_type=dt.Int8) + assert con.execute(a) == [1, 2, 3] + + a2 = ibis.array(a) + assert a.type() == dt.Array(value_type=dt.Int8) + assert con.execute(a2) == [1, 2, 3] + + +def test_array_factory_typed(con): + typed = ibis.array([1, 2, 3], type="array") + assert con.execute(typed) == ["1", "2", "3"] + + typed2 = ibis.array(ibis.array([1, 2, 3]), type="array") + assert con.execute(typed2) == ["1", "2", "3"] + + +@pytest.mark.notimpl("flink", raises=Py4JJavaError) +def test_array_factory_empty(con): + with pytest.raises(ValidationError): + ibis.array([]) + + empty_typed = ibis.array([], type="array") + assert empty_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(empty_typed) == [] + + +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +@pytest.mark.notyet( + "flink", raises=Py4JJavaError, reason="Parameters must be of the same type" +) +def test_array_factory_null(con): + with pytest.raises(ValidationError): + ibis.array(None) + with pytest.raises(ValidationError): + ibis.array(None, type="int64") + none_typed = ibis.array(None, type="array") + assert none_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(none_typed) is None + + nones = ibis.array([None, None], type="array") + assert nones.type() == dt.Array(value_type=dt.string) + assert con.execute(nones) == [None, None] + + # Execute a real value here, so the backends that don't support arrays + # actually xfail as we expect them to. + # Otherwise would have to @mark.xfail every test in this file besides this one. + assert con.execute(ibis.array([1, 2])) == [1, 2] + + +@pytest.mark.broken( + ["datafusion", "flink", "polars"], + raises=AssertionError, + reason="[None, 1] executes to [np.nan, 1.0]", +) +def test_array_factory_null_mixed(con): + none_and_val = ibis.array([None, 1]) + assert none_and_val.type() == dt.Array(value_type=dt.Int8) + assert con.execute(none_and_val) == [None, 1] + + none_and_val_typed = ibis.array([None, 1], type="array") + assert none_and_val_typed.type() == dt.Array(value_type=dt.String) + assert con.execute(none_and_val_typed) == [None, "1"] + + def test_array_column(backend, alltypes, df): expr = ibis.array( [alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)] @@ -1354,11 +1423,6 @@ def test_unnest_range(con): id="array", marks=[ pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), - pytest.mark.broken( - ["polars"], - reason="expression input not supported with nested arrays", - raises=TypeError, - ), ], ), ], diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 3ed4a9db8cc5e..9f9a5ce5e4c96 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1431,13 +1431,12 @@ def query(t, group_cols): snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql") -@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["druid"], raises=AssertionError) @pytest.mark.notyet( - ["datafusion", "impala", "mssql", "mysql", "sqlite"], + ["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"], reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet", - raises=com.OperationNotDefinedError, + raises=(com.OperationNotDefinedError, com.UnsupportedBackendType), ) +@pytest.mark.notimpl(["druid"], raises=AssertionError) @pytest.mark.broken( ["trino"], reason="invalid code generated for unnesting a struct", diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 777cfa3db8bb3..ca5b6488d6096 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -19,7 +19,7 @@ ibis.array([432]), marks=[ pytest.mark.never( - ["mysql", "mssql", "oracle", "impala", "sqlite"], + ["exasol", "mysql", "mssql", "oracle", "impala", "sqlite"], raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType), reason="arrays not supported in the backend", ), @@ -30,8 +30,18 @@ ibis.struct(dict(abc=432)), marks=[ pytest.mark.never( - ["impala", "mysql", "sqlite", "mssql", "exasol"], - raises=(NotImplementedError, exc.UnsupportedBackendType), + [ + "exasol", + "impala", + "mysql", + "sqlite", + "mssql", + ], + raises=( + exc.OperationNotDefinedError, + NotImplementedError, + exc.UnsupportedBackendType, + ), reason="structs not supported in the backend", ), pytest.mark.notimpl( @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot): @pytest.mark.notyet( ["datafusion", "exasol", "oracle", "flink", "risingwave"], reason="no unnest support", - raises=exc.OperationNotDefinedError, + raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType), ) @pytest.mark.notyet( ["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream" diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index ceb9fdc77711b..d56f5f934024c 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -835,6 +835,11 @@ def test_capitalize(con, inp, expected): assert pd.isnull(result) +@pytest.mark.never( + ["exasol", "impala", "mssql", "mysql", "sqlite"], + reason="Backend doesn't support arrays", + raises=(com.OperationNotDefinedError, com.UnsupportedBackendType), +) @pytest.mark.notimpl( [ "dask", @@ -842,11 +847,6 @@ def test_capitalize(con, inp, expected): "polars", "oracle", "flink", - "sqlite", - "mssql", - "mysql", - "exasol", - "impala", ], raises=com.OperationNotDefinedError, ) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 6d68baab94c32..16b55d4efdb54 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -14,15 +14,25 @@ from ibis.expr.operations.core import Unary, Value +@public +class EmptyArray(Value): + """Construct an Empty array.""" + + dtype: dt.Array + shape = ds.scalar + + @public class Array(Value): """Construct an array.""" exprs: VarTuple[Value] - @attribute - def shape(self): - return rlz.highest_precedence_shape(self.exprs) + shape = rlz.shape_like("exprs") + + def __init__(self, exprs): + assert len(exprs) > 0, "Use EmptyArray to create an empty array" + super().__init__(exprs=exprs) @attribute def dtype(self): diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 0c865297889f4..5f681cab8ec80 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -5,6 +5,7 @@ from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util @@ -16,6 +17,9 @@ @public def highest_precedence_shape(nodes): + nodes = tuple(nodes) + if len(nodes) == 0: + return ds.scalar return max(node.shape for node in nodes) diff --git a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt index fbda1a87cc5fc..480a404803c71 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt @@ -1,2 +1,2 @@ DummyTable - foo: Array([1]) \ No newline at end of file + foo: Array(exprs=[1], dtype=array) \ No newline at end of file diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 2d9e5a8f5b3a3..ce489b28ba37a 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -5,14 +5,17 @@ from public import public +import ibis +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Callable, Iterable - import ibis.expr.types as ir from ibis.expr.types.typing import V import ibis.common.exceptions as com @@ -1067,7 +1070,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column: @public @deferrable -def array(values: Iterable[V]) -> ArrayValue: +def array( + values: ArrayValue | Iterable[V] | ir.NullValue | None, + *, + type: str | dt.DataType | None = None, +) -> ArrayValue: """Create an array expression. If any values are [column expressions](../concepts/datatypes.qmd) the @@ -1078,6 +1085,9 @@ def array(values: Iterable[V]) -> ArrayValue: ---------- values An iterable of Ibis expressions or Python literals + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `array`. Returns ------- @@ -1099,7 +1109,7 @@ def array(values: Iterable[V]) -> ArrayValue: >>> t = ibis.memtable({"a": [1, 2, 3], "b": [4, 5, 6]}) >>> ibis.array([t.a, 42, ibis.literal(None)]) ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ + ┃ Array(Array) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ @@ -1108,15 +1118,49 @@ def array(values: Iterable[V]) -> ArrayValue: │ [3, 42, ... +1] │ └──────────────────────┘ - >>> ibis.array([t.a, 42 + ibis.literal(5)]) + >>> ibis.array([t.a, 42 + ibis.literal(5)], type="array") ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ + ┃ Array(Array) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ + │ array │ ├──────────────────────┤ - │ [1, 47] │ - │ [2, 47] │ - │ [3, 47] │ + │ [1.0, 47.0] │ + │ [2.0, 47.0] │ + │ [3.0, 47.0] │ └──────────────────────┘ """ - return ops.Array(tuple(values)).to_expr() + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Array): + raise ValidationError(f"type must be an array, got {type}") + + if isinstance(values, ir.Value): + if type is not None: + return values.cast(type) + elif isinstance(values, ArrayValue): + return values + else: + raise ValidationError( + f"If no type passed, values must be an array, got {values.type()}" + ) + + if values is None: + if type is None: + raise ValidationError("If values is None/NULL, type must be provided") + return ir.null(type) + + values = tuple(values) + if len(values) == 0: + if type is None: + raise ValidationError("If values is empty, type must be provided") + return ops.EmptyArray(type).to_expr() + else: + value_type = type.value_type if type is not None else None + values = [_value(v, value_type) for v in values] + return ops.Array(values).to_expr() + + +def _value(x, type) -> ir.Value: + if isinstance(x, (ir.Value, Deferred)): + return x.cast(type) if type is not None else x + else: + return ibis.literal(x, type=type)