Skip to content

Commit

Permalink
Backport to v1.13 - Set map task metadata only for subnode (#2993)
Browse files Browse the repository at this point in the history
* Set map task metadata only for subnode (#2979)

* set metadata for subnode only

Signed-off-by: Paul Dittamo <[email protected]>

* update unit test

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>

* Set map task metadata only for subnode (#2979)

* set metadata for subnode only

Signed-off-by: Paul Dittamo <[email protected]>

* update unit test

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Paul Dittamo <[email protected]>
  • Loading branch information
eapolinario and pvditt authored Dec 9, 2024
1 parent 01c51b9 commit 272c9c5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
5 changes: 5 additions & 0 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def python_interface(self):

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
return NodeMetadata(
name=self.name,
)

def construct_sub_node_metadata(self) -> NodeMetadata:
nm = super().construct_node_metadata()
nm._name = self.name
return nm
Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def get_serializable_array_node_map_task(
)
node = workflow_model.Node(
id=entity.name,
metadata=entity.construct_node_metadata(),
metadata=entity.construct_sub_node_metadata(),
inputs=node.bindings,
upstream_node_ids=[],
output_aliases=[],
Expand Down
12 changes: 9 additions & 3 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from datetime import timedelta
import os
import typing
from collections import OrderedDict
Expand Down Expand Up @@ -377,7 +378,12 @@ def test_serialization_metadata2(serialization_settings):
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2, interruptible=True))
arraynode_maptask = map_task(
t1,
min_success_ratio=0.9,
concurrency=10,
metadata=TaskMetadata(retries=2, interruptible=True, timeout=timedelta(seconds=10))
)
assert arraynode_maptask.metadata.interruptible

@workflow
Expand All @@ -387,11 +393,11 @@ def wf(x: typing.List[int]):
od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert arraynode_maptask.construct_node_metadata().interruptible
assert wf_spec.template.nodes[0].metadata.interruptible
assert wf_spec.template.nodes[0].metadata.timeout == timedelta()
task_spec = od[arraynode_maptask]
assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.metadata.interruptible
assert task_spec.template.metadata.timeout == timedelta(seconds=10)


def test_serialization_extended_resources(serialization_settings):
Expand Down

0 comments on commit 272c9c5

Please sign in to comment.