Skip to content

Commit

Permalink
feat: support empty arrays, improve ibis.array() API
Browse files Browse the repository at this point in the history
Picking out the array stuff from ibis-project#8666
  • Loading branch information
NickCrews committed Jun 28, 2024
1 parent 33ec754 commit 61e1a2c
Show file tree
Hide file tree
Showing 17 changed files with 194 additions and 65 deletions.
6 changes: 4 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
plan,
)
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
from ibis.formats.numpy import NumpyType
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import gen_name

Expand Down Expand Up @@ -155,9 +156,10 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
np_type = NumpyType.from_ibis(dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
27 changes: 16 additions & 11 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
# There are some interval types that _make_duration() can handle,
# but PolarsType.from_ibis can't, so we need to handle them here.
if dtype.is_interval():
return _make_duration(value, dtype)

typ = PolarsType.from_ibis(dtype)
if value is None:
return pl.lit(None, dtype=typ)
elif dtype.is_array():
return pl.lit(pl.Series("", value).implode(), dtype=typ)
elif dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -974,9 +976,12 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
if op.exprs:
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
else:
return pl.lit([], dtype=pdt)


@translate.register(ops.ArrayCollect)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
return self.cast(self.f.array(*exprs), dtype)

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ WITH "t5" AS (
SELECT
"t0"."field_of_study",
arrayJoin(
[
CAST([
CAST(tuple('1970-71', "t0"."1970-71") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
CAST(tuple('1975-76', "t0"."1975-76") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
CAST(tuple('1980-81', "t0"."1980-81") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
Expand All @@ -45,7 +45,7 @@ WITH "t5" AS (
CAST(tuple('2017-18', "t0"."2017-18") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
CAST(tuple('2018-19', "t0"."2018-19") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
CAST(tuple('2019-20', "t0"."2019-20") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64)))
]
] AS Array(Tuple("years" Nullable(String), "degrees" Nullable(Int64))))
) AS "__pivoted__"
FROM "humanities" AS "t0"
) AS "t1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ WITH "t5" AS (
SELECT
"t0"."field_of_study",
UNNEST(
[
CAST([
{'years': '1970-71', 'degrees': "t0"."1970-71"},
{'years': '1975-76', 'degrees': "t0"."1975-76"},
{'years': '1980-81', 'degrees': "t0"."1980-81"},
Expand All @@ -45,7 +45,7 @@ WITH "t5" AS (
{'years': '2017-18', 'degrees': "t0"."2017-18"},
{'years': '2018-19', 'degrees': "t0"."2018-19"},
{'years': '2019-20', 'degrees': "t0"."2019-20"}
]
] AS STRUCT("years" TEXT, "degrees" BIGINT)[])
) AS "__pivoted__"
FROM "humanities" AS "t0"
) AS "t1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ WITH "t5" AS (
SELECT
"t0"."field_of_study",
UNNEST(
ARRAY[ROW(CAST('1970-71' AS VARCHAR), CAST("t0"."1970-71" AS BIGINT)), ROW(CAST('1975-76' AS VARCHAR), CAST("t0"."1975-76" AS BIGINT)), ROW(CAST('1980-81' AS VARCHAR), CAST("t0"."1980-81" AS BIGINT)), ROW(CAST('1985-86' AS VARCHAR), CAST("t0"."1985-86" AS BIGINT)), ROW(CAST('1990-91' AS VARCHAR), CAST("t0"."1990-91" AS BIGINT)), ROW(CAST('1995-96' AS VARCHAR), CAST("t0"."1995-96" AS BIGINT)), ROW(CAST('2000-01' AS VARCHAR), CAST("t0"."2000-01" AS BIGINT)), ROW(CAST('2005-06' AS VARCHAR), CAST("t0"."2005-06" AS BIGINT)), ROW(CAST('2010-11' AS VARCHAR), CAST("t0"."2010-11" AS BIGINT)), ROW(CAST('2011-12' AS VARCHAR), CAST("t0"."2011-12" AS BIGINT)), ROW(CAST('2012-13' AS VARCHAR), CAST("t0"."2012-13" AS BIGINT)), ROW(CAST('2013-14' AS VARCHAR), CAST("t0"."2013-14" AS BIGINT)), ROW(CAST('2014-15' AS VARCHAR), CAST("t0"."2014-15" AS BIGINT)), ROW(CAST('2015-16' AS VARCHAR), CAST("t0"."2015-16" AS BIGINT)), ROW(CAST('2016-17' AS VARCHAR), CAST("t0"."2016-17" AS BIGINT)), ROW(CAST('2017-18' AS VARCHAR), CAST("t0"."2017-18" AS BIGINT)), ROW(CAST('2018-19' AS VARCHAR), CAST("t0"."2018-19" AS BIGINT)), ROW(CAST('2019-20' AS VARCHAR), CAST("t0"."2019-20" AS BIGINT))]
CAST(ARRAY[ROW(CAST('1970-71' AS VARCHAR), CAST("t0"."1970-71" AS BIGINT)), ROW(CAST('1975-76' AS VARCHAR), CAST("t0"."1975-76" AS BIGINT)), ROW(CAST('1980-81' AS VARCHAR), CAST("t0"."1980-81" AS BIGINT)), ROW(CAST('1985-86' AS VARCHAR), CAST("t0"."1985-86" AS BIGINT)), ROW(CAST('1990-91' AS VARCHAR), CAST("t0"."1990-91" AS BIGINT)), ROW(CAST('1995-96' AS VARCHAR), CAST("t0"."1995-96" AS BIGINT)), ROW(CAST('2000-01' AS VARCHAR), CAST("t0"."2000-01" AS BIGINT)), ROW(CAST('2005-06' AS VARCHAR), CAST("t0"."2005-06" AS BIGINT)), ROW(CAST('2010-11' AS VARCHAR), CAST("t0"."2010-11" AS BIGINT)), ROW(CAST('2011-12' AS VARCHAR), CAST("t0"."2011-12" AS BIGINT)), ROW(CAST('2012-13' AS VARCHAR), CAST("t0"."2012-13" AS BIGINT)), ROW(CAST('2013-14' AS VARCHAR), CAST("t0"."2013-14" AS BIGINT)), ROW(CAST('2014-15' AS VARCHAR), CAST("t0"."2014-15" AS BIGINT)), ROW(CAST('2015-16' AS VARCHAR), CAST("t0"."2015-16" AS BIGINT)), ROW(CAST('2016-17' AS VARCHAR), CAST("t0"."2016-17" AS BIGINT)), ROW(CAST('2017-18' AS VARCHAR), CAST("t0"."2017-18" AS BIGINT)), ROW(CAST('2018-19' AS VARCHAR), CAST("t0"."2018-19" AS BIGINT)), ROW(CAST('2019-20' AS VARCHAR), CAST("t0"."2019-20" AS BIGINT))] AS STRUCT<"years" VARCHAR, "degrees" BIGINT>[])
) AS "__pivoted__"
FROM "humanities" AS "t0"
) AS "t1"
Expand Down
Loading

0 comments on commit 61e1a2c

Please sign in to comment.