Skip to content

Commit

Permalink
fix(dask): don't call compute when executing argmin/argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Apr 18, 2024
1 parent a876c47 commit 1204c56
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def limit_df(
return df[df[col].between(offset, offset + n - 1)]


def argminmax_chunk(df, keycol, valcol, method):
return df.iloc[getattr(df[keycol], method)()]


def argminmax_aggregate(df, keycol, valcol, method):
return df[valcol].iloc[getattr(df[keycol], method)()]


class DaskExecutor(PandasExecutor, DaskUtils):
name = "dask"
kernels = dask_kernels
Expand Down Expand Up @@ -184,30 +192,25 @@ def mapper(df):

@classmethod
def visit(cls, op: ops.ArgMin | ops.ArgMax, arg, key, where):
# TODO(kszucs): raise a warning about triggering compute()?
if isinstance(op, ops.ArgMin):
func = lambda x: x.idxmin()
else:
func = lambda x: x.idxmax()

if where is None:
method = "argmin" if isinstance(op, ops.ArgMin) else "argmax"

def agg(df):
indices = func(df[key.name])
if isinstance(indices, (dd.Series, dd.core.Scalar)):
# to support both aggregating within a group and globally
indices = indices.compute()
return df[arg.name].loc[indices]
else:
def agg(df):
if where is not None:
df = df.where(df[where.name])

def agg(df):
mask = df[where.name]
filtered = df[mask]
indices = func(filtered[key.name])
if isinstance(indices, (dd.Series, dd.core.Scalar)):
# to support both aggregating within a group and globally
indices = indices.compute()
return filtered[arg.name].loc[indices]
if isinstance(df, dd.DataFrame):
return df.reduction(
chunk=argminmax_chunk,
combine=argminmax_chunk,
aggregate=argminmax_aggregate,
meta=op.dtype.to_pandas(),
token=method,
keycol=key.name,
valcol=arg.name,
method=method,
)
else:
return argminmax_aggregate(df, key.name, arg.name, method)

return agg

Expand Down

0 comments on commit 1204c56

Please sign in to comment.