Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): unblock valid topology. #8416

Merged
merged 18 commits into from
Dec 2, 2022
167 changes: 57 additions & 110 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ def producer_op() -> str:
def dummy_op(msg: str = ''):
pass

with self.assertRaisesRegex(RuntimeError, r'Task'):
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
Expand Down Expand Up @@ -475,7 +476,8 @@ def producer_op() -> str:
def dummy_op(msg: str = ''):
pass

with self.assertRaisesRegex(RuntimeError, r'Task'):
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
Expand Down Expand Up @@ -515,7 +517,8 @@ def producer_op() -> str:
def dummy_op(msg: str = ''):
pass

with self.assertRaisesRegex(RuntimeError, r'Task'):
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
Expand Down Expand Up @@ -1484,25 +1487,33 @@ def pipeline_with_input(boolean: bool = False):
.default_value.bool_value, True)


class ValidLegalTopologies(unittest.TestCase):
# helper component defintions for the ValidLegalTopologies tests
@dsl.component
def print_op(message: str):
print(message)

def test_inside_of_root_group_permitted(self):

@dsl.component
def print_op(message: str):
print(message)
@dsl.component
def return_1() -> int:
return 1

@dsl.component
def return_1() -> int:
return 1

@dsl.component
def args_generator_op() -> List[Dict[str, str]]:
return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]


class ValidLegalTopologies(unittest.TestCase):
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved

def test_inside_of_root_group_permitted(self):

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()

one = print_op(message='1')
two = print_op(message='2')
three = print_op(message='3').after(one)
three = print_op(message=str(return_1_task.output))

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
Expand All @@ -1511,15 +1522,8 @@ def my_pipeline():

