From 2da69b660b749f9fe98f21d513c8b0c6cad06966 Mon Sep 17 00:00:00 2001
From: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
Date: Fri, 8 Jan 2021 09:57:38 -0800
Subject: [PATCH] Annotations - dynamic conditional (#319)

---
 flytekit/annotated/dynamic_workflow_task.py   |  9 ++
 .../annotated/test_dynamic_conditional.py     | 91 +++++++++++++++++++
 2 files changed, 100 insertions(+)
 create mode 100644 tests/flytekit/unit/annotated/test_dynamic_conditional.py

diff --git a/flytekit/annotated/dynamic_workflow_task.py b/flytekit/annotated/dynamic_workflow_task.py
index 9875e7dca7..7b6ae08156 100644
--- a/flytekit/annotated/dynamic_workflow_task.py
+++ b/flytekit/annotated/dynamic_workflow_task.py
@@ -78,6 +78,15 @@ def aggregate(tasks, workflows, node) -> None:
                 workflows.add(node.workflow_node.sdk_workflow)
                 for sub_node in node.workflow_node.sdk_workflow.nodes:
                     DynamicWorkflowTask.aggregate(tasks, workflows, sub_node)
+        if node.branch_node is not None:
+            if node.branch_node.if_else.case.then_node is not None:
+                DynamicWorkflowTask.aggregate(tasks, workflows, node.branch_node.if_else.case.then_node)
+            if node.branch_node.if_else.other:
+                for oth in node.branch_node.if_else.other:
+                    if oth.then_node:
+                        DynamicWorkflowTask.aggregate(tasks, workflows, oth.then_node)
+            if node.branch_node.if_else.else_node is not None:
+                DynamicWorkflowTask.aggregate(tasks, workflows, node.branch_node.if_else.else_node)
 
     def execute(self, **kwargs) -> Any:
         """
diff --git a/tests/flytekit/unit/annotated/test_dynamic_conditional.py b/tests/flytekit/unit/annotated/test_dynamic_conditional.py
new file mode 100644
index 0000000000..a312ea8554
--- /dev/null
+++ b/tests/flytekit/unit/annotated/test_dynamic_conditional.py
@@ -0,0 +1,91 @@
+import typing
+from datetime import datetime
+from random import seed
+
+from flytekit import dynamic, task, workflow
+from flytekit.annotated import context_manager
+from flytekit.annotated.condition import conditional
+from flytekit.annotated.context_manager import ExecutionState, Image, ImageConfig
+
+# seed random number generator
+seed(datetime.now().microsecond)
+
+
+def test_dynamic_conditional():
+    @task
+    def split(in1: typing.List[int]) -> (typing.List[int], typing.List[int], int):
+        return in1[0 : int(len(in1) / 2)], in1[int(len(in1) / 2) + 1 :], len(in1) / 2
+
+    # One sample implementation for merging. In a more real world example, this might merge file streams and only load
+    # chunks into the memory.
+    @task
+    def merge(x: typing.List[int], y: typing.List[int]) -> typing.List[int]:
+        n1 = len(x)
+        n2 = len(y)
+        result = list[int]()
+        i = 0
+        j = 0
+
+        # Traverse both array
+        while i < n1 and j < n2:
+            # Check if current element of first array is smaller than current element of second array. If yes,
+            # store first array element and increment first array index. Otherwise do same with second array
+            if x[i] < y[j]:
+                result.append(x[i])
+                i = i + 1
+            else:
+                result.append(y[j])
+                j = j + 1
+
+        # Store remaining elements of first array
+        while i < n1:
+            result.append(x[i])
+            i = i + 1
+
+        # Store remaining elements of second array
+        while j < n2:
+            result.append(y[j])
+            j = j + 1
+
+        return result
+
+    # This runs the sorting completely locally. It's faster and more efficient to do so if the entire list fits in memory.
+    @task
+    def merge_sort_locally(in1: typing.List[int]) -> typing.List[int]:
+        return sorted(in1)
+
+    @task
+    def also_merge_sort_locally(in1: typing.List[int]) -> typing.List[int]:
+        return sorted(in1)
+
+    @dynamic
+    def merge_sort_remotely(in1: typing.List[int]) -> typing.List[int]:
+        x, y, new_count = split(in1=in1)
+        sorted_x = merge_sort(in1=x, count=new_count)
+        sorted_y = merge_sort(in1=y, count=new_count)
+        return merge(x=sorted_x, y=sorted_y)
+
+    @workflow
+    def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]:
+        return (
+            conditional("terminal_case")
+            .if_(count < 500)
+            .then(merge_sort_locally(in1=in1))
+            .elif_(count < 1000)
+            .then(also_merge_sort_locally(in1=in1))
+            .else_()
+            .then(merge_sort_remotely(in1=in1))
+        )
+
+    with context_manager.FlyteContext.current_context().new_registration_settings(
+        registration_settings=context_manager.RegistrationSettings(
+            project="test_proj",
+            domain="test_domain",
+            version="abc",
+            image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
+            env={},
+        )
+    ) as ctx:
+        with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
+            dynamic_job_spec = merge_sort_remotely.compile_into_workflow(ctx, in1=[2, 3, 4, 5])
+            assert len(dynamic_job_spec.tasks) == 5