Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support failure node #840

Merged
merged 32 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
823c528
wip
kumare3 Jan 31, 2022
0508159
wip
kumare3 Jan 31, 2022
703b681
updated
kumare3 Feb 1, 2022
1a04695
Failure node added
kumare3 Feb 1, 2022
72d2e84
update
kumare3 Feb 2, 2022
15f5761
Merge branch 'master' into error-handler
kumare3 Feb 15, 2022
052768f
merged master
pingsutw Oct 24, 2023
69163c3
Merge branch 'master' into error-handler
kumare3 Oct 24, 2023
789f295
updated failed node
kumare3 Oct 24, 2023
1c3b6c5
error transformer
pingsutw Oct 25, 2023
12356aa
nit
pingsutw Oct 25, 2023
2bfed27
updated
kumare3 Oct 25, 2023
325bdf2
updated node id
kumare3 Oct 25, 2023
5de495b
wip
pingsutw Nov 14, 2023
c71fcf2
wip
pingsutw Nov 17, 2023
5a41510
Add FlyteError
pingsutw Nov 17, 2023
683964d
merged master
pingsutw Nov 17, 2023
c9b60dd
update tests
pingsutw Nov 17, 2023
c489c64
lint
pingsutw Nov 17, 2023
c83f625
debug
pingsutw Nov 17, 2023
b27f3b7
Merged master
pingsutw Nov 28, 2023
2582a8b
Merge branch 'master' of github.com:flyteorg/flytekit into error-handler
pingsutw Nov 28, 2023
38a67bf
Merge branch 'master' of github.com:flyteorg/flytekit into error-handler
pingsutw Nov 30, 2023
0e5b5dc
Merge branch 'master' of github.com:flyteorg/flytekit into error-handler
pingsutw Nov 30, 2023
c2c1feb
Serialize failure subworkflow node
pingsutw Dec 1, 2023
4559823
Fix lint
pingsutw Dec 1, 2023
7384315
fix tests
pingsutw Dec 1, 2023
2cb4530
fix tests
pingsutw Dec 1, 2023
3a43dc6
nit
pingsutw Dec 1, 2023
a6b9da5
lint
pingsutw Dec 1, 2023
8c075ef
merged master
pingsutw Dec 2, 2023
f6fa9b3
nit
pingsutw Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.base_task import PythonTask, Task
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection, conditional
from flytekit.core.context_manager import (
Expand Down Expand Up @@ -49,6 +49,7 @@
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.types.error import FlyteError

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand Down Expand Up @@ -115,7 +116,7 @@
return _workflow_model.WorkflowMetadataDefaults(interruptible=self.interruptible)


def construct_input_promises(inputs: List[str]):
def construct_input_promises(inputs: List[str]) -> Dict[str, Promise]:
return {
input_name: Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name))
for input_name in inputs
Expand Down Expand Up @@ -181,6 +182,7 @@
workflow_metadata: WorkflowMetadata,
workflow_metadata_defaults: WorkflowMetadataDefaults,
python_interface: Interface,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
**kwargs,
):
Expand All @@ -190,9 +192,11 @@
self._python_interface = python_interface
self._interface = transform_interface_to_typed_interface(python_interface)
self._inputs: Dict[str, Promise] = {}
self._unbound_inputs: set = set()
self._unbound_inputs: typing.Set[Promise] = set()
self._nodes: List[Node] = []
self._output_bindings: List[_literal_models.Binding] = []
self._on_failure = on_failure
self._failure_node = None
self._docs = docs

if self._python_interface.docstring:
Expand Down Expand Up @@ -250,6 +254,14 @@
self.compile()
return self._nodes

@property
def on_failure(self) -> Optional[Union[WorkflowBase, Task]]:
return self._on_failure

@property
def failure_node(self) -> Optional[Node]:
return self._failure_node

