diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index ae3aa7db4e..1ad74a56bf 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -7,6 +7,7 @@ from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -17,13 +18,22 @@ def test_normal_task(): def t1(a: str) -> str: return a + " world" + @dynamic + def my_subwf(a: int) -> typing.List[str]: + s = [] + for i in range(a): + s.append(t1(a=str(i))) + return s + @workflow - def my_wf(a: str) -> str: + def my_wf(a: str) -> (str, typing.List[str]): t1_node = create_node(t1, a=a) - return t1_node.o0 + dyn_node = create_node(my_subwf, a=3) + return t1_node.o0, dyn_node.o0 - r = my_wf(a="hello") + r, x = my_wf(a="hello") assert r == "hello world" + assert x == ["0 world", "1 world", "2 world"] serialization_settings = context_manager.SerializationSettings( project="test_proj", @@ -33,8 +43,8 @@ def my_wf(a: str) -> str: env={}, ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, my_wf) - assert len(sdk_wf.nodes) == 1 - assert len(sdk_wf.outputs) == 1 + assert len(sdk_wf.nodes) == 2 + assert len(sdk_wf.outputs) == 2 @task def t2():