Skip to content

Commit

Permalink
feat: support empty arrays, improve ibis.array() API
Browse files Browse the repository at this point in the history
Picking out the array stuff from #8666
  • Loading branch information
NickCrews committed Jun 28, 2024
1 parent 33ec754 commit f7e4931
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 54 deletions.
6 changes: 4 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
plan,
)
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
from ibis.formats.numpy import NumpyType
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import gen_name

Expand Down Expand Up @@ -155,9 +156,10 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
np_type = NumpyType.from_ibis(dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
27 changes: 16 additions & 11 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
# There are some interval types that _make_duration() can handle,
# but PolarsType.from_ibis can't, so we need to handle them here.
if dtype.is_interval():
return _make_duration(value, dtype)

typ = PolarsType.from_ibis(dtype)
if value is None:
return pl.lit(None, dtype=typ)
elif dtype.is_array():
return pl.lit(pl.Series("", value).implode(), dtype=typ)
elif dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -974,9 +976,12 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
if op.exprs:
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
else:
return pl.lit([], dtype=pdt)


@translate.register(ops.ArrayCollect)
Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,20 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
if not exprs:
return self.cast(self.f.array(), dtype)

def maybe_cast(ibis_val, sg_expr):
if ibis_val.dtype == dtype.value_type:
return sg_expr
else:
return self.cast(sg_expr, dtype.value_type)

cast_exprs = [
maybe_cast(ibis_val, sg_expr) for ibis_val, sg_expr in zip(op.exprs, exprs)
]
return self.f.array(*cast_exprs)

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
85 changes: 80 additions & 5 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -72,6 +73,85 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


@pytest.mark.broken(
["pandas", "dask"],
raises=AssertionError,
reason="results in [1, 2, 3]",
)
def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.notyet(
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
@pytest.mark.broken(
["pandas", "dask"],
raises=AssertionError,
reason="even with explicit cast, results in [None, 1]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -1354,11 +1434,6 @@ def test_unnest_range(con):
id="array",
marks=[
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
raises=TypeError,
),
],
),
],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,12 @@ def query(t, group_cols):
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")


@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
raises=com.OperationNotDefinedError,
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.broken(
["trino"],
reason="invalid code generated for unnesting a struct",
Expand Down
18 changes: 14 additions & 4 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ibis.array([432]),
marks=[
pytest.mark.never(
["mysql", "mssql", "oracle", "impala", "sqlite"],
["exasol", "mysql", "mssql", "oracle", "impala", "sqlite"],
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
reason="arrays not supported in the backend",
),
Expand All @@ -30,8 +30,18 @@
ibis.struct(dict(abc=432)),
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
[
"exasol",
"impala",
"mysql",
"sqlite",
"mssql",
],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
@pytest.mark.notyet(
["datafusion", "exasol", "oracle", "flink", "risingwave"],
reason="no unnest support",
raises=exc.OperationNotDefinedError,
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
)
@pytest.mark.notyet(
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
Expand Down
16 changes: 11 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ClickHouseDatabaseError,
OracleDatabaseError,
PsycoPg2InternalError,
PyDruidProgrammingError,
PyODBCProgrammingError,
)
from ibis.common.annotations import ValidationError
Expand Down Expand Up @@ -835,21 +836,26 @@ def test_capitalize(con, inp, expected):
assert pd.isnull(result)


@pytest.mark.never(
["exasol", "impala", "mssql", "mysql", "sqlite"],
reason="Backend doesn't support arrays",
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(
[
"dask",
"pandas",
"polars",
"oracle",
"flink",
"sqlite",
"mssql",
"mysql",
"exasol",
"impala",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
"druid",
raises=PyDruidProgrammingError,
reason="ibis.array() has a cast, and we compile the dtype to 'VARCHAR[] instead of 'ARRAY<STRING>' as needed",
)
def test_array_string_join(con):
s = ibis.array(["a", "b", "c"])
expected = "a,b,c"
Expand Down
13 changes: 7 additions & 6 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ class Array(Value):
"""Construct an array."""

exprs: VarTuple[Value]
dtype: Optional[dt.Array] = None

@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)
shape = rlz.shape_like("exprs")

@attribute
def dtype(self):
return dt.Array(rlz.highest_precedence_dtype(self.exprs))
def __init__(self, exprs, dtype: dt.Array | None = None):
# If len(exprs) == 0, the caller is responsible for providing a dtype
if dtype is None:
dtype = dt.Array(rlz.highest_precedence_dtype(exprs))
super().__init__(exprs=exprs, dtype=dtype)


@public
Expand Down
4 changes: 4 additions & 0 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from public import public

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
Expand All @@ -16,6 +17,9 @@

@public
def highest_precedence_shape(nodes):
nodes = tuple(nodes)
if len(nodes) == 0:
return ds.scalar
return max(node.shape for node in nodes)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DummyTable
foo: Array([1])
foo: Array(exprs=[1], dtype=array<int8>)
Loading

0 comments on commit f7e4931

Please sign in to comment.