From 55497bb580fb27f3d30b320a35cd2580ac4d0b83 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 21 Mar 2024 20:19:30 -0800 Subject: [PATCH] feat: support type kwarg in array() and map() fixes https://github.com/ibis-project/ibis/issues/8289 This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in some cases. The big structural change is that now the core Operations for Array and Structs have a different internal representation, so they can distringuish between - the entire value is NULL - the contained values are NULL Before, ops.Array held onto a `VarTuple[Value]`. So the contained Values could be NULL, but there was no way to say the entire thing was null. Now, ops.Array stores a `None | VarTuple[Value]`. The same thing for ops.StructValue. ops.Map didn't suffer from this, because it stores a `ops.Array`s internally, so since `ops.Array` can distinguish between entirely-NULL and contains-NULL, so can ops.Map A fallout of this is that ops.Array needs a way to explicitly store its dtype. Before, it derived its dtype based on the dtype of its args. But now that `None` is a valid value, it is now possible for there to be no values to inspect! So the Op actually stores its dtype explicitly. If you pass in values, then supplying the dtype on construction is optional, we go back to the old behavior of deriving it from the inputs. This requires the backend compilers to now deal with that case. Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup. You can test this locally with eg `pytest -m -k factory ibis/backends/tests/test_array.py ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py` Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns Also, fix executing NULL arrays on pandas. Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes. Also, fix casting structs on pandas. See https://github.com/ibis-project/ibis/issues/8687 Also, support passing in None to all these constructors. Also, error when the value type can't be inferred from empty python literals (eg what is the value type for the elements of []?) Also, make the type argument for struct() always have an effect, not just when passing in python literals. So basically it can act like a cast. Also, make these constructors idempotent. --- ibis/backends/clickhouse/compiler.py | 8 ++-- ibis/backends/dask/executor.py | 21 ++++++++- ibis/backends/duckdb/compiler.py | 29 +++++++++---- ibis/backends/pandas/executor.py | 17 +++++++- ibis/backends/pandas/kernels.py | 10 ++++- ibis/backends/polars/compiler.py | 9 +++- ibis/backends/postgres/compiler.py | 11 +++-- ibis/backends/snowflake/compiler.py | 4 +- ibis/backends/sql/compiler.py | 23 +++++++--- ibis/backends/sql/datatypes.py | 6 ++- ibis/backends/tests/test_array.py | 39 +++++++++++++++++ ibis/backends/tests/test_map.py | 45 ++++++++++++++++++- ibis/backends/tests/test_struct.py | 65 ++++++++++++++++++++++++---- ibis/backends/trino/compiler.py | 6 ++- ibis/expr/operations/arrays.py | 26 ++++++++--- ibis/expr/operations/structs.py | 44 +++++++++++++------ ibis/expr/types/arrays.py | 36 +++++++++------ ibis/expr/types/maps.py | 43 ++++++++++++------ ibis/expr/types/structs.py | 48 +++++++++++--------- 19 files changed, 382 insertions(+), 108 deletions(-) diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index ba24173c6ae19..c695a536a8727 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -455,10 +455,10 @@ def visit_GroupConcat(self, op, *, arg, sep, where): def visit_Cot(self, op, *, arg): return 1.0 / self.f.tan(arg) - def visit_StructColumn(self, op, *, values, names): - # ClickHouse struct types cannot be nullable - # (non-nested fields can be nullable) - return self.cast(self.f.tuple(*values), op.dtype.copy(nullable=False)) + def visit_StructColumn(self, op, *, values, names, dtype): + if values is None: + return self.cast(NULL, dtype) + return self.cast(self.f.tuple(*values), dtype) def visit_Clip(self, op, *, arg, lower, upper): if upper is not None: diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index d2d384ff04322..13e2f353697bc 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -118,9 +118,26 @@ def mapper(df, cases): return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.StructColumn, names, values, dtype): + if values is None: + return None + + def process_row(row): + return { + name: DaskConverter.convert_scalar(value, dty) + for name, value, dty in zip(names, row, dtype.fields.values()) + } + + pdt = PandasType.from_ibis(op.dtype) + return cls.rowwise(process_row, values, name=op.name, dtype=pdt) + + @classmethod + def visit(cls, op: ops.Array, exprs, dtype): + if exprs is None: + return None + pdt = PandasType.from_ibis(op.dtype) return cls.rowwise( - lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object + lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=pdt ) @classmethod diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index d2f8280adb597..ffc8897f19a6c 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -89,15 +89,20 @@ def _aggregate(self, funcname: str, *args, where): return sge.Filter(this=expr, expression=sge.Where(this=where)) return expr - def visit_StructColumn(self, op, *, names, values): - return sge.Struct.from_arg_list( - [ - sge.PropertyEQ( - this=sg.to_identifier(name, quoted=self.quoted), expression=value - ) - for name, value in zip(names, values) - ] - ) + def visit_StructColumn(self, op, *, names, values, dtype): + if values is None: + val = NULL + else: + val = sge.Struct.from_arg_list( + [ + sge.PropertyEQ( + this=sg.to_identifier(name, quoted=self.quoted), + expression=value, + ) + for name, value in zip(names, values) + ] + ) + return self.cast(val, dtype) def visit_ArrayDistinct(self, op, *, arg): return self.if_( @@ -199,6 +204,12 @@ def visit_ArrayZip(self, op, *, arg): any_arg_null = sg.or_(*(arr.is_(NULL) for arr in arg)) return self.if_(any_arg_null, NULL, zipped_arrays) + def visit_Map(self, op, *, keys, values): + # workaround for https://github.com/ibis-project/ibis/issues/8632 + regular = self.f.map(keys, values) + either_null = sg.or_(keys.is_(NULL), values.is_(NULL)) + return self.if_(either_null, NULL, regular) + def visit_MapGet(self, op, *, arg, key, default): return self.f.ifnull( self.f.list_extract(self.f.element_at(arg, key), 1), default diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index d1a3f8d73d0de..93834d42fab26 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -212,7 +212,22 @@ def visit(cls, op: ops.FindInSet, needle, values): return pd.Series(result, name=op.name) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.StructColumn, names, values, dtype): + if values is None: + return None + + def process_row(row): + return { + name: PandasConverter.convert_scalar(value, dty) + for name, value, dty in zip(names, row, dtype.fields.values()) + } + + return cls.rowwise(process_row, values) + + @classmethod + def visit(cls, op: ops.Array, exprs, dtype): + if exprs is None: + return None return cls.rowwise(lambda row: np.array(row, dtype=object), exprs) @classmethod diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index b9bb577d94392..3b654cc3048e9 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -254,6 +254,14 @@ def round_serieswise(arg, digits): return np.round(arg, digits).astype("float64") +def map_(row: pd.Series) -> dict: + k = row["keys"] + v = row["values"] + if k is None or v is None: + return None + return dict(zip(k, v)) + + reductions = { ops.Min: lambda x: x.min(), ops.Max: lambda x: x.max(), @@ -362,7 +370,7 @@ def round_serieswise(arg, digits): ops.EndsWith: lambda row: row["arg"].endswith(row["end"]), ops.IntegerRange: integer_range_rowwise, ops.JSONGetItem: lambda row: safe_json_getitem(row["arg"], row["index"]), - ops.Map: lambda row: dict(zip(row["keys"], row["values"])), + ops.Map: map_, ops.MapGet: lambda row: safe_get(row["arg"], row["key"], row["default"]), ops.MapContains: lambda row: safe_contains(row["arg"], row["key"]), ops.MapMerge: lambda row: safe_merge(row["left"], row["right"]), diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 0e84b255fdf66..24a9ea44d5ebd 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -728,6 +728,8 @@ def struct_field(op, **kw): @translate.register(ops.StructColumn) def struct_column(op, **kw): + if op.values is None: + return pl.lit(None) fields = [translate(v, **kw).alias(k) for k, v in zip(op.names, op.values)] return pl.struct(fields) @@ -969,8 +971,13 @@ def array_concat(op, **kw): @translate.register(ops.Array) def array_column(op, **kw): + if op.exprs is None: + return pl.lit(None, dtype=PolarsType.from_ibis(op.dtype)) cols = [translate(col, **kw) for col in op.exprs] - return pl.concat_list(cols) + if len(cols) > 0: + return pl.concat_list(cols) + else: + return pl.lit([], dtype=PolarsType.from_ibis(op.dtype)) @translate.register(ops.ArrayCollect) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 7d794984375b3..1e46810a8fce6 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -319,8 +319,10 @@ def visit_StructField(self, op, *, arg, field): op.dtype, ) - def visit_StructColumn(self, op, *, names, values): - return self.f.row(*map(self.cast, values, op.dtype.types)) + def visit_StructColumn(self, op, *, names, values, dtype): + if values is None: + return self.cast(self.f.jsonb_build_object(), op.dtype) + return self.f.row(*map(self.cast, values, dtype.types)) def visit_ToJSONArray(self, op, *, arg): return self.if_( @@ -330,7 +332,10 @@ def visit_ToJSONArray(self, op, *, arg): ) def visit_Map(self, op, *, keys, values): - return self.f.map(self.f.array(*keys), self.f.array(*values)) + # map(["a", "b"], NULL) results in {"a": NULL, "b": NULL} in regular postgres, + # so we need to modify it to return NULL instead + regular = self.f.map(keys, values) + return self.if_(values.is_(NULL), NULL, regular) def visit_MapLength(self, op, *, arg): return self.f.cardinality(self.f.akeys(arg)) diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index 8b980aa9aa8f3..0295b5d74a5ad 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -267,7 +267,9 @@ def visit_IntegerRange(self, op, *, start, stop, step): step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array() ) - def visit_StructColumn(self, op, *, names, values): + def visit_StructColumn(self, op, *, names, values, dtype): + if values is None: + return self.cast(NULL, dtype) return self.f.object_construct_keep_null( *itertools.chain.from_iterable(zip(names, values)) ) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index a1642e3e50fa7..81aadf3930e38 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -951,13 +951,24 @@ def visit_ExistsSubquery(self, op, *, rel): def visit_InSubquery(self, op, *, rel, needle): return needle.isin(query=rel.this) - def visit_Array(self, op, *, exprs): - return self.f.array(*exprs) + def visit_Array(self, op, *, exprs, dtype): + if exprs is None: + vals = NULL + else: + vals = self.f.array(*exprs) + return self.cast(vals, dtype) - def visit_StructColumn(self, op, *, names, values): - return sge.Struct.from_arg_list( - [value.as_(name, quoted=self.quoted) for name, value in zip(names, values)] - ) + def visit_StructColumn(self, op, *, names, values, dtype): + if values is None: + vals = NULL + else: + vals = sge.Struct.from_arg_list( + [ + value.as_(name, quoted=self.quoted) + for name, value in zip(names, values) + ] + ) + return self.cast(vals, dtype) def visit_StructField(self, op, *, arg, field): return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index d391d9ab92b60..e81ae8796b5a8 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -960,8 +960,10 @@ class ClickHouseType(SqlglotType): def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: """Convert a sqlglot type to an ibis type.""" typ = super().from_ibis(dtype) - if dtype.nullable and not (dtype.is_map() or dtype.is_array()): - # map cannot be nullable in clickhouse + # nested types cannot be nullable in clickhouse + if dtype.nullable and not ( + dtype.is_map() or dtype.is_array() or dtype.is_struct() + ): return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) else: return typ diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index c9768eff3d728..159094696b33f 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -30,6 +30,7 @@ PySparkAnalysisException, TrinoUserError, ) +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never( @@ -70,6 +71,43 @@ # list. +def test_array_factory(con): + a = ibis.array([1, 2, 3]) + assert con.execute(a) == [1, 2, 3] + + a2 = ibis.array(a) + assert con.execute(a2) == [1, 2, 3] + + typed = ibis.array([1, 2, 3], type="array") + assert con.execute(typed) == ["1", "2", "3"] + + typed2 = ibis.array(a, type="array") + assert con.execute(typed2) == ["1", "2", "3"] + + +@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError) +def test_array_factory_empty(con): + with pytest.raises(ValidationError): + ibis.array([]) + + empty_typed = ibis.array([], type="array") + assert empty_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(empty_typed) == [] + + +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +def test_array_factory_null(con): + with pytest.raises(ValidationError): + ibis.array(None) + with pytest.raises(ValidationError): + ibis.array(None, type="int64") + none_typed = ibis.array(None, type="array") + assert none_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(none_typed) is None + + def test_array_column(backend, alltypes, df): expr = ibis.array( [alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)] @@ -107,6 +145,7 @@ def test_array_scalar(con): @pytest.mark.notimpl(["flink", "polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl("postgres", raises=PsycoPg2SyntaxError) def test_array_repeat(con): expr = ibis.array([1.0, 2.0]) * 2 diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 28ac4010ab82b..cb540f81bee3d 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -10,7 +10,11 @@ import ibis import ibis.common.exceptions as exc import ibis.expr.datatypes as dt -from ibis.backends.tests.errors import Py4JJavaError +from ibis.backends.tests.errors import ( + ClickHouseDatabaseError, + Py4JJavaError, +) +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never( @@ -39,6 +43,43 @@ ) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres +def test_map_factory(con): + m = ibis.map({"a": 1, "b": 2}) + assert con.execute(m) == {"a": 1, "b": 2} + + m2 = ibis.map(m) + assert con.execute(m2) == {"a": 1, "b": 2} + + typed = ibis.map({"a": 1, "b": 2}, type="map") + assert con.execute(typed) == {"a": "1", "b": "2"} + + typed2 = ibis.map(m, type="map") + assert con.execute(typed2) == {"a": "1", "b": "2"} + + +@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError) +@mark_notimpl_risingwave_hstore +def test_map_factory_empty(con): + with pytest.raises(ValidationError): + ibis.map({}) + empty_typed = ibis.map({}, type="map") + assert empty_typed.type() == dt.Map(key_type=dt.string, value_type=dt.string) + assert con.execute(empty_typed) == {} + + +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +def test_map_factory_null(con): + with pytest.raises(ValidationError): + ibis.map(None) + null_typed = ibis.map(None, type="map") + assert null_typed.type() == dt.Map(key_type=dt.string, value_type=dt.string) + assert con.execute(null_typed) is None + + @pytest.mark.notimpl(["pandas", "dask"]) def test_map_table(backend): table = backend.map @@ -474,6 +515,6 @@ def test_map_keys_unnest(backend): @mark_notimpl_risingwave_hstore def test_map_contains_null(con): - expr = ibis.map(["a"], ibis.literal([None], type="array")) + expr = ibis.map(["a"], ibis.array([None], type="array")) assert con.execute(expr.contains("a")) assert not con.execute(expr.contains("b")) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 682da63faf273..fff7e01b05eac 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -11,10 +11,12 @@ import ibis import ibis.expr.datatypes as dt from ibis.backends.tests.errors import ( + ClickHouseDatabaseError, PsycoPg2InternalError, PsycoPg2SyntaxError, Py4JJavaError, ) +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), @@ -22,6 +24,46 @@ pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]), ] +mark_notimpl_postgres_literals = pytest.mark.notimpl( + "postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError +) + + +@pytest.mark.broken("postgres", reason="JSON handling is buggy") +def test_struct_factory(con): + s = ibis.struct({"a": 1, "b": 2}) + assert con.execute(s) == {"a": 1, "b": 2} + + s2 = ibis.struct(s) + assert con.execute(s2) == {"a": 1, "b": 2} + + typed = ibis.struct({"a": 1, "b": 2}, type="struct") + assert con.execute(typed) == {"a": "1", "b": "2"} + + typed2 = ibis.struct(s, type="struct") + assert con.execute(typed2) == {"a": "1", "b": "2"} + + +def test_struct_factory_empty(): + with pytest.raises(ValidationError): + ibis.struct({}) + with pytest.raises(ValidationError): + ibis.struct({}, type="struct<>") + with pytest.raises(ValidationError): + ibis.struct({}, type="struct") + + +@mark_notimpl_postgres_literals +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +def test_struct_factory_null(con): + with pytest.raises(TypeError): + ibis.struct(None) + none_typed = ibis.struct(None, type="struct") + assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64}) + assert con.execute(none_typed) is None + @pytest.mark.notimpl(["dask"]) @pytest.mark.parametrize( @@ -70,8 +112,18 @@ def test_all_fields(struct, struct_df): _NULL_STRUCT_LITERAL = ibis.NA.cast("struct") -@pytest.mark.notimpl(["postgres", "risingwave"]) -@pytest.mark.parametrize("field", ["a", "b", "c"]) +@pytest.mark.notimpl(["risingwave"]) +@pytest.mark.parametrize( + "field", + [ + "a", + pytest.param( + "b", + marks=pytest.mark.broken("postgres", reason="""result is `"2"`, not `2`"""), + ), + "c", + ], +) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" ) @@ -84,7 +136,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@mark_notimpl_postgres_literals @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -100,7 +152,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" ) @@ -115,7 +167,7 @@ def test_struct_column(alltypes, df): tm.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) +@pytest.mark.notimpl(["postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) @@ -140,9 +192,6 @@ def test_collect_into_struct(alltypes): assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730 -@pytest.mark.notimpl( - ["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError -) @pytest.mark.notimpl( ["risingwave"], reason="struct literals not implemented", diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index a32cc72f15272..fca4d2cbd0c37 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -255,8 +255,10 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): raise com.UnsupportedOperationError(f"{unit!r} unit is not supported") return self.cast(res, op.dtype) - def visit_StructColumn(self, op, *, names, values): - return self.cast(sge.Struct(expressions=list(values)), op.dtype) + def visit_StructColumn(self, op, *, names, values, dtype): + if values is None: + return self.cast(NULL, dtype) + return self.cast(sge.Struct(expressions=list(values)), dtype) def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_floating(): diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 68ee711a2da6b..62f91c60f886c 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -7,23 +7,37 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common.annotations import attribute +from ibis.common.annotations import ValidationError, attribute from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Unary, Value @public class Array(Value): - exprs: VarTuple[Value] + exprs: Optional[VarTuple[Value]] + dtype: Optional[dt.Array] = None + + def __init__(self, exprs, dtype: dt.Array | None = None): + if exprs is None: + if dtype is None: + raise ValidationError("If values is None, dtype must be provided") + if not isinstance(dtype, dt.Array): + raise ValidationError(f"dtype must be an array, got {dtype}") + elif len(exprs) == 0: + if dtype is None: + raise ValidationError("If values is empty, dtype must be provided") + if not isinstance(dtype, dt.Array): + raise ValidationError(f"dtype must be an array, got {dtype}") + elif dtype is None: + dtype = dt.Array(rlz.highest_precedence_dtype(exprs)) + super().__init__(exprs=exprs, dtype=dtype) @attribute def shape(self): + if self.exprs is None or len(self.exprs) == 0: + return ds.scalar return rlz.highest_precedence_shape(self.exprs) - @attribute - def dtype(self): - return dt.Array(rlz.highest_precedence_dtype(self.exprs)) - @public class ArrayLength(Unary): diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index 20c0c3dc0a4ef..06492e295b2c0 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -1,7 +1,10 @@ from __future__ import annotations +from typing import Optional + from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import ValidationError, attribute @@ -30,19 +33,34 @@ def name(self) -> str: @public class StructColumn(Value): names: VarTuple[str] - values: VarTuple[Value] - - shape = rlz.shape_like("values") + values: Optional[VarTuple[Value]] + dtype: Optional[dt.Struct] = None - def __init__(self, names, values): - if len(names) != len(values): - raise ValidationError( - f"Length of names ({len(names)}) does not match length of " - f"values ({len(values)})" - ) - super().__init__(names=names, values=values) + def __init__( + self, + names: VarTuple[str], + values: None | VarTuple[Value], + dtype: dt.Struct | None = None, + ): + if len(names) == 0: + raise ValidationError("StructColumn must have at least one field") + if values is None: + if dtype is None: + raise ValidationError("If values is None, dtype must be provided") + if not isinstance(dtype, dt.Struct): + raise ValidationError(f"dtype must be a struct, got {dtype}") + else: + if len(names) != len(values): + raise ValidationError( + f"Length of names ({len(names)}) does not match length of " + f"values ({len(values)})" + ) + if dtype is None: + dtype = dt.Struct.from_tuples(zip(names, (v.dtype for v in values))) + super().__init__(names=names, values=values, dtype=dtype) @attribute - def dtype(self) -> dt.DataType: - dtypes = (value.dtype for value in self.values) - return dt.Struct.from_tuples(zip(self.names, dtypes)) + def shape(self) -> ds.DataShape: + if self.values is None: + return ds.scalar + return rlz.highest_precedence_shape(self.values) diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 68fb594a7e166..5f82e57ee09b6 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -6,13 +6,14 @@ from public import public import ibis.expr.operations as ops +import ibis.expr.types as ir from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable - import ibis.expr.types as ir + import ibis.expr.datatypes as dt from ibis.expr.types.typing import V import ibis.common.exceptions as com @@ -1078,7 +1079,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column: @public @deferrable -def array(values: Iterable[V]) -> ArrayValue: +def array( + values: ArrayValue | Iterable[V] | None, + *, + type: str | dt.DataType | None = None, +) -> ArrayValue: """Create an array expression. If any values are [column expressions](../concepts/datatypes.qmd) the @@ -1089,6 +1094,9 @@ def array(values: Iterable[V]) -> ArrayValue: ---------- values An iterable of Ibis expressions or Python literals + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `array`. Returns ------- @@ -1117,15 +1125,17 @@ def array(values: Iterable[V]) -> ArrayValue: │ [3, 42, ... +1] │ └──────────────────────┘ - >>> ibis.array([t.a, 42 + ibis.literal(5)]) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [1, 47] │ - │ [2, 47] │ - │ [3, 47] │ - └──────────────────────┘ + >>> ibis.array([t.a, 42 + ibis.literal(5)], type="array") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(Array(), array) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [1.0, 47.0] │ + │ [2.0, 47.0] │ + │ [3.0, 47.0] │ + └───────────────────────────────┘ """ - return ops.Array(tuple(values)).to_expr() + if isinstance(values, ir.ArrayValue): + return values.cast(type) if type is not None else values + return ops.Array(values, type).to_expr() diff --git a/ibis/expr/types/maps.py b/ibis/expr/types/maps.py index b61f7caceedd8..dce03771ecd10 100644 --- a/ibis/expr/types/maps.py +++ b/ibis/expr/types/maps.py @@ -4,15 +4,16 @@ from public import public +import ibis import ibis.expr.operations as ops +import ibis.expr.types as ir from ibis.common.deferred import deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable, Mapping - import ibis.expr.types as ir - from ibis.expr.types.arrays import ArrayColumn + import ibis.expr.datatypes as dt @public @@ -435,8 +436,10 @@ def __getitem__(self, key: ir.Value) -> ir.Column: @public @deferrable def map( - keys: Iterable[Any] | Mapping[Any, Any] | ArrayColumn, - values: Iterable[Any] | ArrayColumn | None = None, + keys: Iterable[Any] | Mapping[Any, Any] | ir.ArrayValue | MapValue | None, + values: Iterable[Any] | ir.ArrayValue | None = None, + *, + type: str | dt.DataType | None = None, ) -> MapValue: """Create a MapValue. @@ -449,6 +452,9 @@ def map( Keys of the map or `Mapping`. If `keys` is a `Mapping`, `values` must be `None`. values Values of the map or `None`. If `None`, the `keys` argument must be a `Mapping`. + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `map`. Returns ------- @@ -476,16 +482,25 @@ def map( │ ['a', 'b'] │ [1, 2] │ │ ['b'] │ [3] │ └──────────────────────┴──────────────────────┘ - >>> ibis.map(t.keys, t.values) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Map(keys, values) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ map │ - ├──────────────────────┤ - │ {'a': 1, 'b': 2} │ - │ {'b': 3} │ - └──────────────────────┘ + >>> ibis.map(t.keys, t.values, type="map") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Map(keys, Cast(values, array)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ map │ + ├─────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 2.0} │ + │ {'b': 3.0} │ + └─────────────────────────────────────────┘ """ - if values is None: + from ibis.expr import datatypes as dt + + if isinstance(keys, MapValue): + return keys.cast(type) if type is not None else keys + if values is None and keys is not None: keys, values = tuple(keys.keys()), tuple(keys.values()) + type = dt.dtype(type) if type is not None else None + key_type = dt.Array(value_type=type.key_type) if type is not None else None + value_type = dt.Array(value_type=type.value_type) if type is not None else None + keys = ibis.array(keys, type=key_type) + values = ibis.array(values, type=value_type) return ops.Map(keys, values).to_expr() diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index 65a16700318a8..5e3110659bcfa 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -1,20 +1,19 @@ from __future__ import annotations -import collections from keyword import iskeyword from typing import TYPE_CHECKING from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.deferred import deferrable from ibis.common.exceptions import IbisError -from ibis.expr.types.generic import Column, Scalar, Value, literal +from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - import ibis.expr.datatypes as dt import ibis.expr.types as ir from ibis.expr.types.typing import V @@ -22,7 +21,8 @@ @public @deferrable def struct( - value: Iterable[tuple[str, V]] | Mapping[str, V], + value: Iterable[tuple[str, V]] | Mapping[str, V] | StructValue | None, + *, type: str | dt.DataType | None = None, ) -> StructValue: """Create a struct expression. @@ -37,8 +37,7 @@ def struct( `(str, Value)`. type An instance of `ibis.expr.datatypes.DataType` or a string indicating - the Ibis type of `value`. This is only used if all of the input values - are Python literals. eg `struct`. + the Ibis type of `value`. eg `struct`. Returns ------- @@ -62,26 +61,35 @@ def struct( Create a struct column from a column and a scalar literal >>> t = ibis.memtable({"a": [1, 2, 3]}) - >>> ibis.struct([("a", t.a), ("b", "foo")]) - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ StructColumn() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ struct │ - ├─────────────────────────────┤ - │ {'a': 1, 'b': 'foo'} │ - │ {'a': 2, 'b': 'foo'} │ - │ {'a': 3, 'b': 'foo'} │ - └─────────────────────────────┘ + >>> ibis.struct([("a", t.a), ("b", "foo")], type="struct") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(StructColumn(), struct) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ struct │ + ├─────────────────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 'foo'} │ + │ {'a': 2.0, 'b': 'foo'} │ + │ {'a': 3.0, 'b': 'foo'} │ + └─────────────────────────────────────────────────────┘ """ import ibis.expr.operations as ops - fields = dict(value) - if any(isinstance(value, Value) for value in fields.values()): + if isinstance(value, StructValue): + return value.cast(type) if type is not None else value + if value is not None: + fields = dict(value) names = tuple(fields.keys()) values = tuple(fields.values()) - return ops.StructColumn(names=names, values=values).to_expr() else: - return literal(collections.OrderedDict(fields), type=type) + if type is None: + raise TypeError("Must specify type if value is None") + type = dt.dtype(type) + names = type.names + values = None + result = ops.StructColumn(names=names, values=values, dtype=type).to_expr() + if type is not None: + return result.cast(type) + return result @public