diff --git a/flytekit/annotated/dynamic_workflow_task.py b/flytekit/annotated/dynamic_workflow_task.py index 9875e7dca7..7b6ae08156 100644 --- a/flytekit/annotated/dynamic_workflow_task.py +++ b/flytekit/annotated/dynamic_workflow_task.py @@ -78,6 +78,15 @@ def aggregate(tasks, workflows, node) -> None: workflows.add(node.workflow_node.sdk_workflow) for sub_node in node.workflow_node.sdk_workflow.nodes: DynamicWorkflowTask.aggregate(tasks, workflows, sub_node) + if node.branch_node is not None: + if node.branch_node.if_else.case.then_node is not None: + DynamicWorkflowTask.aggregate(tasks, workflows, node.branch_node.if_else.case.then_node) + if node.branch_node.if_else.other: + for oth in node.branch_node.if_else.other: + if oth.then_node: + DynamicWorkflowTask.aggregate(tasks, workflows, oth.then_node) + if node.branch_node.if_else.else_node is not None: + DynamicWorkflowTask.aggregate(tasks, workflows, node.branch_node.if_else.else_node) def execute(self, **kwargs) -> Any: """ diff --git a/tests/flytekit/unit/annotated/test_dynamic_conditional.py b/tests/flytekit/unit/annotated/test_dynamic_conditional.py new file mode 100644 index 0000000000..a312ea8554 --- /dev/null +++ b/tests/flytekit/unit/annotated/test_dynamic_conditional.py @@ -0,0 +1,91 @@ +import typing +from datetime import datetime +from random import seed + +from flytekit import dynamic, task, workflow +from flytekit.annotated import context_manager +from flytekit.annotated.condition import conditional +from flytekit.annotated.context_manager import ExecutionState, Image, ImageConfig + +# seed random number generator +seed(datetime.now().microsecond) + + +def test_dynamic_conditional(): + @task + def split(in1: typing.List[int]) -> (typing.List[int], typing.List[int], int): + return in1[0 : int(len(in1) / 2)], in1[int(len(in1) / 2) + 1 :], len(in1) / 2 + + # One sample implementation for merging. In a more real world example, this might merge file streams and only load + # chunks into the memory. + @task + def merge(x: typing.List[int], y: typing.List[int]) -> typing.List[int]: + n1 = len(x) + n2 = len(y) + result = list[int]() + i = 0 + j = 0 + + # Traverse both array + while i < n1 and j < n2: + # Check if current element of first array is smaller than current element of second array. If yes, + # store first array element and increment first array index. Otherwise do same with second array + if x[i] < y[j]: + result.append(x[i]) + i = i + 1 + else: + result.append(y[j]) + j = j + 1 + + # Store remaining elements of first array + while i < n1: + result.append(x[i]) + i = i + 1 + + # Store remaining elements of second array + while j < n2: + result.append(y[j]) + j = j + 1 + + return result + + # This runs the sorting completely locally. It's faster and more efficient to do so if the entire list fits in memory. + @task + def merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: + return sorted(in1) + + @task + def also_merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: + return sorted(in1) + + @dynamic + def merge_sort_remotely(in1: typing.List[int]) -> typing.List[int]: + x, y, new_count = split(in1=in1) + sorted_x = merge_sort(in1=x, count=new_count) + sorted_y = merge_sort(in1=y, count=new_count) + return merge(x=sorted_x, y=sorted_y) + + @workflow + def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]: + return ( + conditional("terminal_case") + .if_(count < 500) + .then(merge_sort_locally(in1=in1)) + .elif_(count < 1000) + .then(also_merge_sort_locally(in1=in1)) + .else_() + .then(merge_sort_remotely(in1=in1)) + ) + + with context_manager.FlyteContext.current_context().new_registration_settings( + registration_settings=context_manager.RegistrationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + ) as ctx: + with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: + dynamic_job_spec = merge_sort_remotely.compile_into_workflow(ctx, in1=[2, 3, 4, 5]) + assert len(dynamic_job_spec.tasks) == 5