diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 1e2af4b3be..44458a53d2 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -138,6 +138,11 @@ def python_interface(self): def construct_node_metadata(self) -> NodeMetadata: # TODO: add support for other Flyte entities + return NodeMetadata( + name=self.name, + ) + + def construct_sub_node_metadata(self) -> NodeMetadata: nm = super().construct_node_metadata() nm._name = self.name return nm diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index d4d8629265..868f657610 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -619,7 +619,7 @@ def get_serializable_array_node_map_task( ) node = workflow_model.Node( id=entity.name, - metadata=entity.construct_node_metadata(), + metadata=entity.construct_sub_node_metadata(), inputs=node.bindings, upstream_node_ids=[], output_aliases=[], diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index f025929a13..d4281227db 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,4 +1,5 @@ import functools +from datetime import timedelta import os import tempfile import typing @@ -386,7 +387,12 @@ def test_serialization_metadata2(serialization_settings): def t1(a: int) -> typing.Optional[int]: return a + 1 - arraynode_maptask = map_task(t1, min_success_ratio=0.9, concurrency=10, metadata=TaskMetadata(retries=2, interruptible=True)) + arraynode_maptask = map_task( + t1, + min_success_ratio=0.9, + concurrency=10, + metadata=TaskMetadata(retries=2, interruptible=True, timeout=timedelta(seconds=10)) + ) assert arraynode_maptask.metadata.interruptible @workflow @@ -402,9 +408,8 @@ def wf1(x: typing.List[int]): od = OrderedDict() wf_spec = get_serializable(od, serialization_settings, wf) - assert arraynode_maptask.construct_node_metadata().interruptible array_node = wf_spec.template.nodes[0] - assert array_node.metadata.interruptible + assert array_node.metadata.timeout == timedelta() assert array_node.array_node._min_success_ratio == 0.9 assert array_node.array_node._parallelism == 10 assert not array_node.array_node._is_original_sub_node_interface @@ -412,6 +417,7 @@ def wf1(x: typing.List[int]): task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible + assert task_spec.template.metadata.timeout == timedelta(seconds=10) wf1_spec = get_serializable(od, serialization_settings, wf1) array_node = wf1_spec.template.nodes[0]