Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overriding node metadata for array node #2865

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import TaskMetadata
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand All @@ -41,7 +40,7 @@ def __init__(
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None,
metadata: Optional[_workflow_model.NodeMetadata] = None,
):
"""
:param target: The target Flyte entity to map over
Expand All @@ -53,7 +52,7 @@ def __init__(
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
:param metadata: The metadata for the underlying node
"""
from flytekit.remote import FlyteLaunchPlan

Expand All @@ -62,6 +61,7 @@ def __init__(
self._execution_mode = execution_mode
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata

if min_successes is not None:
self._min_successes = min_successes
Expand Down Expand Up @@ -92,22 +92,15 @@ def __init__(
else:
raise ValueError("No interface found for the target entity.")

self.metadata = None
if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
if isinstance(metadata, _workflow_model.NodeMetadata):
self.metadata = metadata
else:
raise TypeError("Invalid metadata for LaunchPlan. Should be NodeMetadata.")
else:
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
# TODO - include passed in metadata
return _workflow_model.NodeMetadata(name=self.target.name)
return self.metadata or _workflow_model.NodeMetadata(name=self.target.name)
Copy link
Contributor Author

@pvditt pvditt Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes aren't really necessary - just cleaning up code. Don't even support directly passing in node metadata atm.


@property
def name(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def get_serializable_node(
if isinstance(entity.flyte_entity, ArrayNode):
node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.flyte_entity.construct_node_metadata(),
metadata=entity.metadata,
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
Expand Down Expand Up @@ -587,6 +587,8 @@ def get_serializable_array_node(
options: Optional[Options] = None,
) -> ArrayNodeModel:
array_node = node.flyte_entity
# pass in parent node metadata to be set for subnode
array_node.metadata = node.metadata
return ArrayNodeModel(
node=get_serializable_node(entity_mapping, settings, array_node, options=options),
parallelism=array_node.concurrency,
Expand Down
54 changes: 38 additions & 16 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def grandparent_wf() -> typing.List[int]:
return grandparent_wf


def get_grandparent_wf_with_overrides(serialization_settings):
@workflow
def grandparent_wf_with_overrides() -> typing.List[int]:
return array_node(
lp, concurrency=10, min_success_ratio=0.9
)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]).with_overrides(cache=True, cache_version="1.0")

return grandparent_wf_with_overrides


def get_grandparent_remote_wf(serialization_settings):
serialized = OrderedDict()
lp_model = get_serializable(serialized, serialization_settings, lp)
Expand All @@ -73,33 +83,34 @@ def grandparent_remote_wf() -> typing.List[int]:


@pytest.mark.parametrize(
"target",
("target", "overrides_metadata"),
[
get_grandparent_wf,
get_grandparent_remote_wf,
(get_grandparent_wf, False),
(get_grandparent_remote_wf, False),
(get_grandparent_wf_with_overrides, True),
],
)
def test_lp_serialization(target, serialization_settings):
def test_lp_serialization(target, overrides_metadata, serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, target(serialization_settings))
assert len(wf_spec.template.nodes) == 1

top_level = wf_spec.template.nodes[0]
assert top_level.inputs[0].var == "a"
assert len(top_level.inputs[0].binding.collection.bindings) == 3
for binding in top_level.inputs[0].binding.collection.bindings:
parent_node = wf_spec.template.nodes[0]
assert parent_node.inputs[0].var == "a"
assert len(parent_node.inputs[0].binding.collection.bindings) == 3
for binding in parent_node.inputs[0].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None
assert top_level.inputs[1].var == "b"
for binding in top_level.inputs[1].binding.collection.bindings:
assert parent_node.inputs[1].var == "b"
for binding in parent_node.inputs[1].binding.collection.bindings:
assert (binding.scalar.union is not None or
binding.scalar.primitive.integer is not None or
binding.scalar.primitive.string_value is not None)
assert len(top_level.inputs[1].binding.collection.bindings) == 3
assert top_level.inputs[2].var == "c"
assert len(top_level.inputs[2].binding.collection.bindings) == 3
for binding in top_level.inputs[2].binding.collection.bindings:
assert len(parent_node.inputs[1].binding.collection.bindings) == 3
assert parent_node.inputs[2].var == "c"
assert len(parent_node.inputs[2].binding.collection.bindings) == 3
for binding in parent_node.inputs[2].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None

serialized_array_node = top_level.array_node
serialized_array_node = parent_node.array_node
assert (
serialized_array_node.node.workflow_node.launchplan_ref.resource_type
== identifier_models.ResourceType.LAUNCH_PLAN
Expand All @@ -112,7 +123,18 @@ def test_lp_serialization(target, serialization_settings):
assert serialized_array_node._parallelism == 10

subnode = serialized_array_node.node
assert subnode.inputs == top_level.inputs
assert subnode.inputs == parent_node.inputs

if overrides_metadata:
assert parent_node.metadata.cacheable
assert parent_node.metadata.cache_version == "1.0"
assert subnode.metadata.cacheable
assert subnode.metadata.cache_version == "1.0"
else:
assert not parent_node.metadata.cacheable
assert not parent_node.metadata.cache_version
assert not subnode.metadata.cacheable
assert not subnode.metadata.cache_version


@pytest.mark.parametrize(
Expand Down
Loading