Skip to content

Commit

Permalink
refactor: Simplify projection nodes (#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron authored and arwas11 committed Sep 9, 2024
1 parent 6701dd7 commit 46d8012
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 91 deletions.
77 changes: 23 additions & 54 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,49 +192,38 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
)

def project_to_id(self, expression: ex.Expression, output_id: str):
if output_id in self.column_ids: # Mutate case
exprs = [
((expression if (col_id == output_id) else ex.free_var(col_id)), col_id)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (expression, output_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=(
(
expression,
output_id,
),
),
)
)

def assign(self, source_id: str, destination_id: str) -> ArrayValue:
if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.free_var(source_id)
if (col_id == destination_id)
else ex.free_var(col_id)
),
(source_id if (col_id == destination_id) else col_id),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.free_var(source_id), destination_id)]
self_projection = ((col_id, col_id) for col_id in self.column_ids)
exprs = [*self_projection, (source_id, destination_id)]
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(exprs),
input_output_pairs=tuple(exprs),
)
)

def assign_constant(
def create_constant(
self,
destination_id: str,
value: typing.Any,
Expand All @@ -244,49 +233,31 @@ def assign_constant(
# Need to assign a data type when value is NaN.
dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE

if destination_id in self.column_ids: # Mutate case
exprs = [
(
(
ex.const(value, dtype)
if (col_id == destination_id)
else ex.free_var(col_id)
),
col_id,
)
for col_id in self.column_ids
]
else: # append case
self_projection = (
(ex.free_var(col_id), col_id) for col_id in self.column_ids
)
exprs = [*self_projection, (ex.const(value, dtype), destination_id)]
return ArrayValue(
nodes.ProjectionNode(
child=self.node,
assignments=tuple(exprs),
assignments=((ex.const(value, dtype), destination_id),),
)
)

def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
# This basically just drops and reorders columns - logically a no-op except as a final step
selections = ((col_id, col_id) for col_id in column_ids)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(selections),
input_output_pairs=tuple(selections),
)
)

def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
new_projection = (
(ex.free_var(col_id), col_id)
for col_id in self.column_ids
if col_id not in columns
(col_id, col_id) for col_id in self.column_ids if col_id not in columns
)
return ArrayValue(
nodes.ProjectionNode(
nodes.SelectionNode(
child=self.node,
assignments=tuple(new_projection),
input_output_pairs=tuple(new_projection),
)
)

Expand Down Expand Up @@ -422,15 +393,13 @@ def unpivot(
col_expr = ops.case_when_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))

label_exprs = ((ex.free_var(id), id) for id in index_col_ids)
# passthrough columns are unchanged, just repeated N times each
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
unpivot_col_ids = [id for id, _ in unpivot_columns]
return ArrayValue(
nodes.ProjectionNode(
child=joined_array.node,
assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs),
assignments=(*unpivot_exprs,),
)
)
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns])

def _cross_join_w_labels(
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]
Expand Down
25 changes: 17 additions & 8 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def multi_apply_unary_op(
for col_id in columns:
label = self.col_id_to_label[col_id]
block, result_id = block.project_expr(
expr.bind_all_variables({input_varname: ex.free_var(col_id)}),
expr.bind_variables({input_varname: ex.free_var(col_id)}),
label=label,
)
block = block.copy_values(result_id, col_id)
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def create_constant(
dtype: typing.Optional[bigframes.dtypes.Dtype] = None,
) -> typing.Tuple[Block, str]:
result_id = guid.generate_guid()
expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype)
expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype)
# Create index copy with label inserted
# See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html
labels = self.column_labels.insert(len(self.column_labels), label)
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def aggregate_all_and_stack(
index_id = guid.generate_guid()
result_expr = self.expr.aggregate(
aggregations, dropna=dropna
).assign_constant(index_id, None, None)
).create_constant(index_id, None, None)
# Transpose as last operation so that final block has valid transpose cache
return Block(
result_expr,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def aggregate(
names: typing.List[Label] = []
if len(by_column_ids) == 0:
label_id = guid.generate_guid()
result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype())
result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype())
index_columns = (label_id,)
names = [None]
else:
Expand Down Expand Up @@ -1614,17 +1614,22 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ex.const(prefix),
ops.AsTypeOp(to_type="string").as_expr(index_col),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))

