Skip to content

Commit

Permalink
Customized container tasks and Shim tasks/executors (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored May 4, 2021
1 parent 2340bfa commit ad6c322
Show file tree
Hide file tree
Showing 22 changed files with 745 additions and 195 deletions.
110 changes: 55 additions & 55 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import datetime as _datetime
import importlib as _importlib
import logging as _logging
Expand Down Expand Up @@ -26,7 +27,6 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings, get_image_config
from flytekit.core.map_task import MapPythonTask
from flytekit.core.promise import VoidPromise
from flytekit.core.python_auto_container import TaskResolverMixin
from flytekit.engines import loader as _engine_loader
from flytekit.interfaces import random as _flyte_random
from flytekit.interfaces.data import data_proxy as _data_proxy
Expand All @@ -38,6 +38,7 @@
from flytekit.models.core import errors as _error_models
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
from flytekit.tools.module_loader import load_object_from_module


def _compute_array_job_index():
Expand Down Expand Up @@ -73,7 +74,12 @@ def _map_job_index_to_child_index(local_input_dir, datadir, index):
return mapping_proto.literals[index].scalar.primitive.integer


def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, output_prefix: str):
def _dispatch_execute(
ctx: FlyteContext,
task_def: PythonTask,
inputs_path: str,
output_prefix: str,
):
"""
Dispatches execute to PythonTask
Step1: Download inputs and load into a literal map
Expand All @@ -90,6 +96,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
ctx.file_access.get_data(inputs_path, local_inputs_file)
input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file)
idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto)

# Step2
outputs = task_def.dispatch_execute(ctx, idl_input_literals)
# Step3a
Expand Down Expand Up @@ -122,7 +129,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.warning(f"IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}")
return
# Step 3c
_logging.error(f"Exception when executing task {task_def.name}, reason {str(e)}")
_logging.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}")
_logging.error("!! Begin Unknown System Error Captured by Flyte !!")
exc_str = _traceback.format_exc()
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
Expand All @@ -142,18 +149,12 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")


def _handle_annotated_task(
task_def: PythonTask,
inputs: str,
output_prefix: str,
@contextlib.contextmanager
def setup_execution(
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
"""
Entrypoint for all PythonTask extensions
"""
_click.echo("Running native-typed task")
cloud_provider = _platform_config.CLOUD_PROVIDER.get()
log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
_logging.getLogger().setLevel(log_level)
Expand Down Expand Up @@ -235,7 +236,20 @@ def _handle_annotated_task(
execution_params=execution_parameters,
additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir},
) as ctx:
_dispatch_execute(ctx, task_def, inputs, output_prefix)
yield ctx


def _handle_annotated_task(
ctx: FlyteContext,
task_def: PythonTask,
inputs: str,
output_prefix: str,
):
"""
Entrypoint for all PythonTask extensions
"""
_click.echo("Running native-typed task")
_dispatch_execute(ctx, task_def, inputs, output_prefix)


@_scopes.system_entry_point
Expand Down Expand Up @@ -277,18 +291,6 @@ def _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_outp
)


def _load_resolver(resolver_location: str) -> TaskResolverMixin:
# Load the actual resolver - this cannot be a nested thing, whatever kind of resolver it is, it has to be loadable
# directly from importlib
# TODO: Handle corner cases, like where the first part is [] maybe
# e.g. flytekit.core.python_auto_container.default_task_resolver
resolver = resolver_location.split(".")
resolver_mod = resolver[:-1] # e.g. ['flytekit', 'core', 'python_auto_container']
resolver_key = resolver[-1] # e.g. 'default_task_resolver'
resolver_mod = _importlib.import_module(".".join(resolver_mod))
return getattr(resolver_mod, resolver_key)


@_scopes.system_entry_point
def _execute_task(
inputs,
Expand Down Expand Up @@ -326,18 +328,17 @@ def _execute_task(
if len(resolver_args) < 1:
raise Exception("cannot be <1")

resolver_obj = _load_resolver(resolver)
with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
if test:
_click.echo(
f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
)
return
_handle_annotated_task(
_task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
if test:
_click.echo(
f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
)
return
_handle_annotated_task(ctx, _task_def, inputs, output_prefix)


@_scopes.system_entry_point
Expand All @@ -355,28 +356,27 @@ def _execute_map_task(
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

resolver_obj = _load_resolver(resolver)
with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
if not isinstance(_task_def, PythonFunctionTask):
raise Exception("Map tasks cannot be run with instance tasks.")
map_task = MapPythonTask(_task_def, max_concurrency)

task_index = _compute_array_job_index()
output_prefix = _os.path.join(output_prefix, str(task_index))

if test:
_click.echo(
f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} "
f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} "
f"Resolver and args: {resolver} {resolver_args}"
)
return
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
if not isinstance(_task_def, PythonFunctionTask):
raise Exception("Map tasks cannot be run with instance tasks.")
map_task = MapPythonTask(_task_def, max_concurrency)

task_index = _compute_array_job_index()
output_prefix = _os.path.join(output_prefix, str(task_index))

if test:
_click.echo(
f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} "
f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} "
f"Resolver and args: {resolver} {resolver_args}"
)
return

_handle_annotated_task(
map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)
_handle_annotated_task(ctx, map_task, inputs, output_prefix)


@_click.group()
Expand Down
77 changes: 76 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
Expand Down Expand Up @@ -486,3 +486,78 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
@property
def environment(self) -> Dict[str, str]:
return self._environment


class TaskResolverMixin(object):
"""
Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin,
and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any
case, at execution time, the container image containing the task needs to be spun up again (for container tasks at
least which most tasks are) at which point the container needs to know which task it's supposed to run and
how to rehydrate the task object.
For example, the serialization of a simple task ::
# in repo_root/workflows/example.py
@task
def t1(...) -> ...: ...
might result in a container with arguments like ::
pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \
--raw-output-data-prefix /tmp/data \
--resolver flytekit.core.python_auto_container.default_task_resolver \
-- \
task-module repo_root.workflows.example task-name t1
At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute``
bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things,
#. the ``location`` of the task's task resolver, followed by two dashes, followed by
#. the arguments provided by calling the ``loader_args`` function below.
The ``default_task_resolver`` declared below knows that ::
* When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the
key of the task in the module, either the function name, or the variable it was assigned to).
* When ``load_task`` is called, it interprets the first part of the command as the module to call
``importlib.import_module`` on, and then looks for a key ``t1``.
This is just the default behavior. Users should feel free to implement their own resolvers.
"""

@property
@abstractmethod
def location(self) -> str:
pass

@abstractmethod
def name(self) -> str:
pass

@abstractmethod
def load_task(self, loader_args: List[str]) -> Task:
"""
Given the set of identifier keys, should return one Python Task or raise an error if not found
"""
pass

@abstractmethod
def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
"""
Return a list of strings that can help identify the parameter Task
"""
pass

@abstractmethod
def get_all_tasks(self) -> List[Task]:
"""
Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
"""
pass

def task_name(self, t: Task) -> Optional[str]:
"""
Overridable function that can optionally return a custom name for a given task
"""
return None
3 changes: 2 additions & 1 deletion flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List

from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.tracker import TrackedInstance


Expand Down
Loading

0 comments on commit ad6c322

Please sign in to comment.