From 32e7636e1bddc9fda917d73c24d9273325bae8c3 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 8 May 2024 17:24:13 -0800 Subject: [PATCH] feat: Improve array(), map(), and struct fixes https://github.com/ibis-project/ibis/issues/8289 This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in some cases. One this is adding support for passing in None to all these constructors. These use the new `ibis.null()` API to return `op.Literal(None, )`s Make these constructors idempotent: you can pass in existing Expressions into array(), etc. The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast. A big structural change is that now ops.Array has an optional attribute "dtype", so if you pass in a 0-length sequence of values the op still knows what dtype it is. Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup. You can test this locally with eg `pytest -m -k factory ibis/backends/tests/test_array.py ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py` Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns. Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes. Also, implement ops.StructColumn on pandas and dask --- ibis/backends/dask/executor.py | 8 +++- ibis/backends/pandas/executor.py | 18 +++++--- ibis/backends/polars/compiler.py | 13 +++--- ibis/backends/sql/compiler.py | 7 ++- ibis/backends/sql/datatypes.py | 6 ++- ibis/backends/tests/test_array.py | 38 ++++++++++++++++ ibis/backends/tests/test_map.py | 13 +++++- ibis/backends/tests/test_struct.py | 59 +++++++++++++++++++++---- ibis/expr/operations/arrays.py | 19 +++++--- ibis/expr/operations/structs.py | 9 ++-- ibis/expr/types/arrays.py | 58 +++++++++++++++++++------ ibis/expr/types/maps.py | 69 +++++++++++++++++++++++------- ibis/expr/types/structs.py | 69 ++++++++++++++++++++---------- 13 files changed, 301 insertions(+), 85 deletions(-) diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index e35a4140481ce..ac0c28473187e 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -155,11 +155,17 @@ def mapper(df, cases): return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise( lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object ) + @classmethod + def visit(cls, op: ops.StructColumn, names, values): + return cls.rowwise( + lambda row: dict(zip(names, row)), values, name=op.name, dtype=object + ) + @classmethod def visit(cls, op: ops.ArrayConcat, arg): dtype = PandasType.from_ibis(op.dtype) diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 858f49173464b..2ab04d52791b5 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -49,12 +49,14 @@ def visit(cls, op: ops.Node, **kwargs): @classmethod def visit(cls, op: ops.Literal, value, dtype): + if value is None: + return None if dtype.is_interval(): - value = pd.Timedelta(value, dtype.unit.short) - elif dtype.is_array(): - value = np.array(value) - elif dtype.is_date(): - value = pd.Timestamp(value, tz="UTC").tz_localize(None) + return pd.Timedelta(value, dtype.unit.short) + if dtype.is_array(): + return np.array(value) + if dtype.is_date(): + return pd.Timestamp(value, tz="UTC").tz_localize(None) return value @classmethod @@ -220,9 +222,13 @@ def visit(cls, op: ops.FindInSet, needle, values): return pd.Series(result, name=op.name) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise(lambda row: np.array(row, dtype=object), exprs) + @classmethod + def visit(cls, op: ops.StructColumn, names, values): + return cls.rowwise(lambda row: dict(zip(names, row)), values) + @classmethod def visit(cls, op: ops.ArrayConcat, arg): return cls.rowwise(lambda row: np.concatenate(row.values), arg) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index a793db0c609b8..cc4e2495eaf79 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -86,10 +86,12 @@ def _make_duration(value, dtype): def literal(op, **_): value = op.value dtype = op.dtype + typ = PolarsType.from_ibis(dtype) + if value is None: + return pl.lit(None, dtype=typ) if dtype.is_array(): value = pl.Series("", value) - typ = PolarsType.from_ibis(dtype) val = pl.lit(value, dtype=typ) return val.implode() elif dtype.is_struct(): @@ -105,7 +107,6 @@ def literal(op, **_): elif dtype.is_binary(): return pl.lit(value) else: - typ = PolarsType.from_ibis(dtype) return pl.lit(op.value, dtype=typ) @@ -980,9 +981,11 @@ def array_concat(op, **kw): @translate.register(ops.Array) -def array_column(op, **kw): - cols = [translate(col, **kw) for col in op.exprs] - return pl.concat_list(cols) +def array_literal(op, **kw): + if len(op.exprs) > 0: + return pl.concat_list([translate(col, **kw) for col in op.exprs]) + else: + return pl.lit([], dtype=PolarsType.from_ibis(op.dtype)) @translate.register(ops.ArrayCollect) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index fb0c5a974c8b7..158b0ebc0d362 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -970,8 +970,11 @@ def visit_InSubquery(self, op, *, rel, needle): query = sg.select(STAR).from_(query) return needle.isin(query=query) - def visit_Array(self, op, *, exprs): - return self.f.array(*exprs) + def visit_Array(self, op, *, exprs, dtype): + result = self.f.array(*exprs) + if len(exprs) == 0: + return self.cast(result, dtype) + return result def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index 4e492fb15f527..80c57d876b04b 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -1007,8 +1007,10 @@ class ClickHouseType(SqlglotType): def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: """Convert a sqlglot type to an ibis type.""" typ = super().from_ibis(dtype) - if dtype.nullable and not (dtype.is_map() or dtype.is_array()): - # map cannot be nullable in clickhouse + # nested types cannot be nullable in clickhouse + if dtype.nullable and not ( + dtype.is_map() or dtype.is_array() or dtype.is_struct() + ): return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) else: return typ diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 7340540f7269d..2a2004067511f 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -30,6 +30,7 @@ PySparkAnalysisException, TrinoUserError, ) +from ibis.common.annotations import ValidationError from ibis.common.collections import frozendict pytestmark = [ @@ -71,6 +72,43 @@ # list. +def test_array_factory(con): + a = ibis.array([1, 2, 3]) + assert con.execute(a) == [1, 2, 3] + + a2 = ibis.array(a) + assert con.execute(a2) == [1, 2, 3] + + typed = ibis.array([1, 2, 3], type="array") + assert con.execute(typed) == ["1", "2", "3"] + + typed2 = ibis.array(a, type="array") + assert con.execute(typed2) == ["1", "2", "3"] + + +@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError) +def test_array_factory_empty(con): + with pytest.raises(ValidationError): + ibis.array([]) + + empty_typed = ibis.array([], type="array") + assert empty_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(empty_typed) == [] + + +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +def test_array_factory_null(con): + with pytest.raises(ValidationError): + ibis.array(None) + with pytest.raises(ValidationError): + ibis.array(None, type="int64") + none_typed = ibis.array(None, type="array") + assert none_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(none_typed) is None + + def test_array_column(backend, alltypes, df): expr = ibis.array( [alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)] diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 63cdf728aabf4..d0c8993748014 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -11,6 +11,7 @@ import ibis.common.exceptions as exc import ibis.expr.datatypes as dt from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never( @@ -121,6 +122,16 @@ def test_map_values_nulls(con, map): assert con.execute(map.values()) is None +def test_map_factory(con): + assert con.execute(ibis.map(None, type="map")) is None + with pytest.raises(ValidationError): + ibis.map(None) + with pytest.raises(ValidationError): + ibis.map(None, type="array") + with pytest.raises(ValidationError): + ibis.map({1: 2}, type="array") + + @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -669,6 +680,6 @@ def test_map_keys_unnest(backend): @mark_notimpl_risingwave_hstore def test_map_contains_null(con): - expr = ibis.map(["a"], ibis.literal([None], type="array")) + expr = ibis.map(["a"], ibis.array([None], type="array")) assert con.execute(expr.contains("a")) assert not con.execute(expr.contains("b")) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index e7f078f7591a1..d20de15529131 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -21,7 +21,8 @@ Py4JJavaError, PySparkAnalysisException, ) -from ibis.common.exceptions import IbisError, OperationNotDefinedError +from ibis.common.annotations import ValidationError +from ibis.common.exceptions import IbisError pytestmark = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), @@ -29,6 +30,49 @@ pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]), ] +mark_notimpl_postgres_literals = pytest.mark.notimpl( + "postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError +) + + +@pytest.mark.broken("postgres", reason="JSON handling is buggy") +def test_struct_factory(con): + s = ibis.struct({"a": 1, "b": 2}) + assert con.execute(s) == {"a": 1, "b": 2} + + s2 = ibis.struct(s) + assert con.execute(s2) == {"a": 1, "b": 2} + + typed = ibis.struct({"a": 1, "b": 2}, type="struct") + assert con.execute(typed) == {"a": "1", "b": "2"} + + typed2 = ibis.struct(s, type="struct") + assert con.execute(typed2) == {"a": "1", "b": "2"} + + +def test_struct_factory_empty(): + with pytest.raises(ValidationError): + ibis.struct({}) + with pytest.raises(ValidationError): + ibis.struct({}, type="struct<>") + with pytest.raises(ValidationError): + ibis.struct({}, type="struct") + + +@mark_notimpl_postgres_literals +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +@pytest.mark.broken( + "polars", reason=r"pl.lit(None, type='struct') gives {'a': None}" +) +def test_struct_factory_null(con): + with pytest.raises(ValidationError): + ibis.struct(None) + none_typed = ibis.struct(None, type="struct") + assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64}) + assert con.execute(none_typed) is None + @pytest.mark.notimpl(["dask"]) @pytest.mark.parametrize( @@ -79,6 +123,9 @@ def test_all_fields(struct, struct_df): @pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) +@pytest.mark.notyet( + ["flink"], reason="flink doesn't support creating struct columns from literals" +) def test_literal(backend, con, field): query = _STRUCT_LITERAL[field] dtype = query.type().to_pandas() @@ -88,7 +135,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@mark_notimpl_postgres_literals @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -101,7 +148,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) def test_struct_column(alltypes, df): t = alltypes expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col))) @@ -113,7 +160,7 @@ def test_struct_column(alltypes, df): tm.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) +@pytest.mark.notimpl(["postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) @@ -138,9 +185,6 @@ def test_collect_into_struct(alltypes): assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730 -@pytest.mark.notimpl( - ["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError -) @pytest.mark.notimpl( ["risingwave"], reason="struct literals not implemented", @@ -253,7 +297,6 @@ def test_keyword_fields(con, nullable): raises=PolarsColumnNotFoundError, reason="doesn't seem to support IN-style subqueries on structs", ) -@pytest.mark.notimpl(["pandas", "dask"], raises=OperationNotDefinedError) @pytest.mark.xfail_version( pyspark=["pyspark<3.5"], reason="requires pyspark 3.5", diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 68ee711a2da6b..75c9648da47fa 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -7,7 +7,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common.annotations import attribute +from ibis.common.annotations import ValidationError, attribute from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Unary, Value @@ -15,15 +15,24 @@ @public class Array(Value): exprs: VarTuple[Value] + dtype: Optional[dt.Array] = None + + def __init__(self, exprs, dtype: dt.Array | None = None): + if len(exprs) == 0: + if dtype is None: + raise ValidationError("If values is empty, dtype must be provided") + if not isinstance(dtype, dt.Array): + raise ValidationError(f"dtype must be an array, got {dtype}") + elif dtype is None: + dtype = dt.Array(rlz.highest_precedence_dtype(exprs)) + super().__init__(exprs=exprs, dtype=dtype) @attribute def shape(self): + if len(self.exprs) == 0: + return ds.scalar return rlz.highest_precedence_shape(self.exprs) - @attribute - def dtype(self): - return dt.Array(rlz.highest_precedence_dtype(self.exprs)) - @public class ArrayLength(Unary): diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index 20c0c3dc0a4ef..2b3de036948ac 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -34,7 +34,9 @@ class StructColumn(Value): shape = rlz.shape_like("values") - def __init__(self, names, values): + def __init__(self, names: VarTuple[str], values: VarTuple[Value]): + if len(names) == 0: + raise ValidationError("StructColumn must have at least one field") if len(names) != len(values): raise ValidationError( f"Length of names ({len(names)}) does not match length of " @@ -43,6 +45,5 @@ def __init__(self, names, values): super().__init__(names=names, values=values) @attribute - def dtype(self) -> dt.DataType: - dtypes = (value.dtype for value in self.values) - return dt.Struct.from_tuples(zip(self.names, dtypes)) + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples(zip(self.names, (v.dtype for v in self.values))) diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index bea22cf537925..f1f11849c0156 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -5,14 +5,16 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable - import ibis.expr.types as ir from ibis.expr.types.typing import V import ibis.common.exceptions as com @@ -1067,7 +1069,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column: @public @deferrable -def array(values: Iterable[V]) -> ArrayValue: +def array( + values: ArrayValue | Iterable[V] | ir.NullValue | None, + *, + type: str | dt.DataType | None = None, +) -> ArrayValue: """Create an array expression. If any values are [column expressions](../concepts/datatypes.qmd) the @@ -1078,6 +1084,9 @@ def array(values: Iterable[V]) -> ArrayValue: ---------- values An iterable of Ibis expressions or Python literals + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `array`. Returns ------- @@ -1106,15 +1115,38 @@ def array(values: Iterable[V]) -> ArrayValue: │ [3, 42, ... +1] │ └──────────────────────┘ - >>> ibis.array([t.a, 42 + ibis.literal(5)]) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [1, 47] │ - │ [2, 47] │ - │ [3, 47] │ - └──────────────────────┘ + >>> ibis.array([t.a, 42 + ibis.literal(5)], type="array") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(Array(), array) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [1.0, 47.0] │ + │ [2.0, 47.0] │ + │ [3.0, 47.0] │ + └───────────────────────────────┘ """ - return ops.Array(tuple(values)).to_expr() + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Array): + raise ValidationError(f"dtype must be an array, got {type}") + + if isinstance(values, ir.Value): + if type is not None: + return values.cast(type) + elif isinstance(values, ArrayValue): + return values + else: + raise ValidationError( + f"If no type passed, values must be an array, got {values.type()}" + ) + + if values is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + + values = tuple(values) + if len(values) > 0 and type is not None: + return ops.Array(values).to_expr().cast(type) + else: + return ops.Array(values, type).to_expr() diff --git a/ibis/expr/types/maps.py b/ibis/expr/types/maps.py index 79d45d16e188e..f8fd6244c48e4 100644 --- a/ibis/expr/types/maps.py +++ b/ibis/expr/types/maps.py @@ -1,18 +1,21 @@ from __future__ import annotations +from collections.abc import Mapping from typing import TYPE_CHECKING, Any from public import public +import ibis import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable - import ibis.expr.types as ir - from ibis.expr.types.arrays import ArrayColumn + import ibis.expr.datatypes as dt @public @@ -435,8 +438,15 @@ def __getitem__(self, key: ir.Value) -> ir.Column: @public @deferrable def map( - keys: Iterable[Any] | Mapping[Any, Any] | ArrayColumn, - values: Iterable[Any] | ArrayColumn | None = None, + keys: Iterable[Any] + | Mapping[Any, Any] + | ir.ArrayValue + | MapValue + | ir.NullValue + | None, + values: Iterable[Any] | ir.ArrayValue | None = None, + *, + type: str | dt.DataType | None = None, ) -> MapValue: """Create a MapValue. @@ -449,6 +459,9 @@ def map( Keys of the map or `Mapping`. If `keys` is a `Mapping`, `values` must be `None`. values Values of the map or `None`. If `None`, the `keys` argument must be a `Mapping`. + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `map`. Returns ------- @@ -476,16 +489,42 @@ def map( │ ['a', 'b'] │ [1, 2] │ │ ['b'] │ [3] │ └──────────────────────┴──────────────────────┘ - >>> ibis.map(t.keys, t.values) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Map(keys, values) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ map │ - ├──────────────────────┤ - │ {'a': 1, 'b': 2} │ - │ {'b': 3} │ - └──────────────────────┘ + >>> ibis.map(t.keys, t.values, type="map") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Map(keys, Cast(values, array)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ map │ + ├─────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 2.0} │ + │ {'b': 3.0} │ + └─────────────────────────────────────────┘ """ - if values is None: + from ibis.expr import datatypes as dt + + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Map): + raise ValidationError(f"dtype must be a map, got {type}") + + if isinstance(keys, Mapping) and values is None: keys, values = tuple(keys.keys()), tuple(keys.values()) + + if isinstance(keys, ir.Value) and values is None: + if type is not None: + return keys.cast(type) + elif isinstance(keys, MapValue): + return keys + else: + raise ValidationError( + f"If no type passed, value must be a map, got {keys.type()}" + ) + + if keys is None or values is None: + if type is None: + raise ValidationError("If keys is None/NULL, dtype must be provided") + return ir.null(type) + + k_type = dt.Array(value_type=type.key_type) if type is not None else None + v_type = dt.Array(value_type=type.value_type) if type is not None else None + keys = ibis.array(keys, type=k_type) + values = ibis.array(values, type=v_type) return ops.Map(keys, values).to_expr() diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index 65a16700318a8..8876b459b2da3 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -1,28 +1,33 @@ from __future__ import annotations -import collections from keyword import iskeyword from typing import TYPE_CHECKING from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.common.exceptions import IbisError -from ibis.expr.types.generic import Column, Scalar, Value, literal +from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - import ibis.expr.datatypes as dt - import ibis.expr.types as ir from ibis.expr.types.typing import V @public @deferrable def struct( - value: Iterable[tuple[str, V]] | Mapping[str, V], + value: Iterable[tuple[str, V]] + | Mapping[str, V] + | StructValue + | ir.NullValue + | None, + *, type: str | dt.DataType | None = None, ) -> StructValue: """Create a struct expression. @@ -37,8 +42,7 @@ def struct( `(str, Value)`. type An instance of `ibis.expr.datatypes.DataType` or a string indicating - the Ibis type of `value`. This is only used if all of the input values - are Python literals. eg `struct`. + the Ibis type of `value`. eg `struct`. Returns ------- @@ -62,26 +66,45 @@ def struct( Create a struct column from a column and a scalar literal >>> t = ibis.memtable({"a": [1, 2, 3]}) - >>> ibis.struct([("a", t.a), ("b", "foo")]) - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ StructColumn() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ struct │ - ├─────────────────────────────┤ - │ {'a': 1, 'b': 'foo'} │ - │ {'a': 2, 'b': 'foo'} │ - │ {'a': 3, 'b': 'foo'} │ - └─────────────────────────────┘ + >>> ibis.struct([("a", t.a), ("b", "foo")], type="struct") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(StructColumn(), struct) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ struct │ + ├─────────────────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 'foo'} │ + │ {'a': 2.0, 'b': 'foo'} │ + │ {'a': 3.0, 'b': 'foo'} │ + └─────────────────────────────────────────────────────┘ """ import ibis.expr.operations as ops + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Struct): + raise ValidationError(f"dtype must be an struct, got {type}") + + if isinstance(value, ir.Value): + if type is not None: + return value.cast(type) + elif isinstance(value, StructValue): + return value + else: + raise ValidationError( + f"If no type passed, value must be a struct, got {value.type()}" + ) + + if value is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + fields = dict(value) - if any(isinstance(value, Value) for value in fields.values()): - names = tuple(fields.keys()) - values = tuple(fields.values()) - return ops.StructColumn(names=names, values=values).to_expr() - else: - return literal(collections.OrderedDict(fields), type=type) + names = tuple(fields.keys()) + values = tuple(fields.values()) + result = ops.StructColumn(names=names, values=values).to_expr() + if type is not None: + result = result.cast(type) + return result @public