return Block(
expr,
index_columns=self.index_columns,
index_columns=new_index_cols,
column_labels=self.column_labels,
index_labels=self.index.names,
)
Expand All @@ -1635,17 +1640,21 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
if axis_number == 0:
expr = self._expr
new_index_cols = []
for index_col in self._index_columns:
new_col = guid.generate_guid()
expr = expr.project_to_id(
expression=ops.add_op.as_expr(
ops.AsTypeOp(to_type="string").as_expr(index_col),
ex.const(suffix),
),
output_id=index_col,
output_id=new_col,
)
new_index_cols.append(new_col)
expr = expr.select_columns((*new_index_cols, *self.value_columns))
return Block(
expr,
index_columns=self.index_columns,
index_columns=new_index_cols,
column_labels=self.column_labels,
index_labels=self.index.names,
)
Expand Down
15 changes: 14 additions & 1 deletion bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,23 @@ def projection(
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
new_values = [
op_compiler.compile_expression(expression, bindings).name(id)
for expression, id in expression_id_pairs
]
result = self._select(tuple([*self._columns, *new_values])) # type: ignore
return result

def selection(
self: T,
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...],
) -> T:
"""Apply an expression to the ArrayValue and assign the output to a column."""
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
values = [
op_compiler.compile_expression(ex.free_var(input), bindings).name(id)
for input, id in input_output_pairs
]
result = self._select(tuple(values)) # type: ignore
return result

Expand Down
5 changes: 5 additions & 0 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True):
else:
return self.compile_unordered_ir(node.child)

@_compile_node.register
def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
return result.selection(node.input_output_pairs)

@_compile_node.register
def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True):
result = self.compile_node(node.child, ordered)
Expand Down
29 changes: 22 additions & 7 deletions bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ def output_type(
...

@abc.abstractmethod
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
"""Replace all variables with expression given in `bindings`."""
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
"""Replace variables with expression given in `bindings`.
If check_bind_all is True, validate that all free variables are bound to a new value.
"""
...

@property
Expand Down Expand Up @@ -141,7 +146,9 @@ def output_type(
) -> dtypes.ExpressionType:
return self.dtype

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
return self

@property
Expand Down Expand Up @@ -178,11 +185,14 @@ def output_type(
else:
raise ValueError(f"Type of variable {self.id} has not been fixed.")

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
if self.id in bindings.keys():
return bindings[self.id]
else:
elif check_bind_all:
raise ValueError(f"Variable {self.id} remains unbound")
return self

@property
def is_bijective(self) -> bool:
Expand Down Expand Up @@ -225,10 +235,15 @@ def output_type(
)
return self.op.output_type(*operand_types)

def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def bind_variables(
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
) -> Expression:
return OpExpression(
self.op,
tuple(input.bind_all_variables(bindings) for input in self.inputs),
tuple(
input.bind_variables(bindings, check_bind_all=check_bind_all)
for input in self.inputs
),
)

@property
Expand Down
31 changes: 30 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,15 +622,41 @@ def relation_ops_created(self) -> int:
return 0


@dataclass(frozen=True)
class SelectionNode(UnaryNode):
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...]

def __hash__(self):
return self._node_hash

@functools.cached_property
def schema(self) -> schemata.ArraySchema:
input_types = self.child.schema._mapping
items = tuple(
schemata.SchemaItem(output, input_types[input])
for input, output in self.input_output_pairs
)
return schemata.ArraySchema(items)

@property
def variables_introduced(self) -> int:
# This operation only renames variables, doesn't actually create new ones
return 0


@dataclass(frozen=True)
class ProjectionNode(UnaryNode):
"""Assigns new variables (without modifying existing ones)"""

assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...]

def __post_init__(self):
input_types = self.child.schema._mapping
for expression, id in self.assignments:
# throws TypeError if invalid
_ = expression.output_type(input_types)
# Cannot assign to existing variables - append only!
assert all(name not in self.child.schema.names for _, name in self.assignments)

def __hash__(self):
return self._node_hash
Expand All @@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema:
)
for ex, id in self.assignments
)
return schemata.ArraySchema(items)
schema = self.child.schema
for item in items:
schema = schema.append(item)
return schema

@property
def variables_introduced(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def bind_variables(
self, mapping: Mapping[str, expression.Expression]
) -> OrderingExpression:
return OrderingExpression(
self.scalar_expression.bind_all_variables(mapping),
self.scalar_expression.bind_variables(mapping),
self.direction,
self.na_last,
)
Expand Down
Loading

0 comments on commit 46d8012

Please sign in to comment.