Skip to content

Commit

Permalink
refactor: Simplify projection nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron committed Sep 5, 2024
1 parent 6fdb6b1 commit 6b1c27c
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 76 deletions.
77 changes: 23 additions & 54 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,49 +192,38 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
)

def project_to_id(self, expression: ex.Expression, output_id: str):
if output_id in self.column_ids: # Mutate case
exprs = [
((expression if (col_id == output_id) else ex.free_var(col_id)), col_id)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (expression, output_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=(
(
expression,
output_id,
),
),
)
)

def assign(self, source_id: str, destination_id: str) -> ArrayValue:
if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.free_var(source_id)
if (col_id == destination_id)
else ex.free_var(col_id)
),
(source_id if (col_id == destination_id) else col_id),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.free_var(source_id), destination_id)]
self_projection = ((col_id, col_id) for col_id in self.column_ids)
exprs = [*self_projection, (source_id, destination_id)]
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(exprs),
input_output_pairs=tuple(exprs),
)
)

def assign_constant(
def create_constant(
self,
destination_id: str,
value: typing.Any,
Expand All @@ -244,49 +233,31 @@ def assign_constant(
# Need to assign a data type when value is NaN.
dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE

if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.const(value, dtype)
if (col_id == destination_id)
else ex.free_var(col_id)
),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.const(value, dtype), destination_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=((ex.const(value, dtype), destination_id),),
)
)

def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
# This basically just drops and reorders columns - logically a no-op except as a final step
selections = ((col_id, col_id) for col_id in column_ids)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(selections),
input_output_pairs=tuple(selections),
)
)

def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
new_projection = (
(ex.free_var(col_id), col_id)
for col_id in self.column_ids
if col_id not in columns
(col_id, col_id) for col_id in self.column_ids if col_id not in columns
)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(new_projection),
input_output_pairs=tuple(new_projection),
)
)

Expand Down Expand Up @@ -422,15 +393,13 @@ def unpivot(
col_expr = ops.case_when_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))

label_exprs = ((ex.free_var(id), id) for id in index_col_ids)
# passthrough columns are unchanged, just repeated N times each
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
unpivot_col_ids = [id for id, _ in unpivot_columns]
return ArrayValue(
nodes.ProjectionNode(
child=joined_array.node,
assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs),
assignments=(*unpivot_exprs,),
)
)
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns])

def _cross_join_w_labels(
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]
Expand Down
19 changes: 14 additions & 5 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def create_constant(
dtype: typing.Optional[bigframes.dtypes.Dtype] = None,
) -> typing.Tuple[Block, str]:
result_id = guid.generate_guid()
expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype)
expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype)
# Create index copy with label inserted
# See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html
labels = self.column_labels.insert(len(self.column_labels), label)
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def aggregate_all_and_stack(
index_id = guid.generate_guid()
result_expr = self.expr.aggregate(
aggregations, dropna=dropna
).assign_constant(index_id, None, None)
).create_constant(index_id, None, None)
# Transpose as last operation so that final block has valid transpose cache
return Block(
result_expr,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def aggregate(
names: typing.List[Label] = []
if len(by_column_ids) == 0:
label_id = guid.generate_guid()
result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype())
result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype())
index_columns = (label_id,)
names = [None]
else:
Expand Down Expand Up @@ -1614,14 +1614,19 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ex.const(prefix),
ops.AsTypeOp(to_type="string").as_expr(index_col),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))

return Block(
expr,
index_columns=self.index_columns,
Expand All @@ -1635,14 +1640,18 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ops.AsTypeOp(to_type="string").as_expr(index_col),
ex.const(suffix),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))
return Block(
expr,
index_columns=self.index_columns,
Expand Down
16 changes: 15 additions & 1 deletion bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,24 @@ def projection(
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
new_values = [
op_compiler.compile_expression(expression, bindings).name(id)
for expression, id in expression_id_pairs
]
existing_columns = tuple(ex.free_var(id) for id in self.column_ids)
result = self._select(tuple(*existing_columns, *new_values)) # type: ignore
return result

def selection(
self: T,
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...],
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
op_compiler.compile_expression(ex.free_var(input), bindings).name(id)
for input, id in input_output_pairs
]
result = self._select(tuple(values)) # type: ignore
return result

Expand Down
5 changes: 5 additions & 0 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True):
else:
return self.compile_unordered_ir(node.child)

@_compile_node.register
def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
return result.selection(node.input_output_pairs)

@_compile_node.register
def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
Expand Down
31 changes: 30 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,15 +622,41 @@ def relation_ops_created(self) -> int:
return 0


@dataclass(frozen=True)
class SelectionNode(UnaryNode):
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...]

def __hash__(self):
return self._node_hash

