Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[easy] Show failed_node_id in failure node local execution #2334

Merged
merged 4 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
except Exception as exc:
if self.on_failure:
if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs:
input_kwargs["err"] = FlyteError(failed_node_id="", message=str(exc))
id = self.failure_node.id if self.failure_node else ""
input_kwargs["err"] = FlyteError(failed_node_id=id, message=str(exc))
self.on_failure(**input_kwargs)
raise exc

Expand Down
48 changes: 47 additions & 1 deletion tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import typing
from collections import OrderedDict
from unittest.mock import patch

import pytest
from typing_extensions import Annotated # type: ignore
Expand All @@ -15,6 +16,7 @@
from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.tools.translator import get_serializable
from flytekit.types.error.error import FlyteError

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -51,7 +53,7 @@ def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c",
@workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
def wf(a: int) -> typing.Tuple[str, str]:
x, y = t1(a=a)
u, v = t1(a=x)
_, v = t1(a=x)
return y, v

wf_spec = get_serializable(OrderedDict(), serialization_settings, wf)
Expand Down Expand Up @@ -435,3 +437,47 @@ def wf():
t4()

assert ctx.compilation_state is None


@patch("builtins.print")
def test_failure_node_local_execution(mock_print):
@task
def clean_up(name: str, err: typing.Optional[FlyteError] = None):
print(f"Deleting cluster {name} due to {err}")
print("This is err:", str(err))

@task
def create_cluster(name: str):
print(f"Creating cluster: {name}")

@task
def delete_cluster(name: str, err: typing.Optional[FlyteError] = None):
print(f"Deleting cluster {name}")
print(err)

@task
def t1(a: int, b: str):
print(f"{a} {b}")
raise ValueError("Fail!")

@workflow(on_failure=clean_up)
def wf(name: str = "flyteorg"):
c = create_cluster(name=name)
t = t1(a=1, b="2")
d = delete_cluster(name=name)
c >> t >> d

with pytest.raises(ValueError):
wf()

# Adjusted the error message to match the one in the failure
expected_error_message = str(
FlyteError(message="Error encountered while executing 'wf':\n Fail!", failed_node_id="fn0")
)

assert mock_print.call_count > 0

mock_print.assert_any_call("Creating cluster: flyteorg")
mock_print.assert_any_call("1 2")
mock_print.assert_any_call(f"Deleting cluster flyteorg due to {expected_error_message}")
mock_print.assert_any_call("This is err:", expected_error_message)
Loading