Skip to content

Commit

Permalink
feat(api): support literal expressions in array constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 20, 2024
1 parent 1d29e4b commit 3cddbe7
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 44 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _array_concat(translator, op):


def _array_column(translator, op):
return "[{}]".format(", ".join(map(translator.translate, op.cols)))
return "[{}]".format(", ".join(map(translator.translate, op.exprs)))


def _array_index(translator, op):
Expand Down Expand Up @@ -912,7 +912,7 @@ def _timestamp_range(translator, op):
ops.StructColumn: _struct_column,
ops.ArrayCollect: _array_agg,
ops.ArrayConcat: _array_concat,
ops.ArrayColumn: _array_column,
ops.Array: _array_column,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary("ARRAY_LENGTH"),
ops.ArrayRepeat: _array_repeat,
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def _translate(op, *, arg, where, **_):
return _translate


@translate_val.register(ops.ArrayColumn)
def _array_column(op, *, cols, **_):
return F.array(*cols)
@translate_val.register(ops.Array)
def _array_column(op, *, exprs, **_):
return F.array(*exprs)


@translate_val.register(ops.StructColumn)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)


@execute_node.register(ops.ArrayColumn, tuple)
@execute_node.register(ops.Array, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = dd.concat(cols, axis=1)
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ def _not_null(op, *, arg, **_):
return sg.not_(arg.is_(NULL))


@translate_val.register(ops.ArrayColumn)
def array_column(op, *, cols, **_):
return F.make_array(*cols)
@translate_val.register(ops.Array)
def array_column(op, *, exprs, **_):
return F.make_array(*exprs)


@translate_val.register(ops.ArrayRepeat)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _array_remove(t, op):

operation_registry.update(
{
ops.ArrayColumn: (
ops.Array: (
lambda t, op: sa.cast(
sa.func.list_value(*map(t.translate, op.cols)),
sa.func.list_value(*map(t.translate, op.exprs)),
t.get_sqla_type(op.dtype),
)
),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Collection


@execute_node.register(ops.ArrayColumn, tuple)
@execute_node.register(ops.Array, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = pd.concat(cols, axis=1)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,9 @@ def array_concat(op, **kw):
return result


@translate.register(ops.ArrayColumn)
@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.cols]
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def _range(t, op):
# array operations
ops.ArrayLength: unary(sa.func.cardinality),
ops.ArrayCollect: reduction(sa.func.array_agg),
ops.ArrayColumn: (lambda t, op: pg.array(list(map(t.translate, op.cols)))),
ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))),
ops.ArraySlice: _array_slice(
index_converter=_neg_idx_to_pos,
array_length=sa.func.cardinality,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,9 +1634,9 @@ def compile_interval_from_integer(t, op, **kwargs):
# -------------------------- Array Operations ----------------------------


@compiles(ops.ArrayColumn)
@compiles(ops.Array)
def compile_array_column(t, op, **kwargs):
cols = [t.translate(col, **kwargs) for col in op.cols]
cols = [t.translate(col, **kwargs) for col in op.exprs]
return F.array(cols)


Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def _timestamp_range(t, op):
ops.ArrayConcat: varargs(
lambda *args: functools.reduce(sa.func.array_cat, args)
),
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
ops.Array: lambda t, op: sa.func.array_construct(*map(t.translate, op.exprs)),
ops.ArraySlice: _array_slice,
ops.ArrayCollect: reduction(
lambda arg: sa.func.array_agg(
Expand Down
38 changes: 38 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.backends.tests.errors import (
Expand Down Expand Up @@ -1070,3 +1071,40 @@ def test_unnest_range(con):
result = con.execute(expr)
expected = pd.DataFrame({"x": np.array([0, 1], dtype="int8"), "y": [1.0, 1.0]})
tm.assert_frame_equal(result, expected)


@pytest.mark.notyet(["flink"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["pandas"], reason="expression input not supported", raises=TypeError
)
@pytest.mark.broken(
["dask"], reason="expression input not supported", raises=AttributeError
)
@pytest.mark.parametrize(
("input", "expected"),
[
param([1, ibis.literal(2)], [1, 2], id="int-int"),
param([1.0, ibis.literal(2)], [1.0, 2.0], id="float-int"),
param([1.0, ibis.literal(2.0)], [1.0, 2.0], id="float-float"),
param([1, ibis.literal(2.0)], [1.0, 2.0], id="int-float"),
param([ibis.literal(1), ibis.literal(2.0)], [1.0, 2.0], id="int-float-exprs"),
param(
[[1], ibis.literal([2])],
[[1], [2]],
id="array",
marks=[
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
raises=TypeError,
),
],
),
],
)
def test_array_literal_with_exprs(con, input, expected):
expr = ibis.array(input)
assert expr.op().shape == ds.scalar
result = list(con.execute(expr))
assert result == expected
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_map_construct_dict(con, keys, values):
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.arrays.ArrayColumn'>",
reason="No translation rule for <class 'ibis.expr.operations.arrays.Array'>",
)
def test_map_construct_array_column(con, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _group_concat(t, op):
def _array_column(t, op):
args = ", ".join(
str(t.translate(arg).compile(compile_kwargs={"literal_binds": True}))
for arg in op.cols
for arg in op.exprs
)
return sa.literal_column(f"ARRAY[{args}]", type_=t.get_sqla_type(op.dtype))

Expand Down Expand Up @@ -431,7 +431,7 @@ def _range(t, op):
ops.ArrayIndex: fixed_arity(
lambda arg, index: sa.func.element_at(arg, index + 1), 2
),
ops.ArrayColumn: _array_column,
ops.Array: _array_column,
ops.ArrayRepeat: fixed_arity(
lambda arg, times: sa.func.flatten(sa.func.repeat(arg, times)), 2
),
Expand Down
10 changes: 6 additions & 4 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@


@public
class ArrayColumn(Value):
cols: VarTuple[Value]
class Array(Value):
exprs: VarTuple[Value]

shape = ds.columnar
@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)

@attribute
def dtype(self):
return dt.Array(rlz.highest_precedence_dtype(self.cols))
return dt.Array(rlz.highest_precedence_dtype(self.exprs))


@public
Expand Down
19 changes: 2 additions & 17 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,28 +1020,13 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column:
def array(values: Iterable[V], type: str | dt.DataType | None = None) -> ArrayValue:
"""Create an array expression.
If the input expressions are all column expressions, then the output will
be an `ArrayColumn`. The input columns will be concatenated row-wise to
produce each array in the output array column. Each array will have length
_n_, where _n_ is the number of input columns. All input columns should be
of the same datatype.
If the input expressions are Python literals, then the output will be a
single `ArrayScalar` of length _n_, where _n_ is the number of input
values. This is equivalent to
```python
values = [1, 2, 3]
ibis.literal(values)
```
Parameters
----------
values
An iterable of Ibis expressions or a list of Python literals
type
An instance of `ibis.expr.datatypes.DataType` or a string indicating
the ibis type of `value`.
the Ibis type of `value`.
Returns
-------
Expand Down Expand Up @@ -1086,7 +1071,7 @@ def array(values: Iterable[V], type: str | dt.DataType | None = None) -> ArrayVa
└──────────────────────┘
"""
if any(isinstance(value, Value) for value in values):
return ops.ArrayColumn(values).to_expr()
return ops.Array(values).to_expr()
else:
try:
return literal(list(values), type=type)
Expand Down

0 comments on commit 3cddbe7

Please sign in to comment.