def test_upstream_inside_deeper_condition_blocked(self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1538,22 +1542,14 @@ def my_pipeline():

def test_upstream_in_the_same_condition_permitted(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()

with dsl.Condition(return_1_task.output == 1):
one = print_op(message='1')
one = return_1()
two = print_op(message='2')
three = print_op(message='3').after(one)
three = print_op(message=str(one.output))

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
Expand All @@ -1562,14 +1558,6 @@ def my_pipeline():

def test_downstream_inside_deeper_condition_permitted(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()
Expand All @@ -1587,15 +1575,8 @@ def my_pipeline():
def test_downstream_and_upstream_in_different_condition_on_same_level_blocked(
self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1615,24 +1596,16 @@ def my_pipeline():

def test_downstream_inside_deeper_nested_condition_permitted(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()
return_1_task2 = return_1()

with dsl.Condition(return_1_task.output == 1):
one = print_op(message='1')
one = return_1()
with dsl.Condition(return_1_task2.output == 1):
two = print_op(message='2')
three = print_op(message='3').after(one)
three = print_op(message=str(one.output))

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
Expand All @@ -1641,15 +1614,8 @@ def my_pipeline():

def test_upstream_inside_deeper_nested_condition_blocked(self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1668,14 +1634,6 @@ def my_pipeline():

def test_upstream_in_same_for_loop_with_downstream_permitted(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def args_generator_op() -> List[Dict[str, str]]:
return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]

@dsl.pipeline()
def my_pipeline():
args_generator = args_generator_op()
Expand All @@ -1691,15 +1649,8 @@ def my_pipeline():

def test_downstream_not_in_same_for_loop_with_upstream_blocked(self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def args_generator_op() -> List[Dict[str, str]]:
return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1717,15 +1668,8 @@ def my_pipeline():
def test_downstream_not_in_same_for_loop_with_upstream_seperate_blocked(
self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def args_generator_op() -> List[Dict[str, str]]:
return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]
with self.assertRaisesRegex(
RuntimeError, r'Tasks cannot depend on upstream tasks inside'):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1744,15 +1688,10 @@ def my_pipeline():

def test_downstream_not_in_same_for_loop_with_upstream_nested_blocked(self):

with self.assertRaisesRegex(RuntimeError, r'Task'):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def args_generator_op() -> List[Dict[str, str]]:
return [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]
with self.assertRaisesRegex(
RuntimeError,
r'Downstream task cannot depend on an upstream task while in a nested'
):

@dsl.pipeline()
def my_pipeline():
Expand All @@ -1771,14 +1710,6 @@ def my_pipeline():

def test_downstream_in_condition_nested_in_a_for_loop(self):

@dsl.component
def print_op(message: str):
print(message)

@dsl.component
def return_1() -> int:
return 1

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()
Expand All @@ -1793,6 +1724,22 @@ def my_pipeline():
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)

JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
def test_downstream_in_a_for_loop_nested_in_a_condition(self):

@dsl.pipeline()
def my_pipeline():
return_1_task = return_1()

with dsl.Condition(return_1_task.output == 1):
one = print_op(message='1')
with dsl.ParallelFor([1, 2, 3]):
two = print_op(message='2').after(one)

with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)


if __name__ == '__main__':
unittest.main()
64 changes: 32 additions & 32 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Utility methods for compiler implementation that is IR-agnostic."""

import collections
from copy import deepcopy
from typing import Dict, List, Mapping, Set, Tuple, Union

from kfp.components import for_loop
Expand Down Expand Up @@ -427,44 +428,43 @@ def get_dependencies(
task2=task,
)

# ParralelFor Check
for parent in task_name_to_parent_groups[upstream_task.name]:
parent = group_name_to_group.get(parent, None)
if isinstance(parent, tasks_group.ParallelFor):
exception = True
if parent.name in task_name_to_parent_groups[task.name]:
exception = False
idx = task_name_to_parent_groups[task.name].index(
parent.name)
cnt = 0
for ancestors in task_name_to_parent_groups[
task.name][idx:]:
ancestors = group_name_to_group.get(ancestors, None)
if isinstance(ancestors, tasks_group.ParallelFor):
cnt += 1
if cnt > 1:
exception = True
break
# uncommon upstream ancestor check
uncommon_upstream_groups = deepcopy(upstream_groups)
uncommon_upstream_groups.remove(
upstream_task.name
) # because a task's `upstream_groups` contains the task's name
if uncommon_upstream_groups:
dependent_group = group_name_to_group.get(
upstream_groups[0], None)
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(dependent_group, tasks_group.ExitHandler):
task_group_type = 'an ' + tasks_group.ExitHandler.__name__
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved

if exception:
raise RuntimeError(
f'Tasks cannot depend on upstream tasks inside a ParallelFor. Task {task.name} depends on upstream task {upstream_task.name}.'
)
elif isinstance(dependent_group, tasks_group.Condition):
task_group_type = 'a ' + tasks_group.Condition.__name__

# Condition check
dependent_group = group_name_to_group.get(upstream_groups[0], None)
if isinstance(dependent_group, tasks_group.Condition):
raise RuntimeError(
f'Tasks cannot depend on upstream tasks inside a Condition that is not a common ancestor of both tasks. Task {task.name} depends on upstream task {upstream_task.name}.'
)
else:
task_group_type = 'a ' + tasks_group.ParallelFor.__name__

# ExitHandler check
dependent_group = group_name_to_group.get(upstream_groups[0], None)
if isinstance(dependent_group, tasks_group.ExitHandler):
raise RuntimeError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action required / not your code: I want to note here that I'm not sure if this should use a RuntimeError. This is a runtime error in the sense that it results an ambiguous runtime topology, but it's not a true "runtime" error at the time it's raised.

This file used RuntimeError before this PR and, since this PR actually reduces the set of topologies for which this error would be raised, we don't necessarily need to reconsider this in this PR.

Furthermore, an Exception is usually used when the error is attributed to user code, whereas an Error is usually used when the error is attributed to something else, such as an environment. In this case, this is user code.

For this reason, I think it would make sense for this to be a custom InvalidTopologyException or something similar.

Relatedly, some of the ValueErrors from pipeline_task.py now become RuntimeErrors in this PR, so perhaps that is a reason to consider this in the short term.

f'Tasks cannot depend on upstream tasks inside a Exithandler that is not a common ancestor of both tasks. Task {task.name} depends on upstream task {upstream_task.name}.'
f'Tasks cannot depend on upstream tasks inside {task_group_type} that is not a common ancestor of both tasks. Task {task.name} depends on upstream task {upstream_task.name}.'
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
)

# ParralelFor Nested Check
# if there is a parrallelFor group type in the upstream parents tasks and there also exists a parallelFor in the uncommon_ancestors of downstream: this means a nested for loop exists in the DAG
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
upstream_parent_tasks = task_name_to_parent_groups[
upstream_task.name]
for group in downstream_groups:
if isinstance(
group_name_to_group.get(group, None),
tasks_group.ParallelFor):
for parent_task in upstream_parent_tasks:
if isinstance(
group_name_to_group.get(parent_task, None),
tasks_group.ParallelFor):
raise RuntimeError(
f'Downstream task cannot depend on an upstream task while in a nested {tasks_group.ParallelFor.__name__} group. Task {task.name} depends on upstream task {upstream_task.name}, while {group} is nested in {parent_task}'
JOCSTAA marked this conversation as resolved.
Show resolved Hide resolved
)

dependencies[downstream_groups[0]].add(upstream_groups[0])

return dependencies