Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(snowflake): make semantics of array filtering match everything else #10469

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,31 +800,50 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
null_filter_arg = self.f.get(param, "$1")
# extract the field we care about
placeholder = sg.to_identifier("__ibis_snowflake_arg__")
post_process = lambda arg: self.f.transform(
if index is None:
return self.f.filter(
arg,
# nulls are considered false when they are returned from a
# `filter` predicate
#
# we're using is_null_value here because snowflake
# automatically converts embedded SQL NULLs to JSON nulls in
# higher order functions
sge.Lambda(
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
this=sg.and_(sg.not_(self.f.is_null_value(param)), body),
expressions=[param],
),
)
else:
null_filter_arg = param
post_process = lambda arg: arg

# null_filter is necessary otherwise null values are treated as JSON
# nulls instead of SQL NULLs
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))
zipped = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
# extract the field we care about
keeps = self.f.transform(
zipped,
sge.Lambda(
this=self.f.object_construct_keep_null(
"keep", body, "value", self.f.get(param, "$1")
),
expressions=[param],
),
)

return post_process(
self.f.filter(
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
)
# then, filter out elements that are null
placeholder1 = sg.to_identifier("__f1__")
placeholder2 = sg.to_identifier("__f2__")
filtered = self.f.filter(
keeps,
sge.Lambda(
this=self.cast(self.f.get(placeholder1, "keep"), dt.boolean),
expressions=[placeholder1],
),
)
return self.f.transform(
filtered,
sge.Lambda(
this=self.f.get(placeholder2, "value"), expressions=[placeholder2]
),
)

def visit_JoinLink(self, op, *, how, table, predicates):
Expand Down
89 changes: 45 additions & 44 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,57 +172,58 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
)

def visit_ArrayFilter(self, op, *, arg, param, body, index):
# no index, life is simpler
if index is None:
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
else:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no changes here except dedenting.

placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
expressions=[param, index],
),
)
expressions=[param, index],
),
)

# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)
# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)

# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)
# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)

def visit_ArrayContains(self, op, *, arg, other):
return self.if_(
Expand Down