Skip to content

Commit

Permalink
feat: support type kwarg in array() and map()
Browse files Browse the repository at this point in the history
fixes #8289

This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in
some cases.

The big structural change is that now the core Operations for Array and Structs have a different internal representation, so they can distringuish between
- the entire value is NULL
- the contained values are NULL

Before, ops.Array held onto a `VarTuple[Value]`. So the contained Values could be NULL, but there was no way to say the entire thing was null. Now, ops.Array stores a `None | VarTuple[Value]`. The same thing for ops.StructValue. ops.Map didn't suffer from this, because it stores a `ops.Array`s internally, so since `ops.Array` can distinguish between entirely-NULL and contains-NULL, so can ops.Map

A fallout of this is that ops.Array needs a way to explicitly store its dtype. Before, it derived its dtype based on the dtype of its args. But
now that `None` is a valid value, it is now possible for there to be no values to inspect! So the Op actually stores its dtype explicitly. If you pass in values, then supplying the dtype on construction is optional, we go back to the old behavior of deriving it from the inputs.

This requires the backend compilers to now deal with that case.

Several of the backends were always broken here, they just weren't getting caught. I marked them as broken,
we can fix them in a followup.

You can test this locally with eg
`pytest -m <backend> -k factory ibis/backends/tests/test_array.py  ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py`

Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns

Also, fix executing NULL arrays on pandas.

Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes.

Also, fix casting structs on pandas.
See #8687

Also, support passing in None to all these constructors.

Also, error when the value type can't be inferred from empty python literals
(eg what is the value type for the elements of []?)

Also, make the type argument for struct() always have an effect, not just when passing in python literals.
So basically it can act like a cast.

Also, make these constructors idempotent.
  • Loading branch information
NickCrews committed May 1, 2024
1 parent 62ec580 commit a00c26d
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 106 deletions.
8 changes: 4 additions & 4 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_Cot(self, op, *, arg):
return 1.0 / self.f.tan(arg)

def visit_StructColumn(self, op, *, values, names):
# ClickHouse struct types cannot be nullable
# (non-nested fields can be nullable)
return self.cast(self.f.tuple(*values), op.dtype.copy(nullable=False))
def visit_StructColumn(self, op, *, values, names, dtype):
if values is None:
return self.cast(NULL, dtype)
return self.cast(self.f.tuple(*values), dtype)

def visit_Clip(self, op, *, arg, lower, upper):
if upper is not None:
Expand Down
21 changes: 19 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,26 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.StructColumn, names, values, dtype):
if values is None:
return None

def process_row(row):
return {
name: DaskConverter.convert_scalar(value, dty)
for name, value, dty in zip(names, row, dtype.fields.values())
}

pdt = PandasType.from_ibis(op.dtype)
return cls.rowwise(process_row, values, name=op.name, dtype=pdt)

@classmethod
def visit(cls, op: ops.Array, exprs, dtype):
if exprs is None:
return None
pdt = PandasType.from_ibis(op.dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=pdt
)

@classmethod
Expand Down
23 changes: 14 additions & 9 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,20 @@ def _aggregate(self, funcname: str, *args, where):
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
sge.PropertyEQ(
this=sg.to_identifier(name, quoted=self.quoted), expression=value
)
for name, value in zip(names, values)
]
)
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
val = NULL
else:
val = sge.Struct.from_arg_list(
[
sge.PropertyEQ(
this=sg.to_identifier(name, quoted=self.quoted),
expression=value,
)
for name, value in zip(names, values)
]
)
return self.cast(val, dtype)

def visit_ArrayDistinct(self, op, *, arg):
return self.if_(
Expand Down
17 changes: 16 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,22 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.StructColumn, names, values, dtype):
if values is None:
return None

def process_row(row):
return {
name: PandasConverter.convert_scalar(value, dty)
for name, value, dty in zip(names, row, dtype.fields.values())
}

return cls.rowwise(process_row, values)

@classmethod
def visit(cls, op: ops.Array, exprs, dtype):
if exprs is None:
return None
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ def arbitrary(arg):
return arg.iat[0] if len(arg) else None


def map_(row: pd.Series) -> dict:
k = row["keys"]
v = row["values"]
if k is None or v is None:
return None
return dict(zip(k, v))


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand Down Expand Up @@ -369,7 +377,7 @@ def arbitrary(arg):
ops.EndsWith: lambda row: row["arg"].endswith(row["end"]),
ops.IntegerRange: integer_range_rowwise,
ops.JSONGetItem: lambda row: safe_json_getitem(row["arg"], row["index"]),
ops.Map: lambda row: dict(zip(row["keys"], row["values"])),
ops.Map: map_,
ops.MapGet: lambda row: safe_get(row["arg"], row["key"], row["default"]),
ops.MapContains: lambda row: safe_contains(row["arg"], row["key"]),
ops.MapMerge: lambda row: safe_merge(row["left"], row["right"]),
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ def struct_field(op, **kw):

@translate.register(ops.StructColumn)
def struct_column(op, **kw):
if op.values is None:
return pl.lit(None)
fields = [translate(v, **kw).alias(k) for k, v in zip(op.names, op.values)]
return pl.struct(fields)

Expand Down Expand Up @@ -981,8 +983,13 @@ def array_concat(op, **kw):

@translate.register(ops.Array)
def array_column(op, **kw):
if op.exprs is None:
return pl.lit(None, dtype=PolarsType.from_ibis(op.dtype))
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
if len(cols) > 0:
return pl.concat_list(cols)
else:
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))


