diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 108b323a48..0f25374717 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -288,7 +288,8 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis except Exception as exc: if self.on_failure: if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs: - input_kwargs["err"] = FlyteError(failed_node_id="", message=str(exc)) + id = self.failure_node.id if self.failure_node else "" + input_kwargs["err"] = FlyteError(failed_node_id=id, message=str(exc)) self.on_failure(**input_kwargs) raise exc diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index f0ba150f73..60daf80af9 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -2,6 +2,7 @@ import sys import typing from collections import OrderedDict +from unittest.mock import patch import pytest from typing_extensions import Annotated # type: ignore @@ -15,6 +16,7 @@ from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.tools.translator import get_serializable +from flytekit.types.error.error import FlyteError default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -51,7 +53,7 @@ def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) - u, v = t1(a=x) + _, v = t1(a=x) return y, v wf_spec = get_serializable(OrderedDict(), serialization_settings, wf) @@ -435,3 +437,47 @@ def wf(): t4() assert ctx.compilation_state is None + + +@patch("builtins.print") +def test_failure_node_local_execution(mock_print): + @task + def clean_up(name: str, err: typing.Optional[FlyteError] = None): + print(f"Deleting cluster {name} due to {err}") + print("This is err:", str(err)) + + @task + def create_cluster(name: str): + print(f"Creating cluster: {name}") + + @task + def delete_cluster(name: str, err: typing.Optional[FlyteError] = None): + print(f"Deleting cluster {name}") + print(err) + + @task + def t1(a: int, b: str): + print(f"{a} {b}") + raise ValueError("Fail!") + + @workflow(on_failure=clean_up) + def wf(name: str = "flyteorg"): + c = create_cluster(name=name) + t = t1(a=1, b="2") + d = delete_cluster(name=name) + c >> t >> d + + with pytest.raises(ValueError): + wf() + + # Adjusted the error message to match the one in the failure + expected_error_message = str( + FlyteError(message="Error encountered while executing 'wf':\n Fail!", failed_node_id="fn0") + ) + + assert mock_print.call_count > 0 + + mock_print.assert_any_call("Creating cluster: flyteorg") + mock_print.assert_any_call("1 2") + mock_print.assert_any_call(f"Deleting cluster flyteorg due to {expected_error_message}") + mock_print.assert_any_call("This is err:", expected_error_message)