diff --git a/sdk/RELEASE.md b/sdk/RELEASE.md index b67d95a9062..664a2e87c51 100644 --- a/sdk/RELEASE.md +++ b/sdk/RELEASE.md @@ -7,6 +7,7 @@ ## Deprecations ## Bug fixes and other changes +* Unblock valid topologies [\#8416](https://github.com/kubeflow/pipelines/pull/8416) ## Documentation updates diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 3756221fca8..a1c1d272ba1 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -437,7 +437,7 @@ def dummy_op(msg: str = ''): with self.assertRaisesRegex( RuntimeError, - 'Task dummy-op cannot dependent on any task inside the group:'): + r'Tasks cannot depend on an upstream task inside'): @dsl.pipeline(name='test-pipeline') def my_pipeline(val: bool): @@ -479,7 +479,7 @@ def dummy_op(msg: str = ''): with self.assertRaisesRegex( RuntimeError, - 'Task dummy-op cannot dependent on any task inside the group:'): + r'Tasks cannot depend on an upstream task inside'): @dsl.pipeline(name='test-pipeline') def my_pipeline(val: bool): @@ -521,7 +521,7 @@ def dummy_op(msg: str = ''): with self.assertRaisesRegex( RuntimeError, - 'Task dummy-op cannot dependent on any task inside the group:'): + r'Tasks cannot depend on an upstream task inside'): @dsl.pipeline(name='test-pipeline') def my_pipeline(val: bool): @@ -1490,5 +1490,408 @@ def pipeline_with_input(boolean: bool = False): .default_value.bool_value, True) +# helper component defintions for the ValidLegalTopologies tests +@dsl.component +def print_op(message: str): + print(message) + + +@dsl.component +def return_1() -> int: + return 1 + + +@dsl.component +def args_generator_op() -> List[Dict[str, str]]: + return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}] + + +class TestValidLegalTopologies(unittest.TestCase): + + def test_inside_of_root_group_permitted(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + one = print_op(message='1') + two = print_op(message='2') + three = print_op(message=str(return_1_task.output)) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_upstream_inside_deeper_condition_blocked(self): + + with self.assertRaisesRegex( + RuntimeError, + r'Tasks cannot depend on an upstream task inside'): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + one = print_op(message='1') + with dsl.Condition(return_1_task.output == 1): + two = print_op(message='2') + + three = print_op(message='3').after(two) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_upstream_in_the_same_condition_permitted(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + with dsl.Condition(return_1_task.output == 1): + one = return_1() + two = print_op(message='2') + three = print_op(message=str(one.output)) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_inside_deeper_condition_permitted(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + one = print_op(message='1') + with dsl.Condition(return_1_task.output == 1): + two = print_op(message='2') + three = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_and_upstream_in_different_condition_on_same_level_blocked( + self): + + with self.assertRaisesRegex( + RuntimeError, + r'Tasks cannot depend on an upstream task inside'): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + one = print_op(message='1') + with dsl.Condition(return_1_task.output == 1): + two = print_op(message='2') + + with dsl.Condition(return_1_task.output == 1): + three = print_op(message='3').after(two) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_inside_deeper_nested_condition_permitted(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + return_1_task2 = return_1() + + with dsl.Condition(return_1_task.output == 1): + one = return_1() + with dsl.Condition(return_1_task2.output == 1): + two = print_op(message='2') + three = print_op(message=str(one.output)) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_upstream_inside_deeper_nested_condition_blocked(self): + + with self.assertRaisesRegex( + RuntimeError, + r'Tasks cannot depend on an upstream task inside'): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + with dsl.Condition(return_1_task.output == 1): + one = print_op(message='1') + with dsl.Condition(return_1_task.output == 1): + two = print_op(message='2') + three = print_op(message='3').after(two) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_upstream_in_same_for_loop_with_downstream_permitted(self): + + @dsl.pipeline() + def my_pipeline(): + args_generator = args_generator_op() + + with dsl.ParallelFor(args_generator.output): + one = print_op(message='1') + two = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_not_in_same_for_loop_with_upstream_blocked(self): + + with self.assertRaisesRegex( + RuntimeError, + r'Tasks cannot depend on an upstream task inside'): + + @dsl.pipeline() + def my_pipeline(): + args_generator = args_generator_op() + + with dsl.ParallelFor(args_generator.output): + one = print_op(message='1') + two = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_not_in_same_for_loop_with_upstream_seperate_blocked( + self): + + with self.assertRaisesRegex( + RuntimeError, + r'Tasks cannot depend on an upstream task inside'): + + @dsl.pipeline() + def my_pipeline(): + args_generator = args_generator_op() + + with dsl.ParallelFor(args_generator.output): + one = print_op(message='1') + + with dsl.ParallelFor(args_generator.output): + two = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_not_in_same_for_loop_with_upstream_nested_blocked(self): + + with self.assertRaisesRegex( + RuntimeError, + r'Downstream tasks in a nested ParallelFor group cannot depend on an upstream task in a shallower ParallelFor group.' + ): + + @dsl.pipeline() + def my_pipeline(): + args_generator = args_generator_op() + + with dsl.ParallelFor(args_generator.output): + one = print_op(message='1') + + with dsl.ParallelFor(args_generator.output): + two = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_in_condition_nested_in_a_for_loop(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + with dsl.ParallelFor([1, 2, 3]): + one = print_op(message='1') + with dsl.Condition(return_1_task.output == 1): + two = print_op(message='2').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_in_a_for_loop_nested_in_a_condition(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + with dsl.Condition(return_1_task.output == 1): + one = print_op(message='1') + with dsl.ParallelFor([1, 2, 3]): + two = print_op(message='2').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_downstream_in_a_nested_for_loop_not_related_to_upstream(self): + + @dsl.pipeline() + def my_pipeline(): + return_1_task = return_1() + + with dsl.ParallelFor([1, 2, 3]): + one = print_op(message='1') + with dsl.ParallelFor([1, 2, 3]): + two = print_op(message='2').after(return_1_task) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + +class TestCannotUseAfterCrossDAG(unittest.TestCase): + + def test_inner_task_prevented(self): + with self.assertRaisesRegex(RuntimeError, r'Task'): + + @dsl.component + def print_op(message: str): + print(message) + + @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') + def my_pipeline(): + first_exit_task = print_op(message='First exit task.') + + with dsl.ExitHandler(first_exit_task): + first_print_op = print_op( + message='Inside first exit handler.') + + second_exit_task = print_op(message='Second exit task.') + with dsl.ExitHandler(second_exit_task): + print_op(message='Inside second exit handler.').after( + first_print_op) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_exit_handler_task_prevented(self): + with self.assertRaisesRegex(RuntimeError, r'Task'): + + @dsl.component + def print_op(message: str): + print(message) + + @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') + def my_pipeline(): + first_exit_task = print_op(message='First exit task.') + + with dsl.ExitHandler(first_exit_task): + first_print_op = print_op( + message='Inside first exit handler.') + + second_exit_task = print_op(message='Second exit task.') + with dsl.ExitHandler(second_exit_task): + x = print_op(message='Inside second exit handler.') + x.after(first_print_op) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_within_same_exit_handler_permitted(self): + + @dsl.component + def print_op(message: str): + print(message) + + @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') + def my_pipeline(): + first_exit_task = print_op(message='First exit task.') + + with dsl.ExitHandler(first_exit_task): + first_print_op = print_op( + message='First task inside first exit handler.') + second_print_op = print_op( + message='Second task inside first exit handler.').after( + first_print_op) + + second_exit_task = print_op(message='Second exit task.') + with dsl.ExitHandler(second_exit_task): + print_op(message='Inside second exit handler.') + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_outside_of_condition_blocked(self): + with self.assertRaisesRegex(RuntimeError, r'Task'): + + @dsl.component + def print_op(message: str): + print(message) + + @dsl.component + def return_1() -> int: + return 1 + + @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') + def my_pipeline(): + return_1_task = return_1() + + with dsl.Condition(return_1_task.output == 1): + one = print_op(message='1') + two = print_op(message='2') + three = print_op(message='3').after(one) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + def test_inside_of_condition_permitted(self): + + @dsl.component + def print_op(message: str): + print(message) + + @dsl.component + def return_1() -> int: + return 1 + + @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') + def my_pipeline(): + return_1_task = return_1() + + with dsl.Condition(return_1_task.output == '1'): + one = print_op(message='1') + two = print_op(message='2').after(one) + three = print_op(message='3') + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=package_path) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/compiler/compiler_utils.py b/sdk/python/kfp/compiler/compiler_utils.py index 18e329afcde..6efd9617069 100644 --- a/sdk/python/kfp/compiler/compiler_utils.py +++ b/sdk/python/kfp/compiler/compiler_utils.py @@ -14,6 +14,7 @@ """Utility methods for compiler implementation that is IR-agnostic.""" import collections +from copy import deepcopy from typing import Dict, List, Mapping, Set, Tuple, Union from kfp.components import for_loop @@ -427,14 +428,42 @@ def get_dependencies( task2=task, ) - # a task cannot depend on a task created in a for loop group since individual PipelineTask variables are reassigned after each loop iteration - dependent_group = group_name_to_group.get(upstream_groups[0], None) - if isinstance(dependent_group, - (tasks_group.ParallelFor, tasks_group.Condition, - tasks_group.ExitHandler)): + # uncommon upstream ancestor check + uncommon_upstream_groups = deepcopy(upstream_groups) + uncommon_upstream_groups.remove( + upstream_task.name + ) # because a task's `upstream_groups` contains the task's name + if uncommon_upstream_groups: + dependent_group = group_name_to_group.get( + uncommon_upstream_groups[0], None) + if isinstance(dependent_group, tasks_group.ExitHandler): + task_group_type = 'an ' + tasks_group.ExitHandler.__name__ + + elif isinstance(dependent_group, tasks_group.Condition): + task_group_type = 'a ' + tasks_group.Condition.__name__ + + else: + task_group_type = 'a ' + tasks_group.ParallelFor.__name__ + raise RuntimeError( - f'Task {task.name} cannot dependent on any task inside' - f' the group: {upstream_groups[0]}.') + f'Tasks cannot depend on an upstream task inside {task_group_type} that is not a common ancestor of both tasks. Task {task.name} depends on upstream task {upstream_task.name}.' + ) + + # ParralelFor Nested Check + # if there is a parrallelFor group type in the upstream parents tasks and there also exists a parallelFor in the uncommon_ancestors of downstream: this means a nested for loop exists in the DAG + upstream_parent_tasks = task_name_to_parent_groups[ + upstream_task.name] + for group in downstream_groups: + if isinstance( + group_name_to_group.get(group, None), + tasks_group.ParallelFor): + for parent_task in upstream_parent_tasks: + if isinstance( + group_name_to_group.get(parent_task, None), + tasks_group.ParallelFor): + raise RuntimeError( + f'Downstream tasks in a nested {tasks_group.ParallelFor.__name__} group cannot depend on an upstream task in a shallower {tasks_group.ParallelFor.__name__} group. Task {task.name} depends on upstream task {upstream_task.name}, while {group} is nested in {parent_task}.' + ) dependencies[downstream_groups[0]].add(upstream_groups[0]) diff --git a/sdk/python/kfp/components/pipeline_task.py b/sdk/python/kfp/components/pipeline_task.py index 32399355318..03414daa1f7 100644 --- a/sdk/python/kfp/components/pipeline_task.py +++ b/sdk/python/kfp/components/pipeline_task.py @@ -463,10 +463,6 @@ def my_pipeline(): task2 = my_component(text='2nd task').after(task1) """ for task in tasks: - if task.parent_task_group is not self.parent_task_group: - raise ValueError( - f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.' - ) self._task_spec.dependent_tasks.append(task.name) return self diff --git a/sdk/python/kfp/components/pipeline_task_test.py b/sdk/python/kfp/components/pipeline_task_test.py index f1d49784dec..672fb85d43f 100644 --- a/sdk/python/kfp/components/pipeline_task_test.py +++ b/sdk/python/kfp/components/pipeline_task_test.py @@ -13,14 +13,10 @@ # limitations under the License. """Tests for kfp.components.pipeline_task.""" -import os -import tempfile import textwrap import unittest from absl.testing import parameterized -from kfp import compiler -from kfp import dsl from kfp.components import pipeline_task from kfp.components import placeholders from kfp.components import structures @@ -297,136 +293,5 @@ def test_set_display_name(self): self.assertEqual('test_name', task._task_spec.display_name) -class TestCannotUseAfterCrossDAG(unittest.TestCase): - - def test_inner_task_prevented(self): - with self.assertRaisesRegex(ValueError, - r'Cannot use \.after\(\) across'): - - @dsl.component - def print_op(message: str): - print(message) - - @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') - def my_pipeline(): - first_exit_task = print_op(message='First exit task.') - - with dsl.ExitHandler(first_exit_task): - first_print_op = print_op( - message='Inside first exit handler.') - - second_exit_task = print_op(message='Second exit task.') - with dsl.ExitHandler(second_exit_task): - print_op(message='Inside second exit handler.').after( - first_print_op) - - with tempfile.TemporaryDirectory() as tempdir: - package_path = os.path.join(tempdir, 'pipeline.yaml') - compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path=package_path) - - def test_exit_handler_task_prevented(self): - with self.assertRaisesRegex(ValueError, - r'Cannot use \.after\(\) across'): - - @dsl.component - def print_op(message: str): - print(message) - - @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') - def my_pipeline(): - first_exit_task = print_op(message='First exit task.') - - with dsl.ExitHandler(first_exit_task): - first_print_op = print_op( - message='Inside first exit handler.') - - second_exit_task = print_op(message='Second exit task.') - with dsl.ExitHandler(second_exit_task): - x = print_op(message='Inside second exit handler.') - x.after(first_exit_task) - - with tempfile.TemporaryDirectory() as tempdir: - package_path = os.path.join(tempdir, 'pipeline.yaml') - compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path=package_path) - - def test_within_same_exit_handler_permitted(self): - - @dsl.component - def print_op(message: str): - print(message) - - @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') - def my_pipeline(): - first_exit_task = print_op(message='First exit task.') - - with dsl.ExitHandler(first_exit_task): - first_print_op = print_op( - message='First task inside first exit handler.') - second_print_op = print_op( - message='Second task inside first exit handler.').after( - first_print_op) - - second_exit_task = print_op(message='Second exit task.') - with dsl.ExitHandler(second_exit_task): - print_op(message='Inside second exit handler.') - - with tempfile.TemporaryDirectory() as tempdir: - package_path = os.path.join(tempdir, 'pipeline.yaml') - compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path=package_path) - - def test_outside_of_condition_blocked(self): - with self.assertRaisesRegex(ValueError, - r'Cannot use \.after\(\) across'): - - @dsl.component - def print_op(message: str): - print(message) - - @dsl.component - def return_1() -> int: - return 1 - - @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') - def my_pipeline(): - return_1_task = return_1() - - with dsl.Condition(return_1_task.output == 1): - one = print_op(message='1') - two = print_op(message='2') - three = print_op(message='3').after(one) - - with tempfile.TemporaryDirectory() as tempdir: - package_path = os.path.join(tempdir, 'pipeline.yaml') - compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path=package_path) - - def test_inside_of_condition_permitted(self): - - @dsl.component - def print_op(message: str): - print(message) - - @dsl.component - def return_1() -> int: - return 1 - - @dsl.pipeline(name='pipeline-with-multiple-exit-handlers') - def my_pipeline(): - return_1_task = return_1() - - with dsl.Condition(return_1_task.output == '1'): - one = print_op(message='1') - two = print_op(message='2').after(one) - three = print_op(message='3') - - with tempfile.TemporaryDirectory() as tempdir: - package_path = os.path.join(tempdir, 'pipeline.yaml') - compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path=package_path) - - if __name__ == '__main__': unittest.main()