def __repr__(self):
return (
f"WorkflowBase - {self._name} && "
Expand All @@ -275,6 +287,10 @@
try:
return flyte_entity_call_handler(self, *args, **input_kwargs)
except Exception as exc:
if self.on_failure:
if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs:
input_kwargs["err"] = FlyteError(failed_node_id="", message=str(exc))
self.on_failure(**input_kwargs)
exc.args = (f"Encountered error while executing workflow '{self.name}':\n {exc}", *exc.args[1:])
raise exc

Expand Down Expand Up @@ -629,10 +645,11 @@

def __init__(
self,
workflow_function: Callable[..., Any],
workflow_function: Callable,
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
Expand All @@ -648,6 +665,7 @@
workflow_metadata=metadata,
workflow_metadata_defaults=default_metadata,
python_interface=native_interface,
on_failure=on_failure,
docs=docs,
)
self.compiled = False
Expand All @@ -659,13 +677,30 @@
def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore
return f"{self.name}.{t.__module__}.{t.name}"

def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]):
# Compare
with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
c = wf_args.copy()
exception_scopes.user_entry_point(self.on_failure)(**c)
inner_nodes = None
if inner_comp_ctx.compilation_state and inner_comp_ctx.compilation_state.nodes:
inner_nodes = inner_comp_ctx.compilation_state.nodes
if not inner_nodes or len(inner_nodes) > 1:
raise AssertionError("Unable to compile failure node, only either a task or a workflow can be used")

Check warning on line 693 in flytekit/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/workflow.py#L693

Added line #L693 was not covered by tests
self._failure_node = inner_nodes[0]

def compile(self, **kwargs):
"""
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
a 'closure' in the traditional sense of the word.
"""
if self.compiled:
return

self.compiled = True
ctx = FlyteContextManager.current_context()
self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface)
Expand All @@ -691,6 +726,8 @@
logger.debug(f"WF {self.name} saving task {n.flyte_entity.name}")
self.add(n.flyte_entity)

self._validate_add_on_failure_handler(comp_ctx, comp_ctx.compilation_state.prefix + "f", input_kwargs)

# Iterate through the workflow outputs
bindings = []
output_names = list(self.interface.outputs.keys())
Expand Down Expand Up @@ -735,9 +772,10 @@
)
bindings.append(b)

# Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
# Save all the things necessary to create an WorkflowTemplate, except for the missing project and domain
self._nodes = all_nodes
self._output_bindings = bindings

if not output_names:
return None
if len(output_names) == 1:
Expand All @@ -758,6 +796,7 @@
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]:
...
Expand All @@ -768,6 +807,7 @@
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]:
...
Expand All @@ -777,6 +817,7 @@
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]:
"""
Expand Down Expand Up @@ -806,6 +847,8 @@
:param _workflow_function: This argument is implicitly passed and represents the decorated function.
:param failure_policy: Use the options in flytekit.WorkflowFailurePolicy
:param interruptible: Whether or not tasks launched from this workflow are by default interruptible
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
"""