@functools.cached_property
def schema(self) -> schemata.ArraySchema:
input_types = self.child.schema._mapping
items = tuple(
schemata.SchemaItem(output, input_types[input])
for input, output in self.input_output_pairs
)
return schemata.ArraySchema(items)

@property
def variables_introduced(self) -> int:
# This operation only renames variables, doesn't actually create new ones
return 0


@dataclass(frozen=True)
class ProjectionNode(UnaryNode):
"""Assigns new variables (without modifying existing ones)"""

assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...]

def __post_init__(self):
input_types = self.child.schema._mapping
for expression, id in self.assignments:
# throws TypeError if invalid
_ = expression.output_type(input_types)
# Cannot assign to existing variables - append only!
assert all(name not in self.child.schema.names for _, name in self.assignments)

def __hash__(self):
return self._node_hash
Expand All @@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema:
)
for ex, id in self.assignments
)
return schemata.ArraySchema(items)
schema = self.child.schema
for item in items:
schema.append(item)
return schema

@property
def variables_introduced(self) -> int:
Expand Down
27 changes: 25 additions & 2 deletions bigframes/core/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...]

REWRITABLE_NODE_TYPES = (
nodes.SelectionNode,
nodes.ProjectionNode,
nodes.FilterNode,
nodes.ReversedNode,
Expand Down Expand Up @@ -54,7 +55,12 @@ def from_node_span(
for id in get_node_column_ids(node)
)
return cls(node, selection, None, ())
if isinstance(node, nodes.ProjectionNode):

if isinstance(node, nodes.SelectionNode):
return cls.from_node_span(node.child, target).select(
node.input_output_pairs
)
elif isinstance(node, nodes.ProjectionNode):
return cls.from_node_span(node.child, target).project(node.assignments)
elif isinstance(node, nodes.FilterNode):
return cls.from_node_span(node.child, target).filter(node.predicate)
Expand All @@ -69,14 +75,31 @@ def from_node_span(
def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]:
return {col_id: expr for expr, col_id in self.columns}

def select(self, input_output_pairs: Tuple[Tuple[str, str], ...]) -> SquashedSelect:
new_columns = tuple(
(
scalar_exprs.free_var(input).bind_all_variables(self.column_lookup),
output,
)
for input, output in input_output_pairs
)
return SquashedSelect(
self.root, new_columns, self.predicate, self.ordering, self.reverse_root
)

def project(
self, projection: Tuple[Tuple[scalar_exprs.Expression, str], ...]
) -> SquashedSelect:
existing_columns = self.columns
new_columns = tuple(
(expr.bind_all_variables(self.column_lookup), id) for expr, id in projection
)
return SquashedSelect(
self.root, new_columns, self.predicate, self.ordering, self.reverse_root
self.root,
(*existing_columns, *new_columns),
self.predicate,
self.ordering,
self.reverse_root,
)

def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
Expand Down
4 changes: 1 addition & 3 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def generate_head_plan(node: nodes.BigFrameNode, n: int):
predicate = ops.lt_op.as_expr(ex.free_var(offsets_id), ex.const(n))
plan_w_head = nodes.FilterNode(plan_w_offsets, predicate)
# Finally, drop the offsets column
return nodes.ProjectionNode(
plan_w_head, tuple((ex.free_var(i), i) for i in node.schema.names)
)
return nodes.SelectionNode(plan_w_head, tuple((i, i) for i in node.schema.names))


def generate_row_count_plan(node: nodes.BigFrameNode):
Expand Down
9 changes: 7 additions & 2 deletions bigframes/session/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def session_aware_cache_plan(
"""
node_counts = traversals.count_nodes(session_forest)
# These node types are cheap to re-compute, so it makes more sense to cache their children.
de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode)
de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode, nodes.SelectionNode)
caching_target = cur_node = root
caching_target_refs = node_counts.get(caching_target, 0)

Expand All @@ -46,9 +46,14 @@ def session_aware_cache_plan(
# Filter node doesn't define any variables, so no need to chain expressions
filters.append(cur_node.predicate)
elif isinstance(cur_node, nodes.ProjectionNode):
...
elif isinstance(cur_node, nodes.SelectionNode):
# Projection defines the variables that are used in the filter expressions, need to substitute variables with their scalar expressions
# that instead reference variables in the child node.
bindings = {name: expr for expr, name in cur_node.assignments}
bindings = {
output: ex.free_var(input)
for input, output in cur_node.input_output_pairs
}
filters = [i.bind_all_variables(bindings) for i in filters]
else:
raise ValueError(f"Unexpected de-cached node: {cur_node}")
Expand Down
Loading

0 comments on commit 6b1c27c

Please sign in to comment.