Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): improve ibis.array() #9458

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
plan,
)
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
from ibis.formats.numpy import NumpyType
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import gen_name

Expand Down Expand Up @@ -155,9 +156,10 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
np_type = NumpyType.from_ibis(dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
27 changes: 16 additions & 11 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
# There are some interval types that _make_duration() can handle,
# but PolarsType.from_ibis can't, so we need to handle them here.
if dtype.is_interval():
return _make_duration(value, dtype)

typ = PolarsType.from_ibis(dtype)
if value is None:
return pl.lit(None, dtype=typ)
elif dtype.is_array():
return pl.lit(pl.Series("", value).implode(), dtype=typ)
elif dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -974,9 +976,12 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
if op.exprs:
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
else:
return pl.lit([], dtype=pdt)


@translate.register(ops.ArrayCollect)
Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,20 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
if not exprs:
return self.cast(self.f.array(), dtype)

def maybe_cast(ibis_val, sg_expr):
if ibis_val.dtype == dtype.value_type:
return sg_expr
else:
return self.cast(sg_expr, dtype.value_type)

cast_exprs = [
maybe_cast(ibis_val, sg_expr) for ibis_val, sg_expr in zip(op.exprs, exprs)
]
return self.f.array(*cast_exprs)

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
85 changes: 80 additions & 5 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -72,6 +73,85 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


@pytest.mark.broken(
["pandas", "dask"],
raises=AssertionError,
reason="results in [1, 2, 3]",
)
def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.notyet(
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
@pytest.mark.broken(
["pandas", "dask"],
raises=AssertionError,
reason="even with explicit cast, results in [None, 1]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -1354,11 +1434,6 @@ def test_unnest_range(con):
id="array",
marks=[
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
raises=TypeError,
),
],
),
],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,12 @@ def query(t, group_cols):
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")


@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
raises=com.OperationNotDefinedError,
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.broken(
["trino"],
reason="invalid code generated for unnesting a struct",
Expand Down
18 changes: 14 additions & 4 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ibis.array([432]),
marks=[
pytest.mark.never(
["mysql", "mssql", "oracle", "impala", "sqlite"],
["exasol", "mysql", "mssql", "oracle", "impala", "sqlite"],
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
reason="arrays not supported in the backend",
),
Expand All @@ -30,8 +30,18 @@
ibis.struct(dict(abc=432)),
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
[
"exasol",
"impala",
"mysql",
"sqlite",
"mssql",
],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
@pytest.mark.notyet(
["datafusion", "exasol", "oracle", "flink", "risingwave"],
reason="no unnest support",
raises=exc.OperationNotDefinedError,
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
)
@pytest.mark.notyet(
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
Expand Down
16 changes: 11 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ClickHouseDatabaseError,
OracleDatabaseError,
PsycoPg2InternalError,
PyDruidProgrammingError,
PyODBCProgrammingError,
)
from ibis.common.annotations import ValidationError
Expand Down Expand Up @@ -835,21 +836,26 @@ def test_capitalize(con, inp, expected):
assert pd.isnull(result)


@pytest.mark.never(
["exasol", "impala", "mssql", "mysql", "sqlite"],
reason="Backend doesn't support arrays",
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(
[
"dask",
"pandas",
"polars",
"oracle",
"flink",
"sqlite",
"mssql",
"mysql",
"exasol",
"impala",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
"druid",
raises=PyDruidProgrammingError,
reason="ibis.array() has a cast, and we compile the dtype to 'VARCHAR[] instead of 'ARRAY<STRING>' as needed",
)
def test_array_string_join(con):
s = ibis.array(["a", "b", "c"])
expected = "a,b,c"
Expand Down
13 changes: 7 additions & 6 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ class Array(Value):
"""Construct an array."""

exprs: VarTuple[Value]
dtype: Optional[dt.Array] = None

@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)
shape = rlz.shape_like("exprs")

@attribute
def dtype(self):
return dt.Array(rlz.highest_precedence_dtype(self.exprs))
def __init__(self, exprs, dtype: dt.Array | None = None):
# If len(exprs) == 0, the caller is responsible for providing a dtype
if dtype is None:
dtype = dt.Array(rlz.highest_precedence_dtype(exprs))
super().__init__(exprs=exprs, dtype=dtype)


@public
Expand Down
4 changes: 4 additions & 0 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from public import public

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
Expand All @@ -16,6 +17,9 @@

@public
def highest_precedence_shape(nodes):
nodes = tuple(nodes)
if len(nodes) == 0:
return ds.scalar
return max(node.shape for node in nodes)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DummyTable
foo: Array([1])
foo: Array(exprs=[1], dtype=array<int8>)
Loading
Loading