Skip to content

Commit

Permalink
style(trino): dedent
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 11, 2024
1 parent 23c0e81 commit 983cd5d
Showing 1 changed file with 45 additions and 44 deletions.
89 changes: 45 additions & 44 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,57 +172,58 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
)

def visit_ArrayFilter(self, op, *, arg, param, body, index):
# no index, life is simpler
if index is None:
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
else:
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
expressions=[param, index],
),
)
expressions=[param, index],
),
)

# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)
# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)

# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)
# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)

def visit_ArrayContains(self, op, *, arg, other):
return self.if_(
Expand Down

0 comments on commit 983cd5d

Please sign in to comment.