From 46d80125e449ac53f879ca317da0de28725fa47b Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Thu, 5 Sep 2024 15:24:34 -0700 Subject: [PATCH] refactor: Simplify projection nodes (#961) --- bigframes/core/__init__.py | 77 +++++++++--------------------- bigframes/core/blocks.py | 25 ++++++---- bigframes/core/compile/compiled.py | 15 +++++- bigframes/core/compile/compiler.py | 5 ++ bigframes/core/expression.py | 29 ++++++++--- bigframes/core/nodes.py | 31 +++++++++++- bigframes/core/ordering.py | 2 +- bigframes/core/rewrite.py | 39 ++++++++++++--- bigframes/session/executor.py | 4 +- bigframes/session/planner.py | 12 ++++- tests/unit/test_planner.py | 16 +++---- 11 files changed, 164 insertions(+), 91 deletions(-) diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index f3c75f7143..f65509e5b7 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -192,20 +192,15 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue: ) def project_to_id(self, expression: ex.Expression, output_id: str): - if output_id in self.column_ids: # Mutate case - exprs = [ - ((expression if (col_id == output_id) else ex.free_var(col_id)), col_id) - for col_id in self.column_ids - ] - else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (expression, output_id)] return ArrayValue( nodes.ProjectionNode( child=self.node, - assignments=tuple(exprs), + assignments=( + ( + expression, + output_id, + ), + ), ) ) @@ -213,28 +208,22 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue: if destination_id in self.column_ids: # Mutate case exprs = [ ( - ( - ex.free_var(source_id) - if (col_id == destination_id) - else ex.free_var(col_id) - ), + (source_id if (col_id == destination_id) else col_id), col_id, ) for col_id in self.column_ids ] else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (ex.free_var(source_id), destination_id)] + self_projection = ((col_id, col_id) for col_id in self.column_ids) + exprs = [*self_projection, (source_id, destination_id)] return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(exprs), + input_output_pairs=tuple(exprs), ) ) - def assign_constant( + def create_constant( self, destination_id: str, value: typing.Any, @@ -244,49 +233,31 @@ def assign_constant( # Need to assign a data type when value is NaN. dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE - if destination_id in self.column_ids: # Mutate case - exprs = [ - ( - ( - ex.const(value, dtype) - if (col_id == destination_id) - else ex.free_var(col_id) - ), - col_id, - ) - for col_id in self.column_ids - ] - else: # append case - self_projection = ( - (ex.free_var(col_id), col_id) for col_id in self.column_ids - ) - exprs = [*self_projection, (ex.const(value, dtype), destination_id)] return ArrayValue( nodes.ProjectionNode( child=self.node, - assignments=tuple(exprs), + assignments=((ex.const(value, dtype), destination_id),), ) ) def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue: - selections = ((ex.free_var(col_id), col_id) for col_id in column_ids) + # This basically just drops and reorders columns - logically a no-op except as a final step + selections = ((col_id, col_id) for col_id in column_ids) return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(selections), + input_output_pairs=tuple(selections), ) ) def drop_columns(self, columns: Iterable[str]) -> ArrayValue: new_projection = ( - (ex.free_var(col_id), col_id) - for col_id in self.column_ids - if col_id not in columns + (col_id, col_id) for col_id in self.column_ids if col_id not in columns ) return ArrayValue( - nodes.ProjectionNode( + nodes.SelectionNode( child=self.node, - assignments=tuple(new_projection), + input_output_pairs=tuple(new_projection), ) ) @@ -422,15 +393,13 @@ def unpivot( col_expr = ops.case_when_op.as_expr(*cases) unpivot_exprs.append((col_expr, col_id)) - label_exprs = ((ex.free_var(id), id) for id in index_col_ids) - # passthrough columns are unchanged, just repeated N times each - passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns) + unpivot_col_ids = [id for id, _ in unpivot_columns] return ArrayValue( nodes.ProjectionNode( child=joined_array.node, - assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs), + assignments=(*unpivot_exprs,), ) - ) + ).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns]) def _cross_join_w_labels( self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"] diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index a309671842..d7df7801bc 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -939,7 +939,7 @@ def multi_apply_unary_op( for col_id in columns: label = self.col_id_to_label[col_id] block, result_id = block.project_expr( - expr.bind_all_variables({input_varname: ex.free_var(col_id)}), + expr.bind_variables({input_varname: ex.free_var(col_id)}), label=label, ) block = block.copy_values(result_id, col_id) @@ -1006,7 +1006,7 @@ def create_constant( dtype: typing.Optional[bigframes.dtypes.Dtype] = None, ) -> typing.Tuple[Block, str]: result_id = guid.generate_guid() - expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype) + expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype) # Create index copy with label inserted # See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html labels = self.column_labels.insert(len(self.column_labels), label) @@ -1067,7 +1067,7 @@ def aggregate_all_and_stack( index_id = guid.generate_guid() result_expr = self.expr.aggregate( aggregations, dropna=dropna - ).assign_constant(index_id, None, None) + ).create_constant(index_id, None, None) # Transpose as last operation so that final block has valid transpose cache return Block( result_expr, @@ -1222,7 +1222,7 @@ def aggregate( names: typing.List[Label] = [] if len(by_column_ids) == 0: label_id = guid.generate_guid() - result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype()) + result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype()) index_columns = (label_id,) names = [None] else: @@ -1614,17 +1614,22 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block: axis_number = utils.get_axis_number("rows" if (axis is None) else axis) if axis_number == 0: expr = self._expr + new_index_cols = [] for index_col in self._index_columns: + new_col = guid.generate_guid() expr = expr.project_to_id( expression=ops.add_op.as_expr( ex.const(prefix), ops.AsTypeOp(to_type="string").as_expr(index_col), ), - output_id=index_col, + output_id=new_col, ) + new_index_cols.append(new_col) + expr = expr.select_columns((*new_index_cols, *self.value_columns)) + return Block( expr, - index_columns=self.index_columns, + index_columns=new_index_cols, column_labels=self.column_labels, index_labels=self.index.names, ) @@ -1635,17 +1640,21 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block: axis_number = utils.get_axis_number("rows" if (axis is None) else axis) if axis_number == 0: expr = self._expr + new_index_cols = [] for index_col in self._index_columns: + new_col = guid.generate_guid() expr = expr.project_to_id( expression=ops.add_op.as_expr( ops.AsTypeOp(to_type="string").as_expr(index_col), ex.const(suffix), ), - output_id=index_col, + output_id=new_col, ) + new_index_cols.append(new_col) + expr = expr.select_columns((*new_index_cols, *self.value_columns)) return Block( expr, - index_columns=self.index_columns, + index_columns=new_index_cols, column_labels=self.column_labels, index_labels=self.index.names, ) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 512238440c..9a9f598e89 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -134,10 +134,23 @@ def projection( ) -> T: """Apply an expression to the ArrayValue and assign the output to a column.""" bindings = {col: self._get_ibis_column(col) for col in self.column_ids} - values = [ + new_values = [ op_compiler.compile_expression(expression, bindings).name(id) for expression, id in expression_id_pairs ] + result = self._select(tuple([*self._columns, *new_values])) # type: ignore + return result + + def selection( + self: T, + input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...], + ) -> T: + """Apply an expression to the ArrayValue and assign the output to a column.""" + bindings = {col: self._get_ibis_column(col) for col in self.column_ids} + values = [ + op_compiler.compile_expression(ex.free_var(input), bindings).name(id) + for input, id in input_output_pairs + ] result = self._select(tuple(values)) # type: ignore return result diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index 3fedf5c0c8..80d5f5a893 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True): else: return self.compile_unordered_ir(node.child) + @_compile_node.register + def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True): + result = self.compile_node(node.child, ordered) + return result.selection(node.input_output_pairs) + @_compile_node.register def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True): result = self.compile_node(node.child, ordered) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index c216c29717..bbd23b689c 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -110,8 +110,13 @@ def output_type( ... @abc.abstractmethod - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: - """Replace all variables with expression given in `bindings`.""" + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: + """Replace variables with expression given in `bindings`. + + If check_bind_all is True, validate that all free variables are bound to a new value. + """ ... @property @@ -141,7 +146,9 @@ def output_type( ) -> dtypes.ExpressionType: return self.dtype - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: return self @property @@ -178,11 +185,14 @@ def output_type( else: raise ValueError(f"Type of variable {self.id} has not been fixed.") - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: if self.id in bindings.keys(): return bindings[self.id] - else: + elif check_bind_all: raise ValueError(f"Variable {self.id} remains unbound") + return self @property def is_bijective(self) -> bool: @@ -225,10 +235,15 @@ def output_type( ) return self.op.output_type(*operand_types) - def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: + def bind_variables( + self, bindings: Mapping[str, Expression], check_bind_all: bool = True + ) -> Expression: return OpExpression( self.op, - tuple(input.bind_all_variables(bindings) for input in self.inputs), + tuple( + input.bind_variables(bindings, check_bind_all=check_bind_all) + for input in self.inputs + ), ) @property diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 73780719a9..27e76c7910 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -622,8 +622,32 @@ def relation_ops_created(self) -> int: return 0 +@dataclass(frozen=True) +class SelectionNode(UnaryNode): + input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...] + + def __hash__(self): + return self._node_hash + + @functools.cached_property + def schema(self) -> schemata.ArraySchema: + input_types = self.child.schema._mapping + items = tuple( + schemata.SchemaItem(output, input_types[input]) + for input, output in self.input_output_pairs + ) + return schemata.ArraySchema(items) + + @property + def variables_introduced(self) -> int: + # This operation only renames variables, doesn't actually create new ones + return 0 + + @dataclass(frozen=True) class ProjectionNode(UnaryNode): + """Assigns new variables (without modifying existing ones)""" + assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...] def __post_init__(self): @@ -631,6 +655,8 @@ def __post_init__(self): for expression, id in self.assignments: # throws TypeError if invalid _ = expression.output_type(input_types) + # Cannot assign to existing variables - append only! + assert all(name not in self.child.schema.names for _, name in self.assignments) def __hash__(self): return self._node_hash @@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema: ) for ex, id in self.assignments ) - return schemata.ArraySchema(items) + schema = self.child.schema + for item in items: + schema = schema.append(item) + return schema @property def variables_introduced(self) -> int: diff --git a/bigframes/core/ordering.py b/bigframes/core/ordering.py index bff7e2ce44..a57d7a18d6 100644 --- a/bigframes/core/ordering.py +++ b/bigframes/core/ordering.py @@ -63,7 +63,7 @@ def bind_variables( self, mapping: Mapping[str, expression.Expression] ) -> OrderingExpression: return OrderingExpression( - self.scalar_expression.bind_all_variables(mapping), + self.scalar_expression.bind_variables(mapping), self.direction, self.na_last, ) diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index 60ed4069a9..0e73166ea5 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -27,6 +27,7 @@ Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...] REWRITABLE_NODE_TYPES = ( + nodes.SelectionNode, nodes.ProjectionNode, nodes.FilterNode, nodes.ReversedNode, @@ -54,7 +55,12 @@ def from_node_span( for id in get_node_column_ids(node) ) return cls(node, selection, None, ()) - if isinstance(node, nodes.ProjectionNode): + + if isinstance(node, nodes.SelectionNode): + return cls.from_node_span(node.child, target).select( + node.input_output_pairs + ) + elif isinstance(node, nodes.ProjectionNode): return cls.from_node_span(node.child, target).project(node.assignments) elif isinstance(node, nodes.FilterNode): return cls.from_node_span(node.child, target).filter(node.predicate) @@ -69,22 +75,39 @@ def from_node_span( def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]: return {col_id: expr for expr, col_id in self.columns} + def select(self, input_output_pairs: Tuple[Tuple[str, str], ...]) -> SquashedSelect: + new_columns = tuple( + ( + scalar_exprs.free_var(input).bind_variables(self.column_lookup), + output, + ) + for input, output in input_output_pairs + ) + return SquashedSelect( + self.root, new_columns, self.predicate, self.ordering, self.reverse_root + ) + def project( self, projection: Tuple[Tuple[scalar_exprs.Expression, str], ...] ) -> SquashedSelect: + existing_columns = self.columns new_columns = tuple( - (expr.bind_all_variables(self.column_lookup), id) for expr, id in projection + (expr.bind_variables(self.column_lookup), id) for expr, id in projection ) return SquashedSelect( - self.root, new_columns, self.predicate, self.ordering, self.reverse_root + self.root, + (*existing_columns, *new_columns), + self.predicate, + self.ordering, + self.reverse_root, ) def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect: if self.predicate is None: - new_predicate = predicate.bind_all_variables(self.column_lookup) + new_predicate = predicate.bind_variables(self.column_lookup) else: new_predicate = ops.and_op.as_expr( - self.predicate, predicate.bind_all_variables(self.column_lookup) + self.predicate, predicate.bind_variables(self.column_lookup) ) return SquashedSelect( self.root, self.columns, new_predicate, self.ordering, self.reverse_root @@ -204,7 +227,11 @@ def expand(self) -> nodes.BigFrameNode: root = nodes.FilterNode(child=root, predicate=self.predicate) if self.ordering: root = nodes.OrderByNode(child=root, by=self.ordering) - return nodes.ProjectionNode(child=root, assignments=self.columns) + selection = tuple((id, id) for _, id in self.columns) + return nodes.SelectionNode( + child=nodes.ProjectionNode(child=root, assignments=self.columns), + input_output_pairs=selection, + ) def join_as_projection( diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 72d5493294..424e6d7dad 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -457,9 +457,7 @@ def generate_head_plan(node: nodes.BigFrameNode, n: int): predicate = ops.lt_op.as_expr(ex.free_var(offsets_id), ex.const(n)) plan_w_head = nodes.FilterNode(plan_w_offsets, predicate) # Finally, drop the offsets column - return nodes.ProjectionNode( - plan_w_head, tuple((ex.free_var(i), i) for i in node.schema.names) - ) + return nodes.SelectionNode(plan_w_head, tuple((i, i) for i in node.schema.names)) def generate_row_count_plan(node: nodes.BigFrameNode): diff --git a/bigframes/session/planner.py b/bigframes/session/planner.py index 2a74521b43..bc640ec9fa 100644 --- a/bigframes/session/planner.py +++ b/bigframes/session/planner.py @@ -33,7 +33,7 @@ def session_aware_cache_plan( """ node_counts = traversals.count_nodes(session_forest) # These node types are cheap to re-compute, so it makes more sense to cache their children. - de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode) + de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode, nodes.SelectionNode) caching_target = cur_node = root caching_target_refs = node_counts.get(caching_target, 0) @@ -49,7 +49,15 @@ def session_aware_cache_plan( # Projection defines the variables that are used in the filter expressions, need to substitute variables with their scalar expressions # that instead reference variables in the child node. bindings = {name: expr for expr, name in cur_node.assignments} - filters = [i.bind_all_variables(bindings) for i in filters] + filters = [ + i.bind_variables(bindings, check_bind_all=False) for i in filters + ] + elif isinstance(cur_node, nodes.SelectionNode): + bindings = { + output: ex.free_var(input) + for input, output in cur_node.input_output_pairs + } + filters = [i.bind_variables(bindings) for i in filters] else: raise ValueError(f"Unexpected de-cached node: {cur_node}") diff --git a/tests/unit/test_planner.py b/tests/unit/test_planner.py index 2e276d0f1a..84dd05ddaa 100644 --- a/tests/unit/test_planner.py +++ b/tests/unit/test_planner.py @@ -46,8 +46,8 @@ def test_session_aware_caching_project_filter(): """ Test that if a node is filtered by a column, the node is cached pre-filter and clustered by the filter column. """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] - target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.create_constant("col_c", 4, pd.Int64Dtype()).filter( ops.gt_op.as_expr("col_a", ex.const(3)) ) result, cluster_cols = planner.session_aware_cache_plan( @@ -61,14 +61,14 @@ def test_session_aware_caching_project_multi_filter(): """ Test that if a node is filtered by multiple columns, all of them are in the cluster cols """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] predicate_1a = ops.gt_op.as_expr("col_a", ex.const(3)) predicate_1b = ops.lt_op.as_expr("col_a", ex.const(55)) predicate_1 = ops.and_op.as_expr(predicate_1a, predicate_1b) predicate_3 = ops.eq_op.as_expr("col_b", ex.const(1)) target = ( LEAF.filter(predicate_1) - .assign_constant("col_c", 4, pd.Int64Dtype()) + .create_constant("col_c", 4, pd.Int64Dtype()) .filter(predicate_3) ) result, cluster_cols = planner.session_aware_cache_plan( @@ -84,8 +84,8 @@ def test_session_aware_caching_unusable_filter(): Most filters with multiple column references cannot be used for scan pruning, as they cannot be converted to fixed value ranges. """ - session_objects = [LEAF, LEAF.assign_constant("col_c", 4, pd.Int64Dtype())] - target = LEAF.assign_constant("col_c", 4, pd.Int64Dtype()).filter( + session_objects = [LEAF, LEAF.create_constant("col_c", 4, pd.Int64Dtype())] + target = LEAF.create_constant("col_c", 4, pd.Int64Dtype()).filter( ops.gt_op.as_expr("col_a", "col_b") ) result, cluster_cols = planner.session_aware_cache_plan( @@ -101,12 +101,12 @@ def test_session_aware_caching_fork_after_window_op(): Windowing is expensive, so caching should always compute the window function, in order to avoid later recomputation. """ - other = LEAF.promote_offsets("offsets_col").assign_constant( + other = LEAF.promote_offsets("offsets_col").create_constant( "col_d", 5, pd.Int64Dtype() ) target = ( LEAF.promote_offsets("offsets_col") - .assign_constant("col_c", 4, pd.Int64Dtype()) + .create_constant("col_c", 4, pd.Int64Dtype()) .filter( ops.eq_op.as_expr("col_a", ops.add_op.as_expr(ex.const(4), ex.const(3))) )