@translate.register(ops.ArrayCollect)
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,10 @@ def visit_UnwrapJSONBoolean(self, op, *, arg):
NULL,
)

def visit_StructColumn(self, op, *, names, values):
return self.f.row(*map(self.cast, values, op.dtype.types))
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
return self.cast(self.f.jsonb_build_object(), op.dtype)
return self.f.row(*map(self.cast, values, dtype.types))

def visit_ToJSONArray(self, op, *, arg):
return self.if_(
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def visit_IntegerRange(self, op, *, start, stop, step):
step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array()
)

def visit_StructColumn(self, op, *, names, values):
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
return self.cast(NULL, dtype)
return self.f.object_construct_keep_null(
*itertools.chain.from_iterable(zip(names, values))
)
Expand Down
23 changes: 17 additions & 6 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,13 +951,24 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
if exprs is None:
vals = NULL
else:
vals = self.f.array(*exprs)
return self.cast(vals, dtype)

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[value.as_(name, quoted=self.quoted) for name, value in zip(names, values)]
)
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
vals = NULL
else:
vals = sge.Struct.from_arg_list(
[
value.as_(name, quoted=self.quoted)
for name, value in zip(names, values)
]
)
return self.cast(vals, dtype)

def visit_StructField(self, op, *, arg, field):
return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted))
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,8 +999,10 @@ class ClickHouseType(SqlglotType):
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
# map cannot be nullable in clickhouse
# nested types cannot be nullable in clickhouse
if dtype.nullable and not (
dtype.is_map() or dtype.is_array() or dtype.is_struct()
):
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ
Expand Down
39 changes: 39 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -71,6 +72,43 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert con.execute(a2) == [1, 2, 3]

typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(a, type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -108,6 +146,7 @@ def test_array_scalar(con):


@pytest.mark.notimpl(["flink", "polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl("postgres", raises=PsycoPg2SyntaxError)
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import ibis
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError
from ibis.backends.tests.errors import (
PsycoPg2InternalError,
Py4JJavaError,
)

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -669,6 +672,6 @@ def test_map_keys_unnest(backend):

@mark_notimpl_risingwave_hstore
def test_map_contains_null(con):
expr = ibis.map(["a"], ibis.literal([None], type="array<string>"))
expr = ibis.map(["a"], ibis.array([None], type="array<string>"))
assert con.execute(expr.contains("a"))
assert not con.execute(expr.contains("b"))
54 changes: 47 additions & 7 deletions ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,54 @@
Py4JJavaError,
PySparkAnalysisException,
)
from ibis.common.exceptions import IbisError, OperationNotDefinedError
from ibis.common.exceptions import IbisError, OperationNotDefinedError, ValidationError

pytestmark = [
pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"),
pytest.mark.notyet(["impala"]),
pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]),
]

mark_notimpl_postgres_literals = pytest.mark.notimpl(
"postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError
)


@pytest.mark.broken("postgres", reason="JSON handling is buggy")
def test_struct_factory(con):
s = ibis.struct({"a": 1, "b": 2})
assert con.execute(s) == {"a": 1, "b": 2}

s2 = ibis.struct(s)
assert con.execute(s2) == {"a": 1, "b": 2}

typed = ibis.struct({"a": 1, "b": 2}, type="struct<a: string, b: string>")
assert con.execute(typed) == {"a": "1", "b": "2"}

typed2 = ibis.struct(s, type="struct<a: string, b: string>")
assert con.execute(typed2) == {"a": "1", "b": "2"}


def test_struct_factory_empty():
with pytest.raises(ValidationError):
ibis.struct({})
with pytest.raises(ValidationError):
ibis.struct({}, type="struct<>")
with pytest.raises(ValidationError):
ibis.struct({}, type="struct<a: float64, b: float64>")


@mark_notimpl_postgres_literals
@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_struct_factory_null(con):
with pytest.raises(TypeError):
ibis.struct(None)
none_typed = ibis.struct(None, type="struct<a: float64, b: float>")
assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64})
assert con.execute(none_typed) is None


@pytest.mark.notimpl(["dask"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -79,6 +119,9 @@ def test_all_fields(struct, struct_df):

@pytest.mark.notimpl(["postgres", "risingwave"])
@pytest.mark.parametrize("field", ["a", "b", "c"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from literals"
)
def test_literal(backend, con, field):
query = _STRUCT_LITERAL[field]
dtype = query.type().to_pandas()
Expand All @@ -88,7 +131,7 @@ def test_literal(backend, con, field):
backend.assert_series_equal(result, expected.astype(dtype))


@pytest.mark.notimpl(["postgres"])
@mark_notimpl_postgres_literals
@pytest.mark.parametrize("field", ["a", "b", "c"])
@pytest.mark.notyet(
["clickhouse"], reason="clickhouse doesn't support nullable nested types"
Expand All @@ -101,7 +144,7 @@ def test_null_literal(backend, con, field):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"])
@pytest.mark.notimpl(["postgres", "risingwave"])
def test_struct_column(alltypes, df):
t = alltypes
expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)))
Expand All @@ -113,7 +156,7 @@ def test_struct_column(alltypes, df):
tm.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"])
@pytest.mark.notimpl(["postgres", "risingwave", "polars"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from collect"
)
Expand All @@ -138,9 +181,6 @@ def test_collect_into_struct(alltypes):
assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730


@pytest.mark.notimpl(
["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError
)
@pytest.mark.notimpl(
["risingwave"],
reason="struct literals not implemented",
Expand Down
Loading

0 comments on commit a00c26d

Please sign in to comment.