Skip to content

Commit

Permalink
fix(trino,pyspark): improve null handling in array filter (#10448)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
stephen-bowser and cpcloud authored Nov 10, 2024
1 parent 85f0693 commit 860b9ca
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 16 deletions.
7 changes: 2 additions & 5 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,8 @@ def visit_ArrayFilter(self, op, *, arg, body, param, index):
if index is not None:
expressions.append(index)

func = sge.Lambda(this=self.if_(body, param, NULL), expressions=expressions)
transform = self.f.transform(arg, func)

func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=expressions)
return self.f.filter(transform, func)
lamduh = sge.Lambda(this=body, expressions=expressions)
return self.f.filter(arg, lamduh)

def visit_ArrayIndex(self, op, *, arg, index):
return self.f.element_at(arg, index + 1)
Expand Down
51 changes: 40 additions & 11 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,50 @@ def visit_ArrayFilter(self, op, *, arg, param, body, index):
else:
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
return self.f.filter(
self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
# semantics are: arg if predicate(arg, index) else null
this=self.if_(body, param, NULL),
expressions=[param, index],
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
),
expressions=[param, index],
),
)

# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)

# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=placeholder.is_(sg.not_(NULL)), expressions=[placeholder]
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)

Expand Down
45 changes: 45 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,51 @@ def test_array_filter_with_index(con, input, output, predicate):
)


@builtin_array
@pytest.mark.notimpl(
["datafusion", "flink", "polars"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(
["sqlite"], raises=com.UnsupportedBackendType, reason="Unsupported type: Array..."
)
@pytest.mark.parametrize(
("input", "output"),
[
param(
{"a": [[1, None, None], [4]]},
{"a": [[1, None], [4]]},
id="nulls",
marks=[
pytest.mark.notyet(
["bigquery"],
raises=GoogleBadRequest,
reason="NULLs are not allowed as array elements",
)
],
),
param({"a": [[1, 2], [1]]}, {"a": [[1], [1]]}, id="no_nulls"),
],
)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
reason="no support for not null column constraint",
)
@pytest.mark.parametrize(
"predicate",
[lambda x, i: i % 2 == 0, partial(lambda x, y, i: i % 2 == 0, y=1)],
ids=["lambda", "partial"],
)
def test_array_filter_with_index_lambda(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))

expr = t.select(a=t.a.filter(predicate))
result = con.to_pyarrow(expr.a)
assert frozenset(map(tuple, result.to_pylist())) == frozenset(
map(tuple, output["a"])
)


@builtin_array
@pytest.mark.parametrize(
("col", "value"),
Expand Down

0 comments on commit 860b9ca

Please sign in to comment.