Skip to content

Commit

Permalink
Add support failure node (#840)
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
kumare3 and pingsutw authored Dec 12, 2023
1 parent 530ad26 commit 613a655
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 19 deletions.
54 changes: 49 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.base_task import PythonTask, Task
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection, conditional
from flytekit.core.context_manager import (
Expand Down Expand Up @@ -49,6 +49,7 @@
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.types.error import FlyteError

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand Down Expand Up @@ -115,7 +116,7 @@ def to_flyte_model(self):
return _workflow_model.WorkflowMetadataDefaults(interruptible=self.interruptible)


def construct_input_promises(inputs: List[str]):
def construct_input_promises(inputs: List[str]) -> Dict[str, Promise]:
return {
input_name: Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name))
for input_name in inputs
Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(
workflow_metadata: WorkflowMetadata,
workflow_metadata_defaults: WorkflowMetadataDefaults,
python_interface: Interface,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
**kwargs,
):
Expand All @@ -190,9 +192,11 @@ def __init__(
self._python_interface = python_interface
self._interface = transform_interface_to_typed_interface(python_interface)
self._inputs: Dict[str, Promise] = {}
self._unbound_inputs: set = set()
self._unbound_inputs: typing.Set[Promise] = set()
self._nodes: List[Node] = []
self._output_bindings: List[_literal_models.Binding] = []
self._on_failure = on_failure
self._failure_node = None
self._docs = docs

if self._python_interface.docstring:
Expand Down Expand Up @@ -250,6 +254,14 @@ def nodes(self) -> List[Node]:
self.compile()
return self._nodes

@property
def on_failure(self) -> Optional[Union[WorkflowBase, Task]]:
return self._on_failure

@property
def failure_node(self) -> Optional[Node]:
return self._failure_node

def __repr__(self):
return (
f"WorkflowBase - {self._name} && "
Expand All @@ -275,6 +287,10 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
try:
return flyte_entity_call_handler(self, *args, **input_kwargs)
except Exception as exc:
if self.on_failure:
if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs:
input_kwargs["err"] = FlyteError(failed_node_id="", message=str(exc))
self.on_failure(**input_kwargs)
exc.args = (f"Encountered error while executing workflow '{self.name}':\n {exc}", *exc.args[1:])
raise exc

Expand Down Expand Up @@ -629,10 +645,11 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver):

def __init__(
self,
workflow_function: Callable[..., Any],
workflow_function: Callable,
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
Expand All @@ -648,6 +665,7 @@ def __init__(
workflow_metadata=metadata,
workflow_metadata_defaults=default_metadata,
python_interface=native_interface,
on_failure=on_failure,
docs=docs,
)
self.compiled = False
Expand All @@ -659,13 +677,30 @@ def function(self):
def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore
return f"{self.name}.{t.__module__}.{t.name}"

def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]):
# Compare
with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
c = wf_args.copy()
exception_scopes.user_entry_point(self.on_failure)(**c)
inner_nodes = None
if inner_comp_ctx.compilation_state and inner_comp_ctx.compilation_state.nodes:
inner_nodes = inner_comp_ctx.compilation_state.nodes
if not inner_nodes or len(inner_nodes) > 1:
raise AssertionError("Unable to compile failure node, only either a task or a workflow can be used")
self._failure_node = inner_nodes[0]

def compile(self, **kwargs):
"""
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
a 'closure' in the traditional sense of the word.
"""
if self.compiled:
return

self.compiled = True
ctx = FlyteContextManager.current_context()
self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface)
Expand All @@ -691,6 +726,8 @@ def compile(self, **kwargs):
logger.debug(f"WF {self.name} saving task {n.flyte_entity.name}")
self.add(n.flyte_entity)

self._validate_add_on_failure_handler(comp_ctx, comp_ctx.compilation_state.prefix + "f", input_kwargs)

# Iterate through the workflow outputs
bindings = []
output_names = list(self.interface.outputs.keys())
Expand Down Expand Up @@ -735,9 +772,10 @@ def compile(self, **kwargs):
)
bindings.append(b)

# Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
# Save all the things necessary to create an WorkflowTemplate, except for the missing project and domain
self._nodes = all_nodes
self._output_bindings = bindings

if not output_names:
return None
if len(output_names) == 1:
Expand All @@ -758,6 +796,7 @@ def workflow(
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]:
...
Expand All @@ -768,6 +807,7 @@ def workflow(
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]:
...
Expand All @@ -777,6 +817,7 @@ def workflow(
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]:
"""
Expand Down Expand Up @@ -806,6 +847,8 @@ def workflow(
:param _workflow_function: This argument is implicitly passed and represents the decorated function.
:param failure_policy: Use the options in flytekit.WorkflowFailurePolicy
:param interruptible: Whether or not tasks launched from this workflow are by default interruptible
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
"""

