Skip to content

Commit

Permalink
fix(snowflake): make semantics of array filtering match everything else
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 11, 2024
1 parent 983cd5d commit 6d90f18
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,31 +800,50 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
null_filter_arg = self.f.get(param, "$1")
# extract the field we care about
placeholder = sg.to_identifier("__ibis_snowflake_arg__")
post_process = lambda arg: self.f.transform(
if index is None:
return self.f.filter(
arg,
# nulls are considered false when they are returned from a
# `filter` predicate
#
# we're using is_null_value here because snowflake
# automatically converts embedded SQL NULLs to JSON nulls in
# higher order functions
sge.Lambda(
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
this=sg.and_(sg.not_(self.f.is_null_value(param)), body),
expressions=[param],
),
)
else:
null_filter_arg = param
post_process = lambda arg: arg

# null_filter is necessary otherwise null values are treated as JSON
# nulls instead of SQL NULLs
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))
zipped = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
# extract the field we care about
keeps = self.f.transform(
zipped,
sge.Lambda(
this=self.f.object_construct_keep_null(
"keep", body, "value", self.f.get(param, "$1")
),
expressions=[param],
),
)

return post_process(
self.f.filter(
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
)
# then, filter out elements that are null
placeholder1 = sg.to_identifier("__f1__")
placeholder2 = sg.to_identifier("__f2__")
filtered = self.f.filter(
keeps,
sge.Lambda(
this=self.cast(self.f.get(placeholder1, "keep"), dt.boolean),
expressions=[placeholder1],
),
)
return self.f.transform(
filtered,
sge.Lambda(
this=self.f.get(placeholder2, "value"), expressions=[placeholder2]
),
)

def visit_JoinLink(self, op, *, how, table, predicates):
Expand Down

0 comments on commit 6d90f18

Please sign in to comment.