Skip to content

Commit

Permalink
feat(api): add support for passing an optional index parameter to arr…
Browse files Browse the repository at this point in the history
…ay map and filter (#10205)

Co-authored-by: Jim Crist-Harif <[email protected]>
  • Loading branch information
cpcloud and jcrist authored Sep 25, 2024
1 parent 1a131a5 commit dfe7c34
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 160 deletions.
12 changes: 8 additions & 4 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,17 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
array = self.f.array_reverse(self.f.array_agg(arg))
return array[self.f.safe_offset(0)]

def visit_ArrayFilter(self, op, *, arg, body, param):
def visit_ArrayFilter(self, op, *, arg, body, param, index):
return self.f.array(
sg.select(param).from_(self._unnest(arg, as_=param)).where(body)
sg.select(param)
.from_(self._unnest(arg, as_=param, offset=index))
.where(body)
)

def visit_ArrayMap(self, op, *, arg, body, param):
return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param)))
def visit_ArrayMap(self, op, *, arg, body, param, index):
return self.f.array(
sg.select(body).from_(self._unnest(arg, as_=param, offset=index))
)

def visit_ArrayZip(self, op, *, arg):
lengths = [self.f.array_length(arr) - 1 for arr in arg]
Expand Down
28 changes: 22 additions & 6 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,13 +597,29 @@ def visit_ExtractQuery(self, op, *, arg, key):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.arrayStringConcat(arg, sep)

def visit_ArrayMap(self, op, *, arg, param, body):
func = sge.Lambda(this=body, expressions=[param])
return self.f.arrayMap(func, arg)
def visit_ArrayMap(self, op, *, arg, param, body, index):
expressions = [param]
args = [arg]

if index is not None:
expressions.append(index)
args.append(self.f.range(0, self.f.length(arg)))

func = sge.Lambda(this=body, expressions=expressions)

return self.f.arrayMap(func, *args)

def visit_ArrayFilter(self, op, *, arg, param, body, index):
expressions = [param]
args = [arg]

if index is not None:
expressions.append(index)
args.append(self.f.range(0, self.f.length(arg)))

func = sge.Lambda(this=body, expressions=expressions)

def visit_ArrayFilter(self, op, *, arg, param, body):
func = sge.Lambda(this=body, expressions=[param])
return self.f.arrayFilter(func, arg)
return self.f.arrayFilter(func, *args)

def visit_ArrayRemove(self, op, *, arg, other):
x = sg.to_identifier(util.gen_name("x"))
Expand Down
24 changes: 19 additions & 5 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import lower_sample
from ibis.backends.sql.rewrites import (
lower_sample,
subtract_one_from_array_map_filter_index,
)
from ibis.util import gen_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,6 +45,7 @@ class DuckDBCompiler(SQLGlotCompiler):
type_mapper = DuckDBType

agg = AggGen(supports_filter=True, supports_order_by=True)
rewrites = (subtract_one_from_array_map_filter_index, *SQLGlotCompiler.rewrites)

supports_qualify = True

Expand Down Expand Up @@ -187,12 +191,22 @@ def visit_ArraySlice(self, op, *, arg, start, stop):

return self.f.list_slice(arg, start + 1, stop)

def visit_ArrayMap(self, op, *, arg, body, param):
lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)])
def visit_ArrayMap(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

lamduh = sge.Lambda(this=body, expressions=expressions)
return self.f.list_apply(arg, lamduh)

def visit_ArrayFilter(self, op, *, arg, body, param):
lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)])
def visit_ArrayFilter(self, op, *, arg, body, param, index):
expressions = [sg.to_identifier(param)]

if index is not None:
expressions.append(sg.to_identifier(index))

lamduh = sge.Lambda(this=body, expressions=expressions)
return self.f.list_filter(arg, lamduh)

def visit_ArrayIntersect(self, op, *, left, right):
Expand Down
26 changes: 21 additions & 5 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import lower_sample, split_select_distinct_with_order_by
from ibis.backends.sql.rewrites import (
lower_sample,
split_select_distinct_with_order_by,
subtract_one_from_array_map_filter_index,
)
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -42,6 +46,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = Postgres
type_mapper = PostgresType
rewrites = (subtract_one_from_array_map_filter_index, *SQLGlotCompiler.rewrites)
post_rewrites = (split_select_distinct_with_order_by,)

agg = AggGen(supports_filter=True, supports_order_by=True)
Expand Down Expand Up @@ -323,16 +328,27 @@ def visit_ArrayContains(self, op, *, arg, other):
expression=self.f.array(self.cast(other, arg_dtype.value_type)),
)

def visit_ArrayFilter(self, op, *, arg, body, param):
def visit_ArrayFilter(self, op, *, arg, body, param, index):
if index is None:
alias = param
else:
alias = sge.TableAlias(this=sg.to_identifier("_"), columns=[param])

return self.f.array(
sg.select(sg.column(param, quoted=self.quoted))
.from_(sge.Unnest(expressions=[arg], alias=param))
.from_(sge.Unnest(expressions=[arg], alias=alias, offset=index))
.where(body)
)