Expand All @@ -819,6 +862,7 @@
metadata=workflow_metadata,
default_metadata=workflow_metadata_defaults,
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
)
update_wrapper(workflow_instance, fn)
Expand Down
9 changes: 5 additions & 4 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.core import types as _core_types
from flytekit.models.types import Error
from flytekit.models.types import LiteralType as _LiteralType
from flytekit.models.types import OutputReference as _OutputReference
from flytekit.models.types import SchemaType as _SchemaType
Expand Down Expand Up @@ -709,7 +710,7 @@ def __init__(
schema: Schema = None,
union: Union = None,
none_type: Void = None,
error=None,
error: Error = None,
generic: Struct = None,
structured_dataset: StructuredDataset = None,
):
Expand All @@ -721,7 +722,7 @@ def __init__(
:param Binary binary:
:param Schema schema:
:param Void none_type:
:param error:
:param Error error:
:param google.protobuf.struct_pb2.Struct generic:
:param StructuredDataset structured_dataset:
"""
Expand Down Expand Up @@ -781,7 +782,7 @@ def none_type(self):
@property
def error(self):
"""
:rtype: TODO
:rtype: Error
"""
return self._error

Expand Down Expand Up @@ -825,7 +826,7 @@ def to_flyte_idl(self):
schema=self.schema.to_flyte_idl() if self.schema is not None else None,
union=self.union.to_flyte_idl() if self.union is not None else None,
none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None,
error=self.error if self.error is not None else None,
error=self.error.to_flyte_idl() if self.error is not None else None,
generic=self.generic,
structured_dataset=self.structured_dataset.to_flyte_idl() if self.structured_dataset is not None else None,
)
Expand Down
10 changes: 7 additions & 3 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
"""
This is a oneof message, only one of the kwargs may be set, representing one of the Flyte types.

:param int simple: Enum type from SimpleType
:param SimpleType simple: Enum type from SimpleType
:param SchemaType schema: Type definition for a dataframe-like object.
:param LiteralType collection_type: For list-like objects, this is the type of each entry in the list.
:param LiteralType map_value_type: For map objects, this is the type of the value. The key must always be a
Expand Down Expand Up @@ -489,6 +489,10 @@ def __init__(self, failed_node_id: str, message: str):
def message(self) -> str:
return self._message

@property
def failed_node_id(self) -> str:
return self._failed_node_id

def to_flyte_idl(self) -> _types_pb2.Error:
return _types_pb2.Error(
message=self._message,
Expand All @@ -498,7 +502,7 @@ def to_flyte_idl(self) -> _types_pb2.Error:
@classmethod
def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error":
"""
:param flyteidl.core.types.OutputReference pb2_object:
:rtype: OutputReference
:param flyteidl.core.types.Error pb2_object:
:rtype: Error
"""
return cls(failed_node_id=pb2_object.failed_node_id, message=pb2_object.message)
9 changes: 9 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@
sub_wfs.append(leaf_node.flyte_entity)
sub_wfs.extend([s for s in leaf_node.flyte_entity.sub_workflows.values()])

serialized_failure_node = None
if entity.failure_node:
serialized_failure_node = get_serializable(entity_mapping, settings, entity.failure_node, options)
if isinstance(entity.failure_node.flyte_entity, WorkflowBase):
sub_wf_spec = get_serializable(entity_mapping, settings, entity.failure_node.flyte_entity, options)
sub_wfs.append(sub_wf_spec.template)
sub_wfs.extend(sub_wf_spec.sub_workflows)

Check warning on line 307 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L305-L307

Added lines #L305 - L307 were not covered by tests

wf_id = _identifier_model.Identifier(
resource_type=_identifier_model.ResourceType.WORKFLOW,
project=settings.project,
Expand All @@ -312,6 +320,7 @@
interface=entity.interface,
nodes=serialized_nodes,
outputs=entity.output_bindings,
failure_node=serialized_failure_node,
)

return admin_workflow_models.WorkflowSpec(
Expand Down
12 changes: 12 additions & 0 deletions flytekit/types/error/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Flytekit Error Type
==========================================================
.. currentmodule:: flytekit.types.error

.. autosummary::
:toctree: generated/

FlyteError
"""

from .error import FlyteError
58 changes: 58 additions & 0 deletions flytekit/types/error/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Type, TypeVar

from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models import types as _type_models
from flytekit.models.literals import Error, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")


@dataclass
class FlyteError(DataClassJSONMixin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do not think we need this right, though lets keep it

"""
Special Task type that will be used in the failure node. Propeller will pass this error to failure task, so users
have to add an input with this type to the failure task.
"""

message: str
failed_node_id: str


class ErrorTransformer(TypeTransformer[FlyteError]):
"""
Enables converting a python type FlyteError to LiteralType.Error
"""

def __init__(self):
super().__init__(name="FlyteError", t=FlyteError)

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=_type_models.SimpleType.ERROR)

def to_literal(
self, ctx: FlyteContext, python_val: FlyteError, python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val) != FlyteError:
raise TypeTransformerFailedError(

Check warning on line 41 in flytekit/types/error/error.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/error/error.py#L41

Added line #L41 was not covered by tests
f"Expected value of type {FlyteError} but got '{python_val}' of type {type(python_val)}"
)
return Literal(scalar=Scalar(error=Error(message=python_val.message, failed_node_id=python_val.failed_node_id)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if not (lv and lv.scalar and lv.scalar.error is not None):
raise TypeTransformerFailedError("Can only convert a generic literal to FlyteError")

Check warning on line 48 in flytekit/types/error/error.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/error/error.py#L48

Added line #L48 was not covered by tests
return FlyteError(message=lv.scalar.error.message, failed_node_id=lv.scalar.error.failed_node_id)

def guess_python_type(self, literal_type: LiteralType) -> Type[FlyteError]:
if literal_type.simple and literal_type.simple == _type_models.SimpleType.ERROR:
return FlyteError

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(ErrorTransformer())
Loading