From 572ff3e956609ebbcaf2fe9ff45c7c13f9c0ff38 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 12 Jul 2021 16:52:53 -0700 Subject: [PATCH] Clean up scopes and exception handling for new tasks (#543) Signed-off-by: wild-endeavor --- docs/source/design/authoring.rst | 25 +++ flytekit/bin/entrypoint.py | 45 ++++- flytekit/common/exceptions/scopes.py | 74 ++++--- flytekit/core/map_task.py | 6 +- flytekit/core/python_function_task.py | 5 +- flytekit/core/resources.py | 2 +- flytekit/core/schedule.py | 2 +- flytekit/core/workflow.py | 19 +- flytekit/engines/flyte/engine.py | 6 +- flytekit/models/core/errors.py | 22 +- flytekit/models/core/execution.py | 29 ++- flytekit/models/literals.py | 9 +- .../unit/bin/test_python_entrypoint.py | 190 +++++++++++++++++- tests/flytekit/unit/core/test_imperative.py | 19 +- .../flytekit/unit/models/core/test_errors.py | 12 +- .../unit/models/core/test_execution.py | 4 +- 16 files changed, 373 insertions(+), 96 deletions(-) diff --git a/docs/source/design/authoring.rst b/docs/source/design/authoring.rst index b69d7f270a..9b9fed7a9b 100644 --- a/docs/source/design/authoring.rst +++ b/docs/source/design/authoring.rst @@ -65,6 +65,31 @@ There is also only one :py:class:`LaunchPlan int: """ :rtype: int """ @@ -138,38 +137,42 @@ def _is_base_context(): @_decorator def system_entry_point(wrapped, instance, args, kwargs): """ - Decorator for wrapping functions that enter a system context. This should decorate every method a user might - call. This will allow us to add differentiation between what is a user error and what is a system failure. - Furthermore, we will clean the exception trace so as to make more sense to the user--allowing them to know if they - should take action themselves or pass on to the platform owners. We will dispatch metrics and such appropriately. + The reason these two (see the user one below) decorators exist is to categorize non-Flyte exceptions at arbitrary + locations. For example, while there is a separate ecosystem of Flyte-defined user and system exceptions + (see the FlyteException hierarchy), and we can easily understand and categorize those, if flytekit comes upon + a random ``ValueError`` or other non-flytekit defined error, how would we know if it was a bug in flytekit versus an + error with user code or something the user called? The purpose of these decorators is to categorize those (see + the last case in the nested try/catch below. + + Decorator for wrapping functions that enter a system context. This should decorate every method that may invoke some + user code later on down the line. This will allow us to add differentiation between what is a user error and + what is a system failure. Furthermore, we will clean the exception trace so as to make more sense to the + user -- allowing them to know if they should take action themselves or pass on to the platform owners. + We will dispatch metrics and such appropriately. """ try: _CONTEXT_STACK.append(_SYSTEM_CONTEXT) if _is_base_context(): + # If this is the first time either of this decorator, or the one below is called, then we unwrap the + # exception. The first time these decorators are used is currently in the entrypoint.py file. The scoped + # exceptions are unwrapped because at that point, we want to return the underlying error to the user. try: return wrapped(*args, **kwargs) except FlyteScopedException as ex: - _reraise(ex.type, ex.value, ex.traceback) + raise ex.value else: try: return wrapped(*args, **kwargs) - except FlyteScopedException: - # Just pass-on the exception that is already wrapped and scoped - _reraise(*_exc_info()) + except FlyteScopedException as scoped: + raise scoped except _user_exceptions.FlyteUserException: # Re-raise from here. - _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2], - ) + raise FlyteScopedUserException(*_exc_info()) except Exception: + # This is why this function exists - arbitrary exceptions that we don't know what to do with are + # interpreted as system errors. # System error, raise full stack-trace all the way up the chain. - _reraise( - FlyteScopedSystemException, - FlyteScopedSystemException(*_exc_info(), kind=_error_model.ContainerError.Kind.RECOVERABLE), - _exc_info()[2], - ) + raise FlyteScopedSystemException(*_exc_info(), kind=_error_model.ContainerError.Kind.RECOVERABLE) finally: _CONTEXT_STACK.pop() @@ -177,6 +180,8 @@ def system_entry_point(wrapped, instance, args, kwargs): @_decorator def user_entry_point(wrapped, instance, args, kwargs): """ + See the comment for the system_entry_point above as well. + Decorator for wrapping functions that enter into a user context. This will help us differentiate user-created failures even when it is re-entrant into system code. @@ -188,35 +193,24 @@ def user_entry_point(wrapped, instance, args, kwargs): try: _CONTEXT_STACK.append(_USER_CONTEXT) if _is_base_context(): + # See comment at this location for system_entry_point try: return wrapped(*args, **kwargs) except FlyteScopedException as ex: - _reraise(ex.type, ex.value, ex.traceback) + raise ex.value else: try: return wrapped(*args, **kwargs) - except FlyteScopedException: - # Just pass on the already wrapped and scoped exception - _reraise(*_exc_info()) + except FlyteScopedException as scoped: + raise scoped except _user_exceptions.FlyteUserException: - _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2], - ) + raise FlyteScopedUserException(*_exc_info()) except _system_exceptions.FlyteSystemException: - _reraise( - FlyteScopedSystemException, - FlyteScopedSystemException(*_exc_info()), - _exc_info()[2], - ) + raise FlyteScopedSystemException(*_exc_info()) except Exception: - # Any non-platform raised exception is a user exception. + # This is why this function exists - arbitrary exceptions that we don't know what to do with are + # interpreted as user exceptions. # This will also catch FlyteUserException re-raised by the system_entry_point handler - _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2], - ) + raise FlyteScopedUserException(*_exc_info()) finally: _CONTEXT_STACK.pop() diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 89ef4343bd..607af0855c 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -2,12 +2,14 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ + import os from contextlib import contextmanager from itertools import count from typing import Any, Dict, List, Optional, Type from flytekit.common.constants import SdkTaskType +from flytekit.common.exceptions import scopes as exception_scopes from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings from flytekit.core.interface import transform_interface_to_list_interface @@ -168,7 +170,7 @@ def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: map_task_inputs = {} for k in self.interface.inputs.keys(): map_task_inputs[k] = kwargs[k][task_index] - return self._run_task.execute(**map_task_inputs) + return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs) def _raw_execute(self, **kwargs) -> Any: """ @@ -190,7 +192,7 @@ def _raw_execute(self, **kwargs) -> Any: single_instance_inputs = {} for k in self.interface.inputs.keys(): single_instance_inputs[k] = kwargs[k][i] - o = self._run_task.execute(**single_instance_inputs) + o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 43b8e864cb..d2d0b5b341 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -20,6 +20,7 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union +from flytekit.common.exceptions import scopes as exception_scopes from flytekit.core.base_task import TaskResolverMixin from flytekit.core.context_manager import ( ExecutionState, @@ -156,7 +157,7 @@ def execute(self, **kwargs) -> Any: handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator. """ if self.execution_mode == self.ExecutionBehavior.DEFAULT: - return self._task_function(**kwargs) + return exception_scopes.user_entry_point(self._task_function)(**kwargs) elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) @@ -267,7 +268,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) with FlyteContextManager.with_context(ctx.with_execution_state(updated_exec_state)): logger.info("Executing Dynamic workflow, using raw inputs") - return task_function(**kwargs) + return exception_scopes.user_entry_point(task_function)(**kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: is_fast_execution = bool( diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index c110a00e50..34a879106f 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -15,7 +15,7 @@ class Resources(object): Storage is not currently supported on the Flyte backend. - Please see the :std:ref:`User Guide ` for detailed examples. + Please see the :std:ref:`User Guide ` for detailed examples. Also refer to the `K8s conventions. `__ """ diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index aa9d35abb0..0c5fe786ae 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -24,7 +24,7 @@ class CronSchedule(_schedule_models.Schedule): cron_expression="0 10 * * ? *", ) - See the :std:ref:`User Guide ` for further examples. + See the :std:ref:`User Guide ` for further examples. """ _VALID_CRON_ALIASES = [ diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 1175a59223..72eacdd3eb 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from flytekit.common import constants as _common_constants +from flytekit.common.exceptions import scopes as exception_scopes from flytekit.common.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.core.base_task import PythonTask from flytekit.core.class_based_resolver import ClassStorageTaskResolver @@ -384,11 +385,21 @@ class ImperativeWorkflow(WorkflowBase): .. literalinclude:: ../../../tests/flytekit/unit/core/test_imperative.py :start-after: # docs_start - :end-before: # docs_start + :end-before: # docs_end :language: python :dedent: 4 - This workflow would be identical on the backed to the + This workflow would be identical on the back-end to + + .. literalinclude:: ../../../tests/flytekit/unit/core/test_imperative.py + :start-after: # docs_equivalent_start + :end-before: # docs_equivalent_end + :language: python + :dedent: 4 + + Note that the only reason we need the ``NamedTuple`` is so we can name the output the same thing as in the + imperative example. The imperative paradigm makes the naming of workflow outputs easier, but this isn't a big + deal in function-workflows because names tend to not be necessary. """ def __init__( @@ -668,7 +679,7 @@ def compile(self, **kwargs): # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) - workflow_outputs = self._workflow_function(**input_kwargs) + workflow_outputs = exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) # This little loop was added as part of the task resolver change. The task resolver interface itself is @@ -740,7 +751,7 @@ def execute(self, **kwargs): call execute from dispatch_execute which is in _local_execute, workflows should also call an execute inside _local_execute. This makes mocking cleaner. """ - return self._workflow_function(**kwargs) + return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) def workflow( diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 55c20e7dbb..7af35e20b9 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -378,7 +378,7 @@ def execute(self, inputs, context=None): except _exception_scopes.FlyteScopedException as e: _logging.error("!!! Begin Error Captured by Flyte !!!") output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError(e.error_code, e.verbose_message, e.kind) + _error_models.ContainerError(e.error_code, e.verbose_message, e.kind, 0) ) _logging.error(e.verbose_message) _logging.error("!!! End Error Captured by Flyte !!!") @@ -387,9 +387,7 @@ def execute(self, inputs, context=None): exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( - "SYSTEM:Unknown", - exc_str, - _error_models.ContainerError.Kind.RECOVERABLE, + "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE, 0 ) ) _logging.error(exc_str) diff --git a/flytekit/models/core/errors.py b/flytekit/models/core/errors.py index 0a1196caec..28bbdfedd3 100644 --- a/flytekit/models/core/errors.py +++ b/flytekit/models/core/errors.py @@ -8,15 +8,18 @@ class Kind(object): NON_RECOVERABLE = _errors_pb2.ContainerError.NON_RECOVERABLE RECOVERABLE = _errors_pb2.ContainerError.RECOVERABLE - def __init__(self, code, message, kind): + def __init__(self, code: str, message: str, kind: int, origin: int): """ - :param Text code: A succinct code about the error - :param Text message: Whatever message you want to surface about the error - :param int kind: A value from the ContainerError.Kind enum. + :param code: A succinct code about the error + :param message: Whatever message you want to surface about the error + :param kind: A value from the ContainerError.Kind enum. + :param origin: A value from ExecutionError.ErrorKind. Don't confuse this with error kind, even though + both are called kind. """ self._code = code self._message = message self._kind = kind + self._origin = origin @property def code(self): @@ -39,11 +42,18 @@ def kind(self): """ return self._kind + @property + def origin(self) -> int: + """ + The origin of the error, an enum value from ExecutionError.ErrorKind + """ + return self._origin + def to_flyte_idl(self): """ :rtype: flyteidl.core.errors_pb2.ContainerError """ - return _errors_pb2.ContainerError(code=self.code, message=self.message, kind=self.kind) + return _errors_pb2.ContainerError(code=self.code, message=self.message, kind=self.kind, origin=self.origin) @classmethod def from_flyte_idl(cls, proto): @@ -51,7 +61,7 @@ def from_flyte_idl(cls, proto): :param flyteidl.core.errors_pb2.ContainerError proto: :rtype: ContainerError """ - return cls(proto.code, proto.message, proto.kind) + return cls(proto.code, proto.message, proto.kind, proto.origin) class ErrorDocument(_common.FlyteIdlEntity): diff --git a/flytekit/models/core/execution.py b/flytekit/models/core/execution.py index 5323e0489c..84b2e95f4e 100644 --- a/flytekit/models/core/execution.py +++ b/flytekit/models/core/execution.py @@ -112,15 +112,22 @@ def enum_to_string(cls, int_value): class ExecutionError(_common.FlyteIdlEntity): - def __init__(self, code, message, error_uri): + class ErrorKind(object): + UNKNOWN = _execution_pb2.ExecutionError.ErrorKind.UNKNOWN + USER = _execution_pb2.ExecutionError.ErrorKind.USER + SYSTEM = _execution_pb2.ExecutionError.ErrorKind.SYSTEM + + def __init__(self, code: str, message: str, error_uri: str, kind: int): """ - :param Text code: - :param Text message: - :param Text uri: + :param code: + :param message: + :param uri: + :param kind: """ self._code = code self._message = message self._error_uri = error_uri + self._kind = kind @property def code(self): @@ -143,6 +150,13 @@ def error_uri(self): """ return self._error_uri + @property + def kind(self) -> int: + """ + Enum value from ErrorKind + """ + return self._kind + def to_flyte_idl(self): """ :rtype: flyteidl.core.execution_pb2.ExecutionError @@ -151,6 +165,7 @@ def to_flyte_idl(self): code=self.code, message=self.message, error_uri=self.error_uri, + kind=self.kind, ) @classmethod @@ -159,11 +174,7 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.ExecutionError p: :rtype: ExecutionError """ - return cls( - code=p.code, - message=p.message, - error_uri=p.error_uri, - ) + return cls(code=p.code, message=p.message, error_uri=p.error_uri, kind=p.kind) class TaskLog(_common.FlyteIdlEntity): diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index c1f871bf21..684fef95ba 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -1,7 +1,6 @@ from datetime import datetime as _datetime import pytz as _pytz -import six as _six from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf.struct_pb2 import Struct @@ -310,7 +309,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.BindingDataMap """ - return _literals_pb2.BindingDataMap(bindings={k: v.to_flyte_idl() for (k, v) in _six.iteritems(self.bindings)}) + return _literals_pb2.BindingDataMap(bindings={k: v.to_flyte_idl() for (k, v) in self.bindings.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -319,7 +318,7 @@ def from_flyte_idl(cls, pb2_object): :rtype: flytekit.models.literals.BindingDataMap """ - return cls({k: BindingData.from_flyte_idl(v) for (k, v) in _six.iteritems(pb2_object.bindings)}) + return cls({k: BindingData.from_flyte_idl(v) for (k, v) in pb2_object.bindings.items()}) class BindingDataCollection(_common.FlyteIdlEntity): @@ -588,7 +587,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.LiteralMap """ - return _literals_pb2.LiteralMap(literals={k: v.to_flyte_idl() for k, v in _six.iteritems(self.literals)}) + return _literals_pb2.LiteralMap(literals={k: v.to_flyte_idl() for k, v in self.literals.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -596,7 +595,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.literals_pb2.LiteralMap pb2_object: :rtype: LiteralMap """ - return cls({k: Literal.from_flyte_idl(v) for k, v in _six.iteritems(pb2_object.literals)}) + return cls({k: Literal.from_flyte_idl(v) for k, v in pb2_object.literals.items()}) class Scalar(_common.FlyteIdlEntity): diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 2745e0367e..106156499c 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,4 +1,6 @@ import os +import typing +from collections import OrderedDict import mock import six @@ -9,13 +11,19 @@ from flytekit.bin.entrypoint import _dispatch_execute, _legacy_execute_task, execute_task_cmd from flytekit.common import constants as _constants from flytekit.common import utils as _utils +from flytekit.common.exceptions import user as user_exceptions +from flytekit.common.exceptions.scopes import system_entry_point from flytekit.common.types import helpers as _type_helpers from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise +from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine from flytekit.models import literals as _literal_models -from flytekit.models import literals as _literals +from flytekit.models.core import errors as error_models +from flytekit.models.core import execution as execution_models from tests.flytekit.common import task_definitions as _task_defs @@ -110,8 +118,10 @@ def test_arrayjob_entrypoint_in_proc(): _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) # construct indexlookup.pb which has array: [1] - mapped_index = _literals.Literal(_literals.Scalar(primitive=_literals.Primitive(integer=1))) - index_lookup_collection = _literals.LiteralCollection([mapped_index]) + mapped_index = _literal_models.Literal( + _literal_models.Scalar(primitive=_literal_models.Primitive(integer=1)) + ) + index_lookup_collection = _literal_models.LiteralCollection([mapped_index]) index_lookup_file = os.path.join(dir.name, "indexlookup.pb") _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file) @@ -223,7 +233,10 @@ def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_d empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() mock_load_proto.return_value = empty_literal_map - _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") + # The system_entry_point decorator does different thing based on whether or not it's the + # first time it's called. Using it here to mimic the fact that _dispatch_execute is + # called by _execute_task, which also has a system_entry_point + system_entry_point(_dispatch_execute)(ctx, python_task, "inputs path", "outputs prefix") assert mock_write_to_file.call_count == 0 @@ -254,3 +267,172 @@ def verify_output(*args, **kwargs): mock_write_to_file.side_effect = verify_output _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") assert mock_write_to_file.call_count == 1 + + +# This function collects outputs instead of writing them to a file. +# See flytekit.common.utils.write_proto_to_file for the original +def get_output_collector(results: OrderedDict): + def output_collector(proto, path): + results[path] = proto + + return output_collector + + +@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.common.utils.write_proto_to_file") +def test_dispatch_execute_normal(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + # Just leave these here, mock them out so nothing happens + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + @task + def t1(a: int) -> str: + return f"string is: {a}" + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) + mock_load_proto.return_value = input_literal_map.to_flyte_idl() + + files = OrderedDict() + mock_write_to_file.side_effect = get_output_collector(files) + # See comment in test_dispatch_execute_ignore for why we need to decorate + system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix") + assert len(files) == 1 + + # A successful run should've written an outputs file. + k = list(files.keys())[0] + assert "outputs.pb" in k + + v = list(files.values())[0] + lm = _literal_models.LiteralMap.from_flyte_idl(v) + assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5" + + +@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.common.utils.write_proto_to_file") +def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + # Just leave these here, mock them out so nothing happens + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + @task + def t1(a: int) -> str: + # Should be interpreted as a non-recoverable user error + raise ValueError(f"some exception {a}") + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) + mock_load_proto.return_value = input_literal_map.to_flyte_idl() + + files = OrderedDict() + mock_write_to_file.side_effect = get_output_collector(files) + # See comment in test_dispatch_execute_ignore for why we need to decorate + system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix") + assert len(files) == 1 + + # Exception should've caused an error file + k = list(files.keys())[0] + assert "error.pb" in k + + v = list(files.values())[0] + ed = error_models.ErrorDocument.from_flyte_idl(v) + assert ed.error.kind == error_models.ContainerError.Kind.NON_RECOVERABLE + assert "some exception 5" in ed.error.message + assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER + + +@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.common.utils.write_proto_to_file") +def test_dispatch_execute_user_error_recoverable(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + # Just leave these here, mock them out so nothing happens + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + @task + def t1(a: int) -> str: + return f"A is {a}" + + @dynamic + def my_subwf(a: int) -> typing.List[str]: + # This also tests the dynamic/compile path + raise user_exceptions.FlyteRecoverableException(f"recoverable {a}") + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) + mock_load_proto.return_value = input_literal_map.to_flyte_idl() + + files = OrderedDict() + mock_write_to_file.side_effect = get_output_collector(files) + # See comment in test_dispatch_execute_ignore for why we need to decorate + system_entry_point(_dispatch_execute)(ctx, my_subwf, "inputs path", "outputs prefix") + assert len(files) == 1 + + # Exception should've caused an error file + k = list(files.keys())[0] + assert "error.pb" in k + + v = list(files.values())[0] + ed = error_models.ErrorDocument.from_flyte_idl(v) + assert ed.error.kind == error_models.ContainerError.Kind.RECOVERABLE + assert "recoverable 5" in ed.error.message + assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER + + +@mock.patch("flytekit.common.utils.load_proto_from_file") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") +@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.common.utils.write_proto_to_file") +def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + # Just leave these here, mock them out so nothing happens + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) + mock_load_proto.return_value = input_literal_map.to_flyte_idl() + + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = Exception("some system exception") + + files = OrderedDict() + mock_write_to_file.side_effect = get_output_collector(files) + # See comment in test_dispatch_execute_ignore for why we need to decorate + system_entry_point(_dispatch_execute)(ctx, python_task, "inputs path", "outputs prefix") + assert len(files) == 1 + + # Exception should've caused an error file + k = list(files.keys())[0] + assert "error.pb" in k + + v = list(files.values())[0] + ed = error_models.ErrorDocument.from_flyte_idl(v) + # System errors default to recoverable + assert ed.error.kind == error_models.ContainerError.Kind.RECOVERABLE + assert "some system exception" in ed.error.message + assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index 250c0a0e11..fedb2a7553 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -46,7 +46,7 @@ def t2(): # docs_start # Create the workflow with a name. This needs to be unique within the project and takes the place of the function # name that's used for regular decorated function-based workflows. - wb = Workflow(name="my.workflow") + wb = Workflow(name="my_workflow") # Adds a top level input to the workflow. This is like an input to a workflow function. wb.add_workflow_input("in1", str) # Call your tasks. @@ -54,7 +54,7 @@ def t2(): wb.add_entity(t2) # This is analagous to a return statement wb.add_workflow_output("from_n0t1", node.outputs["o0"]) - # docs_start + # docs_end assert wb(in1="hello") == "hello world" @@ -66,10 +66,21 @@ def t2(): assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 + # docs_equivalent_start + nt = typing.NamedTuple("wf_output", from_n0t1=str) + + @workflow + def my_workflow(in1: str) -> nt: + x = t1(a=in1) + t2() + return (x,) + + # docs_equivalent_end + # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) - assert lp_model.spec.workflow_id.name == "my.workflow" + assert lp_model.spec.workflow_id.name == "my_workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) @@ -81,7 +92,7 @@ def t2(): assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs["parent_wf_output"].type.simple is not None - assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my.workflow" + assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my_workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") diff --git a/tests/flytekit/unit/models/core/test_errors.py b/tests/flytekit/unit/models/core/test_errors.py index fd55ba5ebf..080a396ce5 100644 --- a/tests/flytekit/unit/models/core/test_errors.py +++ b/tests/flytekit/unit/models/core/test_errors.py @@ -1,21 +1,27 @@ -from flytekit.models.core import errors +from flytekit.models.core import errors, execution def test_container_error(): - obj = errors.ContainerError("code", "my message", errors.ContainerError.Kind.RECOVERABLE) + obj = errors.ContainerError( + "code", "my message", errors.ContainerError.Kind.RECOVERABLE, execution.ExecutionError.ErrorKind.SYSTEM + ) assert obj.code == "code" assert obj.message == "my message" assert obj.kind == errors.ContainerError.Kind.RECOVERABLE + assert obj.origin == 2 obj2 = errors.ContainerError.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.code == "code" assert obj2.message == "my message" assert obj2.kind == errors.ContainerError.Kind.RECOVERABLE + assert obj2.origin == 2 def test_error_document(): - ce = errors.ContainerError("code", "my message", errors.ContainerError.Kind.RECOVERABLE) + ce = errors.ContainerError( + "code", "my message", errors.ContainerError.Kind.RECOVERABLE, execution.ExecutionError.ErrorKind.USER + ) obj = errors.ErrorDocument(ce) assert obj.error == ce diff --git a/tests/flytekit/unit/models/core/test_execution.py b/tests/flytekit/unit/models/core/test_execution.py index a0294967a7..bb277280ad 100644 --- a/tests/flytekit/unit/models/core/test_execution.py +++ b/tests/flytekit/unit/models/core/test_execution.py @@ -19,13 +19,15 @@ def test_task_logs(): def test_execution_error(): - obj = execution.ExecutionError("code", "message", "uri") + obj = execution.ExecutionError("code", "message", "uri", execution.ExecutionError.ErrorKind.UNKNOWN) assert obj.code == "code" assert obj.message == "message" assert obj.error_uri == "uri" + assert obj.kind == 0 obj2 = execution.ExecutionError.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.code == "code" assert obj2.message == "message" assert obj2.error_uri == "uri" + assert obj2.kind == 0