Expand All @@ -819,6 +862,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
metadata=workflow_metadata,
default_metadata=workflow_metadata_defaults,
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
)
update_wrapper(workflow_instance, fn)
Expand Down
9 changes: 5 additions & 4 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.core import types as _core_types
from flytekit.models.types import Error
from flytekit.models.types import LiteralType as _LiteralType
from flytekit.models.types import OutputReference as _OutputReference
from flytekit.models.types import SchemaType as _SchemaType
Expand Down Expand Up @@ -709,7 +710,7 @@ def __init__(
schema: Schema = None,
union: Union = None,
none_type: Void = None,
error=None,
error: Error = None,
generic: Struct = None,
structured_dataset: StructuredDataset = None,
):
Expand All @@ -721,7 +722,7 @@ def __init__(
:param Binary binary:
:param Schema schema:
:param Void none_type:
:param error:
:param Error error:
:param google.protobuf.struct_pb2.Struct generic:
:param StructuredDataset structured_dataset:
"""
Expand Down Expand Up @@ -781,7 +782,7 @@ def none_type(self):
@property
def error(self):
"""
:rtype: TODO
:rtype: Error
"""
return self._error

Expand Down Expand Up @@ -825,7 +826,7 @@ def to_flyte_idl(self):
schema=self.schema.to_flyte_idl() if self.schema is not None else None,
union=self.union.to_flyte_idl() if self.union is not None else None,
none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None,
error=self.error if self.error is not None else None,
error=self.error.to_flyte_idl() if self.error is not None else None,
generic=self.generic,
structured_dataset=self.structured_dataset.to_flyte_idl() if self.structured_dataset is not None else None,
)
Expand Down
10 changes: 7 additions & 3 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
"""
This is a oneof message, only one of the kwargs may be set, representing one of the Flyte types.
:param int simple: Enum type from SimpleType
:param SimpleType simple: Enum type from SimpleType
:param SchemaType schema: Type definition for a dataframe-like object.
:param LiteralType collection_type: For list-like objects, this is the type of each entry in the list.
:param LiteralType map_value_type: For map objects, this is the type of the value. The key must always be a
Expand Down Expand Up @@ -489,6 +489,10 @@ def __init__(self, failed_node_id: str, message: str):
def message(self) -> str:
return self._message

@property
def failed_node_id(self) -> str:
return self._failed_node_id

def to_flyte_idl(self) -> _types_pb2.Error:
return _types_pb2.Error(
message=self._message,
Expand All @@ -498,7 +502,7 @@ def to_flyte_idl(self) -> _types_pb2.Error:
@classmethod
def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error":
"""
:param flyteidl.core.types.OutputReference pb2_object:
:rtype: OutputReference
:param flyteidl.core.types.Error pb2_object:
:rtype: Error
"""
return cls(failed_node_id=pb2_object.failed_node_id, message=pb2_object.message)
9 changes: 9 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def get_serializable_workflow(
sub_wfs.append(leaf_node.flyte_entity)
sub_wfs.extend([s for s in leaf_node.flyte_entity.sub_workflows.values()])

serialized_failure_node = None
if entity.failure_node:
serialized_failure_node = get_serializable(entity_mapping, settings, entity.failure_node, options)
if isinstance(entity.failure_node.flyte_entity, WorkflowBase):
sub_wf_spec = get_serializable(entity_mapping, settings, entity.failure_node.flyte_entity, options)
sub_wfs.append(sub_wf_spec.template)
sub_wfs.extend(sub_wf_spec.sub_workflows)

wf_id = _identifier_model.Identifier(
resource_type=_identifier_model.ResourceType.WORKFLOW,
project=settings.project,
Expand All @@ -310,6 +318,7 @@ def get_serializable_workflow(
interface=entity.interface,
nodes=serialized_nodes,
outputs=entity.output_bindings,
failure_node=serialized_failure_node,
)

return admin_workflow_models.WorkflowSpec(
Expand Down
12 changes: 12 additions & 0 deletions flytekit/types/error/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Flytekit Error Type
==========================================================
.. currentmodule:: flytekit.types.error
.. autosummary::
:toctree: generated/
FlyteError
"""

from .error import FlyteError
58 changes: 58 additions & 0 deletions flytekit/types/error/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Type, TypeVar

from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models import types as _type_models
from flytekit.models.literals import Error, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")


@dataclass
class FlyteError(DataClassJSONMixin):
"""
Special Task type that will be used in the failure node. Propeller will pass this error to failure task, so users
have to add an input with this type to the failure task.
"""

message: str
failed_node_id: str


class ErrorTransformer(TypeTransformer[FlyteError]):
"""
Enables converting a python type FlyteError to LiteralType.Error
"""

def __init__(self):
super().__init__(name="FlyteError", t=FlyteError)

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=_type_models.SimpleType.ERROR)

def to_literal(
self, ctx: FlyteContext, python_val: FlyteError, python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val) != FlyteError:
raise TypeTransformerFailedError(
f"Expected value of type {FlyteError} but got '{python_val}' of type {type(python_val)}"
)
return Literal(scalar=Scalar(error=Error(message=python_val.message, failed_node_id=python_val.failed_node_id)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if not (lv and lv.scalar and lv.scalar.error is not None):
raise TypeTransformerFailedError("Can only convert a generic literal to FlyteError")
return FlyteError(message=lv.scalar.error.message, failed_node_id=lv.scalar.error.failed_node_id)

def guess_python_type(self, literal_type: LiteralType) -> Type[FlyteError]:
if literal_type.simple and literal_type.simple == _type_models.SimpleType.ERROR:
return FlyteError

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(ErrorTransformer())
Loading

0 comments on commit 613a655

Please sign in to comment.