Skip to content

Commit

Permalink
Fix dependencies for operation fusion (#37878)
Browse files Browse the repository at this point in the history
The op fusion rule only updates the input deps, but not the output deps. This bug makes the limit op non-effective for fused DAGs.

Signed-off-by: Hao Chen <[email protected]>
  • Loading branch information
raulchen authored Jul 28, 2023
1 parent d650f03 commit cf4d7d1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
16 changes: 16 additions & 0 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,24 @@ def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
# we fuse together MapOperator -> AllToAllOperator pairs.
fused_dag = self._fuse_all_to_all_operators_in_dag(fused_dag)

# Update output dependencies after fusion.
# TODO(hchen): Instead of updating the depdencies manually,
# we need a better abstraction for manipulating the DAG.
self._remove_output_depes(fused_dag)
self._update_output_depes(fused_dag)

return PhysicalPlan(fused_dag, self._op_map)

def _remove_output_depes(self, op: PhysicalOperator) -> None:
for input in op._input_dependencies:
input._output_dependencies = []
self._remove_output_depes(input)

def _update_output_depes(self, op: PhysicalOperator) -> None:
for input in op._input_dependencies:
input._output_dependencies.append(op)
self._update_output_depes(input)

def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator:
"""Starting at the given operator, traverses up the DAG of operators
and recursively fuses compatible MapOperator -> MapOperator pairs.
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def test_read_map_batches_operator_fusion(ray_start_regular_shared, enable_optim
assert physical_op.name == "ReadParquet->MapBatches(<lambda>)"
assert isinstance(physical_op, MapOperator)
assert len(physical_op.input_dependencies) == 1
assert isinstance(physical_op.input_dependencies[0], InputDataBuffer)
input = physical_op.input_dependencies[0]
assert isinstance(input, InputDataBuffer)
assert physical_op in input.output_dependencies, input.output_dependencies


def test_read_map_chain_operator_fusion(ray_start_regular_shared, enable_optimizer):
Expand Down

0 comments on commit cf4d7d1

Please sign in to comment.