Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): unblock valid topology. #8416

Merged
merged 18 commits into from
Dec 2, 2022
7 changes: 6 additions & 1 deletion sdk/python/kfp/components/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,13 @@ def my_pipeline():
task1 = my_component(text='1st task')
task2 = my_component(text='2nd task').after(task1)
"""
from kfp.components.tasks_group import TasksGroupType

for task in tasks:
if task.parent_task_group is not self.parent_task_group:
if task.parent_task_group is not self.parent_task_group and task.parent_task_group.group_type in [
TasksGroupType.CONDITION, TasksGroupType.FOR_LOOP,
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
TasksGroupType.EXIT_HANDLER
]:
raise ValueError(
f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.'
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
)
Expand Down
26 changes: 25 additions & 1 deletion sdk/python/kfp/components/pipeline_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def my_pipeline():
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
x = print_op(message='Inside second exit handler.')
x.after(first_exit_task)
x.after(first_print_op)
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
Expand Down Expand Up @@ -411,6 +411,30 @@ def my_pipeline():
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)

def test_outside_of_condition_exception_permitted(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1

@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
return_1_task = return_1()

one = print_op(message='1')
with dsl.Condition(return_1_task.output == 1):
two = print_op(message='2')
three = print_op(message='3').after(one)

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)

def test_inside_of_condition_permitted(self):

@dsl.component
Expand Down