def visit_ArrayMap(self, op, *, arg, body, param):
def visit_ArrayMap(self, op, *, arg, body, param, index):
if index is None:
alias = param
else:
alias = sge.TableAlias(this=sg.to_identifier("_"), columns=[param])
return self.f.array(
sg.select(body).from_(sge.Unnest(expressions=[arg], alias=param))
sg.select(body).from_(
sge.Unnest(expressions=[arg], alias=alias, offset=index)
)
)

def visit_ArrayPosition(self, op, *, arg, other):
Expand Down
23 changes: 16 additions & 7 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,25 @@ def visit_MapGet(self, op, *, arg, key, default):
def visit_ArrayZip(self, op, *, arg):
return self.cast(self.f.arrays_zip(*arg), op.dtype)

def visit_ArrayMap(self, op, *, arg, body, param):
param = sge.Identifier(this=param)
func = sge.Lambda(this=body, expressions=[param])
def visit_ArrayMap(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

func = sge.Lambda(this=body, expressions=expressions)
return self.f.transform(arg, func)

def visit_ArrayFilter(self, op, *, arg, body, param):
param = sge.Identifier(this=param)
func = sge.Lambda(this=self.if_(body, param, NULL), expressions=[param])
def visit_ArrayFilter(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

func = sge.Lambda(this=self.if_(body, param, NULL), expressions=expressions)
transform = self.f.transform(arg, func)
func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=[param])

func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=expressions)
return self.f.filter(transform, func)

def visit_ArrayIndex(self, op, *, arg, index):
Expand Down
79 changes: 65 additions & 14 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,31 @@
lower_log10,
lower_sample,
rewrite_empty_order_by_window,
x,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p


@replace(p.ArrayMap | p.ArrayFilter)
def multiple_args_to_zipped_struct_field_access(_, **kwargs):
# no index argument, so do nothing
if _.index is None:
return _

param = _.param.name

@replace(x @ p.Argument(name=param))
def argument_replacer(_, x, **kwargs):
return ops.StructField(x.copy(dtype=dt.Struct({"$1": _.dtype})), "$1")

@replace(x @ p.Argument(name=_.index.name))
def index_replacer(_, x, **kwargs):
return ops.StructField(
x.copy(name=param, dtype=dt.Struct({"$2": _.dtype})), "$2"
)

return _.copy(body=_.body.replace(argument_replacer | index_replacer))


class SnowflakeFuncGen(FuncGen):
Expand All @@ -54,6 +78,7 @@ class SnowflakeCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_empty_order_by_window,
multiple_args_to_zipped_struct_field_access,
*SQLGlotCompiler.rewrites,
)

Expand Down Expand Up @@ -768,23 +793,49 @@ def visit_TimestampRange(self, op, *, start, stop, step):
.subquery()
)

def visit_ArrayMap(self, op, *, arg, param, body):
def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArrayMap(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(
arg,
sge.Lambda(
this=sg.and_(
body,
# necessary otherwise null values are treated as JSON nulls
# instead of SQL NULLs
self.cast(sg.to_identifier(param), op.dtype.value_type).is_(
sg.not_(NULL)
),
def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
null_filter_arg = self.f.get(param, "$1")
# extract the field we care about
placeholder = sg.to_identifier("__ibis_snowflake_arg__")
post_process = lambda arg: self.f.transform(
arg,
sge.Lambda(
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
),
expressions=[param],
),
)
else:
null_filter_arg = param
post_process = lambda arg: arg

# null_filter is necessary otherwise null values are treated as JSON
# nulls instead of SQL NULLs
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))

return post_process(
self.f.filter(
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
)
)

def visit_JoinLink(self, op, *, how, table, predicates):
Expand Down
36 changes: 32 additions & 4 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,39 @@ def _neg_idx_to_pos(n, idx):

return self.f.slice(arg, start + 1, stop - start)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
def visit_ArrayMap(self, op, *, arg, param, body, index):
if index is None:
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
else:
return self.f.zip_with(
arg,
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(this=body, expressions=[param, index]),
)

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is None:
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
else:
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
return self.f.filter(
self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
# semantics are: arg if predicate(arg, index) else null
this=self.if_(body, param, NULL),
expressions=[param, index],
),
),
# then, filter out elements that are null
sge.Lambda(
this=placeholder.is_(sg.not_(NULL)), expressions=[placeholder]
),
)

def visit_ArrayContains(self, op, *, arg, other):
return self.if_(
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,16 @@ def lower(_, **kwargs):
return _

return lower


@replace(p.ArrayMap | p.ArrayFilter)
def subtract_one_from_array_map_filter_index(_, **kwargs):
# no index argument, so do nothing
if _.index is None:
return _

@replace(y @ p.Argument(name=_.index.name))
def argument_replacer(_, y, **kwargs):
return ops.Subtract(y, 1)

return _.copy(body=_.body.replace(argument_replacer))
Loading

0 comments on commit dfe7c34

Please sign in to comment.