Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New-style tasks/workflows use user exception scope #540

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def _dispatch_execute(
c: OR if an unhandled exception is retrieved - record it as an errors.pb
"""
output_file_dict = {}
try:

@_scopes.system_entry_point
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although _dispatch_execute is already being executed in a system exception scope, that scope is too far up the stack to wrap exceptions raised in this function so the subsequent except blocks below can handle them.

def do_dispatch_execute():
nonlocal output_file_dict

# Step1
local_inputs_file = _os.path.join(ctx.execution_state.working_dir, "inputs.pb")
ctx.file_access.get_data(inputs_path, local_inputs_file)
Expand All @@ -122,6 +126,9 @@ def _dispatch_execute(
_error_models.ContainerError.Kind.RECOVERABLE,
)
)

try:
do_dispatch_execute()
except _scoped_exceptions.FlyteScopedException as e:
_logging.error("!! Begin Error Captured by Flyte !!")
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
Expand All @@ -130,6 +137,7 @@ def _dispatch_execute(
_logging.error(e.verbose_message)
_logging.error("!! End Error Captured by Flyte !!")
except Exception as e:
# TODO: need to preserve IgnoreOutputs exception from system_entry_point handling
if isinstance(e, IgnoreOutputs):
# Step 3b
_logging.warning(f"IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}")
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Dict, List, Optional, Type

from flytekit.common.constants import SdkTaskType
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings
from flytekit.core.interface import transform_interface_to_list_interface
Expand Down Expand Up @@ -168,7 +169,7 @@ def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any:
map_task_inputs = {}
for k in self.interface.inputs.keys():
map_task_inputs[k] = kwargs[k][task_index]
return self._run_task.execute(**map_task_inputs)
return _exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs)

def _raw_execute(self, **kwargs) -> Any:
"""
Expand All @@ -190,7 +191,7 @@ def _raw_execute(self, **kwargs) -> Any:
single_instance_inputs = {}
for k in self.interface.inputs.keys():
single_instance_inputs[k] = kwargs[k][i]
o = self._run_task.execute(**single_instance_inputs)
o = _exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)

Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar, Union

from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import (
ExecutionState,
Expand Down Expand Up @@ -156,7 +157,7 @@ def execute(self, **kwargs) -> Any:
handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator.
"""
if self.execution_mode == self.ExecutionBehavior.DEFAULT:
return self._task_function(**kwargs)
return _exception_scopes.user_entry_point(self._task_function)(**kwargs)
elif self.execution_mode == self.ExecutionBehavior.DYNAMIC:
return self.dynamic_execute(self._task_function, **kwargs)

Expand Down Expand Up @@ -267,7 +268,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)
with FlyteContextManager.with_context(ctx.with_execution_state(updated_exec_state)):
logger.info("Executing Dynamic workflow, using raw inputs")
return task_function(**kwargs)
return _exception_scopes.user_entry_point(task_function)(**kwargs)

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
is_fast_execution = bool(
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
Expand Down Expand Up @@ -668,7 +669,7 @@ def compile(self, **kwargs):
# Construct the default input promise bindings, but then override with the provided inputs, if any
input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()])
input_kwargs.update(kwargs)
workflow_outputs = self._workflow_function(**input_kwargs)
workflow_outputs = _exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs)
all_nodes.extend(comp_ctx.compilation_state.nodes)

# This little loop was added as part of the task resolver change. The task resolver interface itself is
Expand Down Expand Up @@ -740,7 +741,7 @@ def execute(self, **kwargs):
call execute from dispatch_execute which is in _local_execute, workflows should also call an execute inside
_local_execute. This makes mocking cleaner.
"""
return self._workflow_function(**kwargs)
return _exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


def workflow(
Expand Down