diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 613eb0b507d2..81bbee2b7fb6 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -45,8 +45,24 @@ def apply(self, plan: PhysicalPlan) -> PhysicalPlan: # we fuse together MapOperator -> AllToAllOperator pairs. fused_dag = self._fuse_all_to_all_operators_in_dag(fused_dag) + # Update output dependencies after fusion. + # TODO(hchen): Instead of updating the depdencies manually, + # we need a better abstraction for manipulating the DAG. + self._remove_output_depes(fused_dag) + self._update_output_depes(fused_dag) + return PhysicalPlan(fused_dag, self._op_map) + def _remove_output_depes(self, op: PhysicalOperator) -> None: + for input in op._input_dependencies: + input._output_dependencies = [] + self._remove_output_depes(input) + + def _update_output_depes(self, op: PhysicalOperator) -> None: + for input in op._input_dependencies: + input._output_dependencies.append(op) + self._update_output_depes(input) + def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator: """Starting at the given operator, traverses up the DAG of operators and recursively fuses compatible MapOperator -> MapOperator pairs. diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 3c20cf7abd7c..a9972c819d00 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -467,7 +467,9 @@ def test_read_map_batches_operator_fusion(ray_start_regular_shared, enable_optim assert physical_op.name == "ReadParquet->MapBatches()" assert isinstance(physical_op, MapOperator) assert len(physical_op.input_dependencies) == 1 - assert isinstance(physical_op.input_dependencies[0], InputDataBuffer) + input = physical_op.input_dependencies[0] + assert isinstance(input, InputDataBuffer) + assert physical_op in input.output_dependencies, input.output_dependencies def test_read_map_chain_operator_fusion(ray_start_regular_shared, enable_optimizer):