Skip to content

Commit

Permalink
feat: support empty arrays, improve ibis.array() API
Browse files Browse the repository at this point in the history
Picking out the array stuff from ibis-project#8666

Instead of trying to fit the two cases of 0-length and 1+ length arrays into the same op, I split them up into separate ones.
By doing this, if we guarantee that all the elements of ops.Array() have the right type before construction,
we don't have to do any fancy casting during compilation, all the elements will
already have been casted as needed.

This allows for the compilation of array<structs> on some sql backends like postgres.
If we tried to cast the entire array, you end up with SQL like `cast [..] as STRUCT<...>[]`,
which postgres freaks about.
Instead, if we cast each individual element,
such as `[cast({...} as ROW..., cast({...} as ROW...]`, then this is valid SQL.

I added a Length annotation to ops.Array to verify the length is 1+. IDK, this isn't really needed, since if you ever did construct one, then the rlz.highest_precedence_dtype([]) would fail. But that might fail at a later time,
and I wanted to raise an error at construction time. But, end users should never be constructing ops.Arrays directly,
so this is a guardrail just for us ibis devs.
So IDK, we could remove it, but I think it is a nice hint for future us.
  • Loading branch information
NickCrews committed Jun 29, 2024
1 parent 33ec754 commit a655252
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 37 deletions.
9 changes: 8 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -219,6 +221,11 @@ def visit(cls, op: ops.FindInSet, needle, values):
result = np.select(condlist, choicelist, default=-1)
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.EmptyArray, dtype):
pdt = PandasType.from_ibis(dtype)
return np.array([], dtype=pdt)

@classmethod
def visit(cls, op: ops.Array, exprs):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
Expand Down
15 changes: 8 additions & 7 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
elif dtype.is_struct():
if dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
Expand All @@ -106,7 +101,7 @@ def literal(op, **_):
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)
return pl.lit(value, dtype=typ)


_TIMESTAMP_SCALE_TO_UNITS = {
Expand Down Expand Up @@ -973,6 +968,12 @@ def array_concat(op, **kw):
return result


@translate.register(ops.EmptyArray)
def empty_array(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
return pl.lit([], dtype=pdt)


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,9 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_EmptyArray(self, op, *, dtype):
return self.cast(self.f.array(), dtype)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)

Expand Down
68 changes: 67 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -72,6 +73,71 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -1356,7 +1422,7 @@ def test_unnest_range(con):
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
reason="upstream polars bug: https://github.com/pola-rs/polars/issues/17294",
raises=TypeError,
),
],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,12 @@ def query(t, group_cols):
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")


@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
raises=com.OperationNotDefinedError,
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.broken(
["trino"],
reason="invalid code generated for unnesting a struct",
Expand Down
16 changes: 13 additions & 3 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@
ibis.struct(dict(abc=432)),
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
[
"exasol",
"impala",
"mysql",
"sqlite",
"mssql",
],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
@pytest.mark.notyet(
["datafusion", "exasol", "oracle", "flink", "risingwave"],
reason="no unnest support",
raises=exc.OperationNotDefinedError,
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
)
@pytest.mark.notyet(
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,18 +835,18 @@ def test_capitalize(con, inp, expected):
assert pd.isnull(result)


@pytest.mark.never(
["exasol", "impala", "mssql", "mysql", "sqlite"],
reason="Backend doesn't support arrays",
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(
[
"dask",
"pandas",
"polars",
"oracle",
"flink",
"sqlite",
"mssql",
"mysql",
"exasol",
"impala",
],
raises=com.OperationNotDefinedError,
)
Expand Down
19 changes: 13 additions & 6 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,34 @@

from __future__ import annotations

from typing import Optional
from typing import Annotated, Optional

from public import public

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.patterns import Length # noqa: TCH001
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Unary, Value


@public
class EmptyArray(Value):
"""Construct an array with 0 elements."""

dtype: dt.Array
shape = ds.scalar


@public
class Array(Value):
"""Construct an array."""
"""Construct an array with 1+ elements. Use `EmptyArray` for empty arrays."""

exprs: VarTuple[Value]
exprs: Annotated[VarTuple[Value], Length(at_least=1)]

@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)
shape = rlz.shape_like("exprs")

@attribute
def dtype(self):
Expand Down
62 changes: 52 additions & 10 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

from public import public

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.annotations import ValidationError
from ibis.common.deferred import Deferred, deferrable
from ibis.expr.types.generic import Column, Scalar, Value

if TYPE_CHECKING:
from collections.abc import Callable, Iterable

import ibis.expr.types as ir
from ibis.expr.types.typing import V

import ibis.common.exceptions as com


Expand Down Expand Up @@ -1067,7 +1068,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column:

@public
@deferrable
def array(values: Iterable[V]) -> ArrayValue:
def array(
values: ArrayValue | Iterable | ir.NullValue | None,
*,
type: str | dt.DataType | None = None,
) -> ArrayValue:
"""Create an array expression.
If any values are [column expressions](../concepts/datatypes.qmd) the
Expand All @@ -1078,6 +1083,9 @@ def array(values: Iterable[V]) -> ArrayValue:
----------
values
An iterable of Ibis expressions or Python literals
type
An instance of `ibis.expr.datatypes.DataType` or a string indicating
the Ibis type of `value`. eg `array<float>`.
Returns
-------
Expand Down Expand Up @@ -1108,15 +1116,49 @@ def array(values: Iterable[V]) -> ArrayValue:
│ [3, 42, ... +1] │
└──────────────────────┘
>>> ibis.array([t.a, 42 + ibis.literal(5)])
>>> ibis.array([t.a, 42 + ibis.literal(5)], type="array<float>")
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ Array() ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64>
│ array<float64>
├──────────────────────┤
│ [1, 47]
│ [2, 47]
│ [3, 47]
│ [1.0, 47.0]
│ [2.0, 47.0]
│ [3.0, 47.0]
└──────────────────────┘
"""
return ops.Array(tuple(values)).to_expr()
type = dt.dtype(type) if type is not None else None
if type is not None and not isinstance(type, dt.Array):
raise ValidationError(f"type must be an array, got {type}")

if isinstance(values, ir.Value):
if type is not None:
return values.cast(type)
elif isinstance(values, ArrayValue):
return values
else:
raise ValidationError(
f"If no type passed, values must be an array, got {values.type()}"
)

if values is None:
if type is None:
raise ValidationError("If values is None/NULL, type must be provided")
return ir.null(type)

values = tuple(values)
if len(values) == 0:
if type is None:
raise ValidationError("If values is empty, type must be provided")
return ops.EmptyArray(type).to_expr()
else:
value_type = type.value_type if type is not None else None
values = [_value(v, value_type) for v in values]
return ops.Array(values).to_expr()


def _value(x, type) -> ir.Value:
if isinstance(x, (ir.Value, Deferred)):
return x.cast(type) if type is not None else x
else:
return ibis.literal(x, type=type)

0 comments on commit a655252

Please sign in to comment.