Skip to content

Commit

Permalink
refactor: generalize aggregation to handle 0,1, or 2 inputs (#360)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕
  • Loading branch information
TrevorBergeron authored Feb 5, 2024
1 parent 1866a26 commit 93835ef
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 225 deletions.
17 changes: 13 additions & 4 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def row_count(self) -> ArrayValue:
# Operations
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
predicate = ex.free_var(predicate_id)
predicate: ex.Expression = ex.free_var(predicate_id)
if keep_null:
predicate = ops.fillna_op.as_expr(predicate, ex.const(True))
return self.filter(predicate)
Expand Down Expand Up @@ -241,7 +241,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:

def aggregate(
self,
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.AggregateOp, str]],
aggregations: typing.Sequence[typing.Tuple[ex.Aggregation, str]],
by_column_ids: typing.Sequence[str] = (),
dropna: bool = True,
) -> ArrayValue:
Expand Down Expand Up @@ -270,14 +270,23 @@ def corr_aggregate(
Arguments:
corr_aggregations: left_column_id, right_column_id, output_column_id tuples
"""
aggregations = tuple(
(
ex.BinaryAggregation(
agg_ops.CorrOp(), ex.free_var(agg[0]), ex.free_var(agg[1])
),
agg[2],
)
for agg in corr_aggregations
)
return ArrayValue(
nodes.CorrNode(child=self.node, corr_aggregations=tuple(corr_aggregations))
nodes.AggregateNode(child=self.node, aggregations=aggregations)
)

def project_window_op(
self,
column_name: str,
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
window_spec: WindowSpec,
output_name=None,
*,
Expand Down
36 changes: 23 additions & 13 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def filter(self, column_id: str, keep_null: bool = False):

def aggregate_all_and_stack(
self,
operation: agg_ops.AggregateOp,
operation: agg_ops.UnaryAggregateOp,
*,
axis: int | str = 0,
value_col_id: str = "values",
Expand All @@ -872,7 +872,8 @@ def aggregate_all_and_stack(
axis_n = utils.get_axis_number(axis)
if axis_n == 0:
aggregations = [
(col_id, operation, col_id) for col_id in self.value_columns
(ex.UnaryAggregation(operation, ex.free_var(col_id)), col_id)
for col_id in self.value_columns
]
index_col_ids = [
guid.generate_guid() for i in range(self.column_labels.nlevels)
Expand Down Expand Up @@ -902,10 +903,13 @@ def aggregate_all_and_stack(
dtype=dtype,
)
index_aggregations = [
(col_id, agg_ops.AnyValueOp(), col_id)
(ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.free_var(col_id)), col_id)
for col_id in [*self.index_columns]
]
main_aggregation = (value_col_id, operation, value_col_id)
main_aggregation = (
ex.UnaryAggregation(operation, ex.free_var(value_col_id)),
value_col_id,
)
result_expr = stacked_expr.aggregate(
[*index_aggregations, main_aggregation],
by_column_ids=[offset_col],
Expand Down Expand Up @@ -966,7 +970,7 @@ def remap_f(x):
def aggregate(
self,
by_column_ids: typing.Sequence[str] = (),
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.AggregateOp]] = (),
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.UnaryAggregateOp]] = (),
*,
dropna: bool = True,
) -> typing.Tuple[Block, typing.Sequence[str]]:
Expand All @@ -979,10 +983,13 @@ def aggregate(
dropna: whether null keys should be dropped
"""
agg_specs = [
(input_id, operation, guid.generate_guid())
(
ex.UnaryAggregation(operation, ex.free_var(input_id)),
guid.generate_guid(),
)
for input_id, operation in aggregations
]
output_col_ids = [agg_spec[2] for agg_spec in agg_specs]
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)

aggregate_labels = self._get_labels_for_columns(
Expand All @@ -1004,7 +1011,7 @@ def aggregate(
output_col_ids,
)

def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
def get_stat(self, column_id: str, stat: agg_ops.UnaryAggregateOp):
"""Gets aggregates immediately, and caches it"""
if stat.name in self._stats_cache[column_id]:
return self._stats_cache[column_id][stat.name]
Expand All @@ -1014,7 +1021,10 @@ def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
standard_stats = self._standard_stats(column_id)
stats_to_fetch = standard_stats if stat in standard_stats else [stat]

aggregations = [(column_id, stat, stat.name) for stat in stats_to_fetch]
aggregations = [
(ex.UnaryAggregation(stat, ex.free_var(column_id)), stat.name)
for stat in stats_to_fetch
]
expr = self.expr.aggregate(aggregations)
offset_index_id = guid.generate_guid()
expr = expr.promote_offsets(offset_index_id)
Expand Down Expand Up @@ -1054,13 +1064,13 @@ def get_corr_stat(self, column_id_left: str, column_id_right: str):
def summarize(
self,
column_ids: typing.Sequence[str],
stats: typing.Sequence[agg_ops.AggregateOp],
stats: typing.Sequence[agg_ops.UnaryAggregateOp],
):
"""Get a list of stats as a deferred block object."""
label_col_id = guid.generate_guid()
labels = [stat.name for stat in stats]
aggregations = [
(col_id, stat, f"{col_id}-{stat.name}")
(ex.UnaryAggregation(stat, ex.free_var(col_id)), f"{col_id}-{stat.name}")
for stat in stats
for col_id in column_ids
]
Expand All @@ -1076,7 +1086,7 @@ def summarize(
labels = self._get_labels_for_columns(column_ids)
return Block(expr, column_labels=labels, index_columns=[label_col_id])

def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.UnaryAggregateOp]:
"""
Gets a standard set of stats to preemptively fetch for a column if
any other stat is fetched.
Expand All @@ -1087,7 +1097,7 @@ def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
"""
# TODO: annotate aggregations themself with this information
dtype = self.expr.get_column_type(column_id)
stats: list[agg_ops.AggregateOp] = [agg_ops.count_op]
stats: list[agg_ops.UnaryAggregateOp] = [agg_ops.count_op]
if dtype not in bigframes.dtypes.UNORDERED_DTYPES:
stats += [agg_ops.min_op, agg_ops.max_op]
if dtype in bigframes.dtypes.NUMERIC_BIGFRAMES_TYPES_PERMISSIVE:
Expand Down
Loading

0 comments on commit 93835ef

Please sign in to comment.