diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index 58481ed654c9e..ac0c28473187e 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -155,7 +155,7 @@ def mapper(df, cases): return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise( lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object ) diff --git a/ibis/backends/dask/helpers.py b/ibis/backends/dask/helpers.py index 1ca8d191c29a7..dec137a8f9319 100644 --- a/ibis/backends/dask/helpers.py +++ b/ibis/backends/dask/helpers.py @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs): @classmethod def asseries(cls, value, like=None): - """Ensure that value is a pandas Series object, broadcast if necessary.""" + """Ensure that value is a dask Series object, broadcast if necessary.""" if isinstance(value, dd.Series): return value @@ -50,7 +50,7 @@ def asseries(cls, value, like=None): elif isinstance(value, pd.Series): return dd.from_pandas(value, npartitions=1) elif like is not None: - if isinstance(value, (tuple, list, dict)): + if isinstance(value, (tuple, list, dict, np.ndarray)): fn = lambda df: pd.Series([value] * len(df), index=df.index) else: fn = lambda df: pd.Series(value, index=df.index) diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 7a837e41b9c8f..1ac4a4999e3c4 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler): ops.StringSplit, ops.StringToDate, ops.StringToTimestamp, + ops.StructColumn, ops.TimeDelta, ops.TimestampAdd, ops.TimestampBucket, diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index e0b3e19940f90..e78dae9b48365 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -49,6 +49,8 @@ def visit(cls, op: ops.Node, **kwargs): @classmethod def visit(cls, op: ops.Literal, value, dtype): + if value is None: + return None if dtype.is_interval(): value = pd.Timedelta(value, dtype.unit.short) elif dtype.is_array(): @@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values): return pd.Series(result, name=op.name) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise(lambda row: np.array(row, dtype=object), exprs) @classmethod diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 7cfd3e6a81343..b64198da95426 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -86,10 +86,14 @@ def _make_duration(value, dtype): def literal(op, **_): value = op.value dtype = op.dtype + if dtype.is_interval(): + return _make_duration(value, dtype) - if dtype.is_array(): + typ = PolarsType.from_ibis(dtype) + if value is None: + return pl.lit(None, dtype=typ) + elif dtype.is_array(): value = pl.Series("", value) - typ = PolarsType.from_ibis(dtype) val = pl.lit(value, dtype=typ) return val.implode() elif dtype.is_struct(): @@ -98,14 +102,11 @@ def literal(op, **_): for k, v in value.items() ] return pl.struct(values) - elif dtype.is_interval(): - return _make_duration(value, dtype) elif dtype.is_null(): return pl.lit(value) elif dtype.is_binary(): return pl.lit(value) else: - typ = PolarsType.from_ibis(dtype) return pl.lit(op.value, dtype=typ) @@ -982,9 +983,11 @@ def array_concat(op, **kw): @translate.register(ops.Array) -def array_column(op, **kw): - cols = [translate(col, **kw) for col in op.exprs] - return pl.concat_list(cols) +def array_literal(op, **kw): + if op.exprs: + return pl.concat_list([translate(col, **kw) for col in op.exprs]) + else: + return pl.lit([], dtype=PolarsType.from_ibis(op.dtype)) @translate.register(ops.ArrayCollect) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 619916843c3b5..1548521e597fa 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -1019,8 +1019,11 @@ def visit_InSubquery(self, op, *, rel, needle): query = sg.select(STAR).from_(query) return needle.isin(query=query) - def visit_Array(self, op, *, exprs): - return self.f.array(*exprs) + def visit_Array(self, op, *, exprs, dtype): + result = self.f.array(*exprs) + if not exprs: + return self.cast(result, dtype) + return result def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index c2232a8856699..af82ab38eade9 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -62,6 +62,7 @@ class SQLiteCompiler(SQLGlotCompiler): ops.TimestampDiff, ops.StringToDate, ops.StringToTimestamp, + ops.StructColumn, ops.TimeDelta, ops.DateDelta, ops.TimestampDelta, diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 82beee0ddb686..9410f720a2636 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -32,6 +32,7 @@ SnowflakeProgrammingError, TrinoUserError, ) +from ibis.common.annotations import ValidationError from ibis.common.collections import frozendict pytestmark = [ @@ -73,6 +74,75 @@ # list. +def test_array_factory(con): + a = ibis.array([1, 2, 3]) + assert a.type() == dt.Array(value_type=dt.Int8) + assert con.execute(a) == [1, 2, 3] + + a2 = ibis.array(a) + assert a.type() == dt.Array(value_type=dt.Int8) + assert con.execute(a2) == [1, 2, 3] + + +def test_array_factory_typed(con): + typed = ibis.array([1, 2, 3], type="array") + assert con.execute(typed) == ["1", "2", "3"] + + typed2 = ibis.array(ibis.array([1, 2, 3]), type="array") + assert con.execute(typed2) == ["1", "2", "3"] + + +@pytest.mark.notimpl("flink", raises=Py4JJavaError) +@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError) +def test_array_factory_empty(con): + with pytest.raises(ValidationError): + ibis.array([]) + + empty_typed = ibis.array([], type="array") + assert empty_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(empty_typed) == [] + + +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +@pytest.mark.notyet( + "flink", raises=Py4JJavaError, reason="Parameters must be of the same type" +) +def test_array_factory_null(con): + with pytest.raises(ValidationError): + ibis.array(None) + with pytest.raises(ValidationError): + ibis.array(None, type="int64") + none_typed = ibis.array(None, type="array") + assert none_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(none_typed) is None + + nones = ibis.array([None, None], type="array") + assert nones.type() == dt.Array(value_type=dt.string) + assert con.execute(nones) == [None, None] + + # Execute a real value here, so the backends that don't support arrays + # actually xfail as we expect them to. + # Otherwise would have to @mark.xfail every test in this file besides this one. + assert con.execute(ibis.array([1, 2])) == [1, 2] + + +@pytest.mark.broken( + ["datafusion", "flink", "polars"], + raises=AssertionError, + reason="[None, 1] executes to [np.nan, 1.0]", +) +def test_array_factory_null_mixed(con): + none_and_val = ibis.array([None, 1]) + assert none_and_val.type() == dt.Array(value_type=dt.Int8) + assert con.execute(none_and_val) == [None, 1] + + none_and_val_typed = ibis.array([None, 1], type="array") + assert none_and_val_typed.type() == dt.Array(value_type=dt.String) + assert con.execute(none_and_val_typed) == [None, "1"] + + def test_array_column(backend, alltypes, df): expr = ibis.array( [alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)] diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 491efc46f2812..089528ad6ada8 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -11,6 +11,7 @@ import ibis.common.exceptions as exc import ibis.expr.datatypes as dt from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never( @@ -39,6 +40,31 @@ ) +@mark_notimpl_risingwave_hstore +def test_map_factory_dict(con): + assert con.execute(ibis.map({"a": "b"})) == {"a": "b"} + assert con.execute(ibis.map({"a": "b"}, type="map")) == {"a": "b"} + with pytest.raises(ValidationError): + ibis.map({1: 2}, type="array") + + +@mark_notimpl_risingwave_hstore +def test_map_factory_keys_vals(con): + assert con.execute(ibis.map(["a"], ["b"])) == {"a": "b"} + assert con.execute(ibis.map(["a"], ["b"], type="map")) == {"a": "b"} + with pytest.raises(ValidationError): + ibis.map(["a"], ["b"], type="array") + + +@mark_notimpl_risingwave_hstore +def test_map_factory_expr(con): + m = ibis.map({"a": "b"}) + assert con.execute(ibis.map(m)) == {"a": "b"} + assert con.execute(ibis.map(m, type="map")) == {"a": "b"} + with pytest.raises(ValidationError): + ibis.map(m, type="array") + + @pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") @pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array") @pytest.mark.notimpl( @@ -60,6 +86,12 @@ def test_map_nulls(con, k, v): m = ibis.map(k, v) assert con.execute(m) is None + assert con.execute(ibis.map(None, type="map")) is None + with pytest.raises(ValidationError): + ibis.map(None) + with pytest.raises(ValidationError): + ibis.map(None, type="array") + @pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") @pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array") @@ -503,6 +535,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): marks=[ pytest.mark.notyet("clickhouse", reason="nested types can't be null"), mark_notyet_postgres, + pytest.mark.notimpl( + "flink", + raises=Py4JJavaError, + reason="Unexpected error in type inference logic of function 'COALESCE'", + ), ], id="struct", ), diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 67dcfc048d880..784236cbc5027 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -12,11 +12,12 @@ @pytest.mark.parametrize( - "expr", + "expr,contains", [ - param(ibis.literal(1), id="simple_literal"), + param(ibis.literal(432), "432", id="simple_literal"), param( - ibis.array([1]), + ibis.array([432]), + "432", marks=[ pytest.mark.never( ["mysql", "mssql", "oracle", "impala", "sqlite"], @@ -27,11 +28,16 @@ id="array_literal", ), param( - ibis.struct(dict(a=1)), + ibis.struct(dict(abc=432)), + "432", marks=[ pytest.mark.never( ["impala", "mysql", "sqlite", "mssql", "exasol"], - raises=(NotImplementedError, exc.UnsupportedBackendType), + raises=( + exc.OperationNotDefinedError, + NotImplementedError, + exc.UnsupportedBackendType, + ), reason="structs not supported in the backend", ), pytest.mark.notimpl( @@ -48,8 +54,8 @@ reason="Not a SQL backend", ) @pytest.mark.notimpl(["polars"], reason="Not clear how to extract SQL from the backend") -def test_literal(backend, expr): - assert ibis.to_sql(expr, dialect=backend.name()) +def test_literal(backend, expr, contains): + assert contains in ibis.to_sql(expr, dialect=backend.name()) @pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index c791318f15d67..30089daf8a7de 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -14,12 +14,14 @@ import ibis.expr.datatypes as dt from ibis import util from ibis.backends.tests.errors import ( + ClickHouseDatabaseError, PolarsColumnNotFoundError, PsycoPg2InternalError, PsycoPg2SyntaxError, Py4JJavaError, PySparkAnalysisException, ) +from ibis.common.annotations import ValidationError from ibis.common.exceptions import IbisError pytestmark = [ @@ -28,6 +30,64 @@ pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]), ] +mark_notimpl_postgres_literals = pytest.mark.notimpl( + "postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError +) + + +@pytest.mark.notimpl("risingwave") +@pytest.mark.broken("postgres", reason="JSON handling is buggy") +@pytest.mark.notimpl( + "flink", + raises=Py4JJavaError, + reason="Unexpected error in type inference logic of function 'COALESCE'", +) +def test_struct_factory(con): + s = ibis.struct({"a": 1, "b": 2}) + assert con.execute(s) == {"a": 1, "b": 2} + + s2 = ibis.struct(s) + assert con.execute(s2) == {"a": 1, "b": 2} + + typed = ibis.struct({"a": 1, "b": 2}, type="struct") + assert con.execute(typed) == {"a": "1", "b": "2"} + + typed2 = ibis.struct(s, type="struct") + assert con.execute(typed2) == {"a": "1", "b": "2"} + + items = ibis.struct([("a", 1), ("b", 2)]) + assert con.execute(items) == {"a": 1, "b": 2} + + +@pytest.mark.parametrize("val", [{}, []]) +def test_struct_factory_empty(con, val): + with pytest.raises(ValidationError): + ibis.struct(val, type="struct") + s = ibis.struct(val, type="struct<>") + result = con.execute(s) + assert result == {} + + +@pytest.mark.notimpl("risingwave") +@mark_notimpl_postgres_literals +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +@pytest.mark.broken( + "polars", + reason=r"pl.lit(None, type='struct') gives {'a': None}: https://github.com/pola-rs/polars/issues/3462", +) +def test_struct_factory_null(con): + with pytest.raises(ValidationError): + ibis.struct(None) + none_typed = ibis.struct(None, type="struct") + assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64}) + assert con.execute(none_typed) is None + # Execute a real value here, so the backends that don't support structs + # actually xfail as we expect them to. + # Otherwise would have to @mark.xfail every test in this file besides this one. + assert con.execute(ibis.struct({"a": 1, "b": 2})) == {"a": 1, "b": 2} + @pytest.mark.notimpl(["dask"]) @pytest.mark.parametrize( @@ -78,6 +138,9 @@ def test_all_fields(struct, struct_df): @pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) +@pytest.mark.notyet( + ["flink"], reason="flink doesn't support creating struct columns from literals" +) def test_literal(backend, con, field): query = _STRUCT_LITERAL[field] dtype = query.type().to_pandas() @@ -87,7 +150,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@mark_notimpl_postgres_literals @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -137,14 +200,6 @@ def test_collect_into_struct(alltypes): assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730 -@pytest.mark.notimpl( - ["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError -) -@pytest.mark.notimpl( - ["risingwave"], - reason="struct literals not implemented", - raises=PsycoPg2InternalError, -) @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) @@ -240,12 +295,6 @@ def test_keyword_fields(con, nullable): raises=PolarsColumnNotFoundError, reason="doesn't seem to support IN-style subqueries on structs", ) -@pytest.mark.notimpl( - # https://github.com/pandas-dev/pandas/issues/58909 - ["pandas", "dask"], - raises=TypeError, - reason="unhashable type: 'dict'", -) @pytest.mark.xfail_version( pyspark=["pyspark<3.5"], reason="requires pyspark 3.5", diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 6d68baab94c32..08734ab3eb6a5 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -9,7 +9,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common.annotations import attribute +from ibis.common.annotations import ValidationError, attribute from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Unary, Value @@ -19,15 +19,24 @@ class Array(Value): """Construct an array.""" exprs: VarTuple[Value] + dtype: Optional[dt.Array] = None + + def __init__(self, exprs, dtype: dt.Array | None = None): + if len(exprs) == 0: + if dtype is None: + raise ValidationError("If values is empty, dtype must be provided") + if not isinstance(dtype, dt.Array): + raise ValidationError(f"dtype must be an array, got {dtype}") + elif dtype is None: + dtype = dt.Array(rlz.highest_precedence_dtype(exprs)) + super().__init__(exprs=exprs, dtype=dtype) @attribute - def shape(self): + def shape(self) -> ds.DataShape: + if len(self.exprs) == 0: + return ds.scalar return rlz.highest_precedence_shape(self.exprs) - @attribute - def dtype(self): - return dt.Array(rlz.highest_precedence_dtype(self.exprs)) - @public class ArrayLength(Unary): diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index 3e921cdd633c8..82596e415bf93 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -4,6 +4,7 @@ from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import ValidationError, attribute @@ -38,8 +39,6 @@ class StructColumn(Value): names: VarTuple[str] values: VarTuple[Value] - shape = rlz.shape_like("values") - def __init__(self, names, values): if len(names) != len(values): raise ValidationError( @@ -52,3 +51,9 @@ def __init__(self, names, values): def dtype(self) -> dt.DataType: dtypes = (value.dtype for value in self.values) return dt.Struct.from_tuples(zip(self.names, dtypes)) + + @attribute + def shape(self) -> ds.DataShape: + if len(self.values) == 0: + return ds.scalar + return rlz.highest_precedence_shape(self.values) diff --git a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt index fbda1a87cc5fc..480a404803c71 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt @@ -1,2 +1,2 @@ DummyTable - foo: Array([1]) \ No newline at end of file + foo: Array(exprs=[1], dtype=array) \ No newline at end of file diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 2d9e5a8f5b3a3..3e732d2c3c567 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -5,14 +5,16 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Callable, Iterable - import ibis.expr.types as ir from ibis.expr.types.typing import V import ibis.common.exceptions as com @@ -1067,7 +1069,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column: @public @deferrable -def array(values: Iterable[V]) -> ArrayValue: +def array( + values: ArrayValue | Iterable[V] | ir.NullValue | None, + *, + type: str | dt.DataType | None = None, +) -> ArrayValue: """Create an array expression. If any values are [column expressions](../concepts/datatypes.qmd) the @@ -1078,6 +1084,9 @@ def array(values: Iterable[V]) -> ArrayValue: ---------- values An iterable of Ibis expressions or Python literals + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `array`. Returns ------- @@ -1099,7 +1108,7 @@ def array(values: Iterable[V]) -> ArrayValue: >>> t = ibis.memtable({"a": [1, 2, 3], "b": [4, 5, 6]}) >>> ibis.array([t.a, 42, ibis.literal(None)]) ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ + ┃ Array(Array) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ @@ -1108,15 +1117,38 @@ def array(values: Iterable[V]) -> ArrayValue: │ [3, 42, ... +1] │ └──────────────────────┘ - >>> ibis.array([t.a, 42 + ibis.literal(5)]) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [1, 47] │ - │ [2, 47] │ - │ [3, 47] │ - └──────────────────────┘ + >>> ibis.array([t.a, 42 + ibis.literal(5)], type="array") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(Array(Array), array) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├────────────────────────────────────┤ + │ [1.0, 47.0] │ + │ [2.0, 47.0] │ + │ [3.0, 47.0] │ + └────────────────────────────────────┘ """ - return ops.Array(tuple(values)).to_expr() + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Array): + raise ValidationError(f"dtype must be an array, got {type}") + + if isinstance(values, ir.Value): + if type is not None: + return values.cast(type) + elif isinstance(values, ArrayValue): + return values + else: + raise ValidationError( + f"If no type passed, values must be an array, got {values.type()}" + ) + + if values is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + + values = tuple(values) + if len(values) > 0 and type is not None: + return ops.Array(values).to_expr().cast(type) + else: + return ops.Array(values, type).to_expr() diff --git a/ibis/expr/types/maps.py b/ibis/expr/types/maps.py index 8672206a6873e..cb8a9947a6f50 100644 --- a/ibis/expr/types/maps.py +++ b/ibis/expr/types/maps.py @@ -1,18 +1,21 @@ from __future__ import annotations +from collections.abc import Mapping from typing import TYPE_CHECKING, Any from public import public +import ibis import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable - import ibis.expr.types as ir - from ibis.expr.types.arrays import ArrayValue + import ibis.expr.datatypes as dt @public @@ -441,8 +444,15 @@ def __getitem__(self, key: ir.Value) -> ir.Column: @public @deferrable def map( - keys: Iterable[Any] | Mapping[Any, Any] | ArrayValue, - values: Iterable[Any] | ArrayValue | None = None, + keys: Iterable[Any] + | Mapping[Any, Any] + | ir.ArrayValue + | MapValue + | ir.NullValue + | None, + values: Iterable[Any] | ir.ArrayValue | None = None, + *, + type: str | dt.DataType | None = None, ) -> MapValue: """Create a MapValue. @@ -455,6 +465,9 @@ def map( Keys of the map or `Mapping`. If `keys` is a `Mapping`, `values` must be `None`. values Values of the map or `None`. If `None`, the `keys` argument must be a `Mapping`. + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `map`. Returns ------- @@ -484,16 +497,42 @@ def map( │ ['a', 'b'] │ [1, 2] │ │ ['b'] │ [3] │ └──────────────────────┴──────────────────────┘ - >>> ibis.map(t.keys, t.values) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Map(keys, values) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ map │ - ├──────────────────────┤ - │ {'a': 1, 'b': 2} │ - │ {'b': 3} │ - └──────────────────────┘ + >>> ibis.map(t.keys, t.values, type="map") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Map(keys, Cast(values, array)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ map │ + ├─────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 2.0} │ + │ {'b': 3.0} │ + └─────────────────────────────────────────┘ """ - if values is None: + from ibis.expr import datatypes as dt + + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Map): + raise ValidationError(f"dtype must be a map, got {type}") + + if isinstance(keys, Mapping) and values is None: keys, values = tuple(keys.keys()), tuple(keys.values()) + + if isinstance(keys, ir.Value) and values is None: + if type is not None: + return keys.cast(type) + elif isinstance(keys, MapValue): + return keys + else: + raise ValidationError( + f"If no type passed, value must be a map, got {keys.type()}" + ) + + if keys is None or values is None: + if type is None: + raise ValidationError("If keys is None/NULL, dtype must be provided") + return ir.null(type) + + k_type = dt.Array(value_type=type.key_type) if type is not None else None + v_type = dt.Array(value_type=type.value_type) if type is not None else None + keys = ibis.array(keys, type=k_type) + values = ibis.array(values, type=v_type) return ops.Map(keys, values).to_expr() diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index 7de4c28451906..14d02ac16fc11 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -1,28 +1,33 @@ from __future__ import annotations -import collections from keyword import iskeyword from typing import TYPE_CHECKING from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.common.exceptions import IbisError -from ibis.expr.types.generic import Column, Scalar, Value, literal +from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - import ibis.expr.datatypes as dt - import ibis.expr.types as ir from ibis.expr.types.typing import V @public @deferrable def struct( - value: Iterable[tuple[str, V]] | Mapping[str, V], + value: Iterable[tuple[str, V]] + | Mapping[str, V] + | StructValue + | ir.NullValue + | None, + *, type: str | dt.DataType | None = None, ) -> StructValue: """Create a struct expression. @@ -37,8 +42,7 @@ def struct( `(str, Value)`. type An instance of `ibis.expr.datatypes.DataType` or a string indicating - the Ibis type of `value`. This is only used if all of the input values - are Python literals. eg `struct`. + the Ibis type of `value`. eg `struct`. Returns ------- @@ -66,26 +70,49 @@ def struct( Create a struct column from a column and a scalar literal >>> t = ibis.memtable({"a": [1, 2, 3]}) - >>> ibis.struct([("a", t.a), ("b", "foo")]) - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ StructColumn() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ struct │ - ├─────────────────────────────┤ - │ {'a': 1, 'b': 'foo'} │ - │ {'a': 2, 'b': 'foo'} │ - │ {'a': 3, 'b': 'foo'} │ - └─────────────────────────────┘ + >>> ibis.struct([("a", t.a), ("b", "foo")], type="struct") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(StructColumn(), struct) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ struct │ + ├─────────────────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 'foo'} │ + │ {'a': 2.0, 'b': 'foo'} │ + │ {'a': 3.0, 'b': 'foo'} │ + └─────────────────────────────────────────────────────┘ """ import ibis.expr.operations as ops + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Struct): + raise ValidationError(f"dtype must be an struct, got {type}") + + if isinstance(value, ir.Value): + if type is not None: + return value.cast(type) + elif isinstance(value, StructValue): + return value + else: + raise ValidationError( + f"If no type passed, value must be a struct, got {value.type()}" + ) + + if value is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + fields = dict(value) - if any(isinstance(value, Value) for value in fields.values()): - names = tuple(fields.keys()) - values = tuple(fields.values()) - return ops.StructColumn(names=names, values=values).to_expr() - else: - return literal(collections.OrderedDict(fields), type=type) + names = fields.keys() + result = ops.StructColumn(names=names, values=fields.values()).to_expr() + if type is not None: + if not set(names).issuperset(type.names): + raise ValidationError( + f"The passed type requires fields {type.names}", + f" but only found fields {names}", + ) + result = result.cast(type) + return result @public diff --git a/ibis/tests/expr/test_literal.py b/ibis/tests/expr/test_literal.py index da056ff140f6c..0b84109030566 100644 --- a/ibis/tests/expr/test_literal.py +++ b/ibis/tests/expr/test_literal.py @@ -8,7 +8,7 @@ import ibis import ibis.expr.datatypes as dt -from ibis.common.collections import frozendict +from ibis.common.annotations import ValidationError from ibis.expr.operations import Literal from ibis.tests.util import assert_pickle_roundtrip @@ -109,23 +109,19 @@ def test_normalized_underlying_value(userinput, literal_type, expected_type): def test_struct_literal(value): typestr = "struct" a = ibis.struct(value, type=typestr) - assert a.op().value == frozendict( - field1=str(value["field1"]), field2=float(value["field2"]) - ) assert a.type() == dt.dtype(typestr) @pytest.mark.parametrize( "value", [ - dict(field1="value1", field3=3.14), # wrong field name - dict(field1="value1"), # missing field + pytest.param(dict(field1="value1", field3=3.14), id="wrong_field"), + pytest.param(dict(field1="value1"), id="missing_field"), ], ) def test_struct_literal_non_castable(value): - typestr = "struct" - with pytest.raises(TypeError, match="Unable to normalize"): - ibis.struct(value, type=typestr) + with pytest.raises(ValidationError): + ibis.struct(value, type="struct") def test_struct_cast_to_empty_struct(): @@ -133,18 +129,9 @@ def test_struct_cast_to_empty_struct(): assert value.type().castable(dt.Struct({})) -@pytest.mark.parametrize( - "value", - [ - dict(key1="value1", key2="value2"), - ], -) -def test_map_literal(value): - typestr = "map" +def test_map_literal(): a = ibis.map(["a", "b"], [1, 2]) - assert a.op().keys.value == ("a", "b") - assert a.op().values.value == (1, 2) - assert a.type() == dt.dtype(typestr) + assert a.type() == dt.dtype("map") @pytest.mark.parametrize(