diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index d25d332625..cfad71d99f 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -22,6 +22,7 @@ from flytekit.configuration import internal as _internal_config from flytekit.configuration import platform as _platform_config from flytekit.configuration import sdk as _sdk_config +from flytekit.control_plane.tasks.task import FlyteTask from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings, get_image_config from flytekit.core.map_task import MapPythonTask @@ -34,7 +35,6 @@ from flytekit.interfaces.stats.taggable import get_stats as _get_stats from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models -from flytekit.models import task as task_models from flytekit.models.core import errors as _error_models from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution @@ -96,12 +96,11 @@ def _dispatch_execute( ctx.file_access.get_data(inputs_path, local_inputs_file) input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) - from flytekit.core.python_third_party_task import ExecutorTask # Step2 if isinstance(task_def, PythonTask): outputs = task_def.dispatch_execute(ctx, idl_input_literals) - elif isinstance(task_def, ExecutorTask): + elif isinstance(task_def, FlyteTask): outputs = task_def.dispatch_execute(ctx, idl_input_literals) else: raise Exception("Task def was neither PythonTask nor TaskTemplate") diff --git a/flytekit/control_plane/__init__.py b/flytekit/control_plane/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/control_plane/tasks/executor.py b/flytekit/control_plane/tasks/executor.py new file mode 100644 index 0000000000..853852545b --- /dev/null +++ b/flytekit/control_plane/tasks/executor.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar, Union + +from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import FlyteContext +from flytekit.core.tracker import TrackedInstance +from flytekit.core.type_engine import TypeEngine +from flytekit.loggers import logger +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_model + +T = TypeVar("T") + + +class FlyteTaskExecutor(TrackedInstance, Generic[T]): + @classmethod + def execute_from_model(cls, tt: _task_model.TaskTemplate, **kwargs) -> Any: + raise NotImplementedError + + @classmethod + def pre_execute(cls, user_params: ExecutionParameters) -> ExecutionParameters: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return user_params + + @classmethod + def post_execute(cls, user_params: ExecutionParameters, rval: Any) -> Any: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return rval + + @classmethod + def dispatch_execute( + cls, ctx: FlyteContext, tt: _task_model.TaskTemplate, input_literal_map: _literal_models.LiteralMap + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + """ + This function is copied from PythonTask.dispatch_execute. Will need to make it a mixin and refactor in the + future. + """ + + # Invoked before the task is executed + new_user_params = cls.pre_execute(ctx.user_space_params) + + # Create another execution context with the new user params, but let's keep the same working dir + with ctx.new_execution_context( + mode=ctx.execution_state.mode, + execution_params=new_user_params, + working_dir=ctx.execution_state.working_dir, + ) as exec_ctx: + # Added: Have to reverse the Python interface from the task template Flyte interface + # This will be moved into the FlyteTask promote logic instead + guessed_python_input_types = TypeEngine.guess_python_types(tt.interface.inputs) + native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, guessed_python_input_types) + + logger.info(f"Invoking FlyteTask executor {tt.id.name} with inputs: {native_inputs}") + try: + native_outputs = cls.execute_from_model(tt, **native_inputs) + except Exception as e: + logger.exception(f"Exception when executing {e}") + raise e + + logger.info(f"Task executed successfully in user level, outputs: {native_outputs}") + # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is + # bubbled up to be handled at the callee layer. + native_outputs = cls.post_execute(new_user_params, native_outputs) + + # Short circuit the translation to literal map because what's returned may be a dj spec (or an + # already-constructed LiteralMap if the dynamic task was a no-op), not python native values + if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( + native_outputs, _dynamic_job.DynamicJobSpec + ): + return native_outputs + + expected_output_names = list(tt.interface.outputs.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + # Deleted some stuff + native_outputs_as_map = {expected_output_names[0]: native_outputs} + elif len(expected_output_names) == 0: + native_outputs_as_map = {} + else: + native_outputs_as_map = { + expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) + } + + # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption + # built into the IDL that all the values of a literal map are of the same type. + literals = {} + for k, v in native_outputs_as_map.items(): + literal_type = tt.interface.outputs[k].type + py_type = type(v) + + if isinstance(v, tuple): + raise AssertionError(f"Output({k}) in task{tt.id.name} received a tuple {v}, instead of {py_type}") + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + raise AssertionError(f"failed to convert return value for var {k}") from e + + outputs_literal_map = _literal_models.LiteralMap(literals=literals) + # After the execute has been successfully completed + return outputs_literal_map diff --git a/flytekit/control_plane/tasks/task.py b/flytekit/control_plane/tasks/task.py index 71f159e523..dfd0145122 100644 --- a/flytekit/control_plane/tasks/task.py +++ b/flytekit/control_plane/tasks/task.py @@ -1,16 +1,35 @@ +from typing import Any, Optional, Union + from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import hash as _hash_mixin from flytekit.control_plane import identifier as _identifier from flytekit.control_plane import interface as _interfaces +from flytekit.core.base_task import ExecutableTaskMixin +from flytekit.core.context_manager import FlyteContext +from flytekit.core.task_executor import FlyteTaskExecutor from flytekit.engines.flyte import engine as _flyte_engine from flytekit.models import common as _common_model +from flytekit.models import dynamic_job as dynamic_job +from flytekit.models import literals as literal_models from flytekit.models import task as _task_model from flytekit.models.admin import common as _admin_common from flytekit.models.core import identifier as _identifier_model -class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): - def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): +class FlyteTask(ExecutableTaskMixin, _hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): + def __init__( + self, + id, + type, + metadata, + interface, + custom, + container=None, + task_type_version=0, + config=None, + executor: Optional[FlyteTaskExecutor] = None, + ): super(FlyteTask, self).__init__( id, type, @@ -22,6 +41,12 @@ def __init__(self, id, type, metadata, interface, custom, container=None, task_t config=config, ) + self._executor = executor + + @property + def executor(self) -> Optional[FlyteTaskExecutor]: + return self._executor + @property def interface(self) -> _interfaces.TypedInterface: return super(FlyteTask, self).interface @@ -93,3 +118,23 @@ def fetch_latest(cls, project: str, domain: str, name: str) -> "FlyteTask": flyte_task = cls.promote_from_model(admin_task.closure.compiled_task.template) flyte_task._id = admin_task.id return flyte_task + + def execute(self, **kwargs) -> Any: + """ + This function directs execute to the executor instead of attempting to run itself. + """ + if self.executor is None: + raise ValueError(f"Cannot execute without an executor") + + return self.executor.execute_from_model(self, **kwargs) + + def dispatch_execute( + self, ctx: FlyteContext, input_literal_map: literal_models.LiteralMap + ) -> Union[literal_models.LiteralMap, dynamic_job.DynamicJobSpec]: + """ + This function directs execute to the executor instead of attempting to run itself. + """ + if self.executor is None: + raise ValueError(f"Cannot run dispatch_execute without an executor") + + return self.executor.dispatch_execute(ctx, self, input_literal_map) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index e6572f4268..6fb2b9aa27 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -1,6 +1,6 @@ import collections import datetime -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union @@ -21,6 +21,7 @@ create_task_output, translate_inputs_to_literals, ) +from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger @@ -110,7 +111,52 @@ class IgnoreOutputs(Exception): pass -class Task(object): +class ExecutableTaskMixin(object): + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + """ + This is the method that will be invoked directly before executing the task method and before all the inputs + are converted. One particular case where this is useful is if the context is to be modified for the user process + to get some user space parameters. This also ensures that things like SparkSession are already correctly + setup before the type transformers are called + + This should return either the same context of the mutated context + """ + return user_params + + def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + """ + Post execute is called after the execution has completed, with the user_params and can be used to clean-up, + or alter the outputs to match the intended tasks outputs. If not overriden, then this function is a No-op + + Args: + rval is returned value from call to execute + user_params: are the modified user params as created during the pre_execute step + """ + return rval + + @abstractmethod + def dispatch_execute( + self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + """ + This method translates Flyte's Type system based input values and invokes the actual call to the executor + This method is also invoked during runtime. It takes in a Flyte LiteralMap and returns a Flyte LiteralMap. + This should call execute below. + + To support dynamic tasks, it can optionally return a DynamicJobSpec instead. + """ + raise NotImplementedError + + @abstractmethod + def execute(self, **kwargs) -> Any: + """ + This function should take in Python native kwargs and return Python native values. This should be called by + dispatch execute. + """ + raise NotImplementedError + + +class Task(ExecutableTaskMixin, ABC): """ The base of all Tasks in flytekit. This task is closest to the FlyteIDL TaskTemplate and captures information in FlyteIDL specification and does not have python native interfaces associated. For any real extension please @@ -276,39 +322,11 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return None - @abstractmethod - def dispatch_execute( - self, - ctx: FlyteContext, - input_literal_map: _literal_models.LiteralMap, - ) -> _literal_models.LiteralMap: - """ - This method translates Flyte's Type system based input values and invokes the actual call to the executor - This method is also invoked during runtime. - """ - pass - - @abstractmethod - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - """ - This is the method that will be invoked directly before executing the task method and before all the inputs - are converted. One particular case where this is useful is if the context is to be modified for the user process - to get some user space parameters. This also ensures that things like SparkSession are already correctly - setup before the type transformers are called - - This should return either the same context of the mutated context - """ - pass - - @abstractmethod - def execute(self, **kwargs) -> Any: - pass - T = TypeVar("T") -class PythonTask(TrackedInstance, Task, Generic[T]): +class PythonTask(TrackedInstance, Task, Generic[T], metaclass=FlyteTrackedABC): # noqa: doesn't realize it's ABC """ Base Class for all Tasks with a Python native ``Interface``. This should be directly used for task types, that do not have a python function to be executed. Otherwise refer to :py:class:`flytekit.PythonFunctionTask`. @@ -343,7 +361,6 @@ def __init__( self._environment = environment if environment else {} self._task_config = task_config - # TODO lets call this interface and the other as flyte_interface? @property def python_interface(self) -> Interface: return self._python_interface @@ -457,32 +474,6 @@ def dispatch_execute( # After the execute has been successfully completed return outputs_literal_map - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - """ - This is the method that will be invoked directly before executing the task method and before all the inputs - are converted. One particular case where this is useful is if the context is to be modified for the user process - to get some user space parameters. This also ensures that things like SparkSession are already correctly - setup before the type transformers are called - - This should return either the same context of the mutated context - """ - return user_params - - @abstractmethod - def execute(self, **kwargs) -> Any: - pass - - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: - """ - Post execute is called after the execution has completed, with the user_params and can be used to clean-up, - or alter the outputs to match the intended tasks outputs. If not overriden, then this function is a No-op - - Args: - rval is returned value from call to execute - user_params: are the modified user params as created during the pre_execute step - """ - return rval - @property def environment(self) -> Dict[str, str]: return self._environment @@ -536,7 +527,7 @@ def name(self) -> str: pass @abstractmethod - def load_task(self, loader_args: List[str]) -> Task: + def load_task(self, loader_args: List[str]) -> ExecutableTaskMixin: """ Given the set of identifier keys, should return one Python Task or raise an error if not found """ diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0951467ff7..eecceaa8dc 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -2,7 +2,7 @@ import importlib import re -from abc import abstractmethod +from abc import ABC from typing import Dict, List, Optional, TypeVar from flytekit.common.tasks.raw_container import _get_container_definition @@ -18,7 +18,7 @@ T = TypeVar("T") -class PythonAutoContainerTask(PythonTask[T], metaclass=FlyteTrackedABC): +class PythonAutoContainerTask(PythonTask[T], metaclass=FlyteTrackedABC): # noqa: doesn't realize it's ABC """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the container and the container information to be automatically captured. diff --git a/flytekit/core/python_third_party_task.py b/flytekit/core/python_third_party_task.py index c9ad9f68e8..186ec83459 100644 --- a/flytekit/core/python_third_party_task.py +++ b/flytekit/core/python_third_party_task.py @@ -2,18 +2,17 @@ import importlib import os -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Dict, List, Optional, TypeVar, Union from flyteidl.core import tasks_pb2 as _tasks_pb2 from flytekit.common import utils as common_utils from flytekit.common.tasks.raw_container import _get_container_definition -from flytekit.common.tasks.sdk_runnable import ExecutionParameters -from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin +from flytekit.core.base_task import ExecutableTaskMixin, PythonTask, Task, TaskResolverMixin from flytekit.core.context_manager import FlyteContext, Image, ImageConfig, SerializationSettings from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.task_executor import FlyteTaskExecutor from flytekit.core.tracker import TrackedInstance -from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models @@ -25,137 +24,6 @@ TC = TypeVar("TC") -class TaskTemplateExecutor(TrackedInstance, Generic[T]): - @classmethod - def execute_from_model(cls, tt: _task_model.TaskTemplate, **kwargs) -> Any: - raise NotImplementedError - - @classmethod - def pre_execute(cls, user_params: ExecutionParameters) -> ExecutionParameters: - """ - This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. - """ - return user_params - - @classmethod - def post_execute(cls, user_params: ExecutionParameters, rval: Any) -> Any: - """ - This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. - """ - return rval - - @classmethod - def dispatch_execute( - cls, ctx: FlyteContext, tt: _task_model.TaskTemplate, input_literal_map: _literal_models.LiteralMap - ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: - """ - This function is copied from PythonTask.dispatch_execute. Will need to make it a mixin and refactor in the - future. - """ - - # Invoked before the task is executed - new_user_params = cls.pre_execute(ctx.user_space_params) - - # Create another execution context with the new user params, but let's keep the same working dir - with ctx.new_execution_context( - mode=ctx.execution_state.mode, - execution_params=new_user_params, - working_dir=ctx.execution_state.working_dir, - ) as exec_ctx: - # Added: Have to reverse the Python interface from the task template Flyte interface - # This will be moved into the FlyteTask promote logic instead - guessed_python_input_types = TypeEngine.guess_python_types(tt.interface.inputs) - native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, guessed_python_input_types) - - logger.info(f"Invoking FlyteTask executor {tt.id.name} with inputs: {native_inputs}") - try: - native_outputs = cls.execute_from_model(tt, **native_inputs) - except Exception as e: - logger.exception(f"Exception when executing {e}") - raise e - - logger.info(f"Task executed successfully in user level, outputs: {native_outputs}") - # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is - # bubbled up to be handled at the callee layer. - native_outputs = cls.post_execute(new_user_params, native_outputs) - - # Short circuit the translation to literal map because what's returned may be a dj spec (or an - # already-constructed LiteralMap if the dynamic task was a no-op), not python native values - if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( - native_outputs, _dynamic_job.DynamicJobSpec - ): - return native_outputs - - expected_output_names = list(tt.interface.outputs.keys()) - if len(expected_output_names) == 1: - # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of - # length one. That convention is used for naming outputs - and single-length-NamedTuples are - # particularly troublesome but elegant handling of them is not a high priority - # Again, we're using the output_tuple_name as a proxy. - # Deleted some stuff - native_outputs_as_map = {expected_output_names[0]: native_outputs} - elif len(expected_output_names) == 0: - native_outputs_as_map = {} - else: - native_outputs_as_map = { - expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) - } - - # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption - # built into the IDL that all the values of a literal map are of the same type. - literals = {} - for k, v in native_outputs_as_map.items(): - literal_type = tt.interface.outputs[k].type - py_type = type(v) - - if isinstance(v, tuple): - raise AssertionError(f"Output({k}) in task{tt.id.name} received a tuple {v}, instead of {py_type}") - try: - literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) - except Exception as e: - raise AssertionError(f"failed to convert return value for var {k}") from e - - outputs_literal_map = _literal_models.LiteralMap(literals=literals) - # After the execute has been successfully completed - return outputs_literal_map - - -class ExecutorTask(object): - def __init__(self, tt: _task_model.TaskTemplate, executor: TaskTemplateExecutor): - self._executor = executor - self._task_template = tt - - @property - def task_template(self) -> _task_model.TaskTemplate: - return self._task_template - - @property - def executor(self) -> TaskTemplateExecutor: - return self._executor - - def execute(self, **kwargs) -> Any: - """ - This function overrides the default task execute behavior. - - Execution for third-party tasks is different from tasks that run the user workflow container. - 1. Serialize the task out to a TaskTemplate. - 2. Pass the template over to the Executor to run, along with the input arguments. - 3. Executor will reconstruct the Python task class object, before running the e - - When overridden for unit testing using the patch operator, all these steps will be skipped and the mocked code, - which should just take in and return Python native values, will be run. - """ - return self.executor.execute_from_model(self.task_template, **kwargs) - - def dispatch_execute( - self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap - ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: - """ - This function overrides the default task execute behavior. - """ - return self.executor.dispatch_execute(ctx, self.task_template, input_literal_map) - - class PythonThirdPartyContainerTask(PythonTask[TC]): SERIALIZE_SETTINGS = SerializationSettings( project="PLACEHOLDER_PROJECT", @@ -170,8 +38,8 @@ def __init__( name: str, task_config: TC, container_image: str, - executor: TaskTemplateExecutor, - task_resolver: Optional[TaskTemplateResolver] = None, + executor: FlyteTaskExecutor, + task_resolver: Optional[FlyteTaskResolver] = None, task_type="python-task", requests: Optional[Resources] = None, limits: Optional[Resources] = None, @@ -224,7 +92,7 @@ def __init__( # Because instances of these tasks rely on the task template in order to run even locally, we'll cache it self._task_template = None - self._task_resolver = task_resolver or default_task_template_resolver + self._task_resolver = task_resolver or default_flyte_task_resolver def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # Overriding base implementation to raise an error, force third-party task author to implement @@ -243,11 +111,11 @@ def resources(self) -> ResourceSpec: return self._resources @property - def executor(self) -> TaskTemplateExecutor: + def executor(self) -> FlyteTaskExecutor: return self._executor @property - def task_resolver(self) -> TaskTemplateResolver: + def task_resolver(self) -> FlyteTaskResolver: return self._task_resolver @property @@ -349,14 +217,16 @@ def load_executor(object_location: str) -> Any: return getattr(class_obj_mod, class_obj_key) -class TaskTemplateResolver(TrackedInstance, TaskResolverMixin): +class FlyteTaskResolver(TrackedInstance, TaskResolverMixin): def __init__(self): - super(TaskTemplateResolver, self).__init__() + super(FlyteTaskResolver, self).__init__() def name(self) -> str: return "task template resolver" - def load_task(self, loader_args: List[str]) -> Task: + def load_task(self, loader_args: List[str]) -> ExecutableTaskMixin: + from flytekit.control_plane.tasks.task import FlyteTask + logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") @@ -365,7 +235,9 @@ def load_task(self, loader_args: List[str]) -> Task: task_template_model = _task_model.TaskTemplate.from_flyte_idl(task_template_proto) executor = load_executor(loader_args[1]) - return ExecutorTask(task_template_model, executor) + ft = FlyteTask.promote_from_model(task_template_model) + ft._executor = executor + return ft def loader_args(self, settings: SerializationSettings, t: PythonThirdPartyContainerTask) -> List[str]: return ["{{.taskTemplatePath}}", f"{t.executor.__module__}.{t.executor.__name__}"] @@ -374,4 +246,4 @@ def get_all_tasks(self) -> List[Task]: return [] -default_task_template_resolver = TaskTemplateResolver() +default_flyte_task_resolver = FlyteTaskResolver() diff --git a/flytekit/core/task_executor.py b/flytekit/core/task_executor.py new file mode 100644 index 0000000000..853852545b --- /dev/null +++ b/flytekit/core/task_executor.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar, Union + +from flytekit.common.tasks.sdk_runnable import ExecutionParameters +from flytekit.core.context_manager import FlyteContext +from flytekit.core.tracker import TrackedInstance +from flytekit.core.type_engine import TypeEngine +from flytekit.loggers import logger +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_model + +T = TypeVar("T") + + +class FlyteTaskExecutor(TrackedInstance, Generic[T]): + @classmethod + def execute_from_model(cls, tt: _task_model.TaskTemplate, **kwargs) -> Any: + raise NotImplementedError + + @classmethod + def pre_execute(cls, user_params: ExecutionParameters) -> ExecutionParameters: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return user_params + + @classmethod + def post_execute(cls, user_params: ExecutionParameters, rval: Any) -> Any: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return rval + + @classmethod + def dispatch_execute( + cls, ctx: FlyteContext, tt: _task_model.TaskTemplate, input_literal_map: _literal_models.LiteralMap + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + """ + This function is copied from PythonTask.dispatch_execute. Will need to make it a mixin and refactor in the + future. + """ + + # Invoked before the task is executed + new_user_params = cls.pre_execute(ctx.user_space_params) + + # Create another execution context with the new user params, but let's keep the same working dir + with ctx.new_execution_context( + mode=ctx.execution_state.mode, + execution_params=new_user_params, + working_dir=ctx.execution_state.working_dir, + ) as exec_ctx: + # Added: Have to reverse the Python interface from the task template Flyte interface + # This will be moved into the FlyteTask promote logic instead + guessed_python_input_types = TypeEngine.guess_python_types(tt.interface.inputs) + native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, guessed_python_input_types) + + logger.info(f"Invoking FlyteTask executor {tt.id.name} with inputs: {native_inputs}") + try: + native_outputs = cls.execute_from_model(tt, **native_inputs) + except Exception as e: + logger.exception(f"Exception when executing {e}") + raise e + + logger.info(f"Task executed successfully in user level, outputs: {native_outputs}") + # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is + # bubbled up to be handled at the callee layer. + native_outputs = cls.post_execute(new_user_params, native_outputs) + + # Short circuit the translation to literal map because what's returned may be a dj spec (or an + # already-constructed LiteralMap if the dynamic task was a no-op), not python native values + if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( + native_outputs, _dynamic_job.DynamicJobSpec + ): + return native_outputs + + expected_output_names = list(tt.interface.outputs.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + # Deleted some stuff + native_outputs_as_map = {expected_output_names[0]: native_outputs} + elif len(expected_output_names) == 0: + native_outputs_as_map = {} + else: + native_outputs_as_map = { + expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) + } + + # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption + # built into the IDL that all the values of a literal map are of the same type. + literals = {} + for k, v in native_outputs_as_map.items(): + literal_type = tt.interface.outputs[k].type + py_type = type(v) + + if isinstance(v, tuple): + raise AssertionError(f"Output({k}) in task{tt.id.name} received a tuple {v}, instead of {py_type}") + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + raise AssertionError(f"failed to convert return value for var {k}") from e + + outputs_literal_map = _literal_models.LiteralMap(literals=literals) + # After the execute has been successfully completed + return outputs_literal_map diff --git a/flytekit/extras/external_container_sqlite3/Dockerfile b/flytekit/extras/external_container_sqlite3/Dockerfile index 882d919234..d982af53ad 100644 --- a/flytekit/extras/external_container_sqlite3/Dockerfile +++ b/flytekit/extras/external_container_sqlite3/Dockerfile @@ -4,4 +4,4 @@ ENV FLYTE_INTERNAL_IMAGE=flytecli-sqlite3:v0.1.0 RUN pip install awscli -RUN pip install -U https://github.com/flyteorg/flytekit/archive/4089babbdc7797d8e12203bda4c83c172f1c5758.zip#egg=flytekit +RUN pip install -U https://github.com/flyteorg/flytekit/archive/ceb80db9a09928f64afded5fe9ff46ce7cce03fd.zip#egg=flytekit diff --git a/flytekit/extras/external_container_sqlite3/task.py b/flytekit/extras/external_container_sqlite3/task.py index 84a54779d2..3bf0d0bd97 100644 --- a/flytekit/extras/external_container_sqlite3/task.py +++ b/flytekit/extras/external_container_sqlite3/task.py @@ -11,7 +11,7 @@ from flytekit import FlyteContext, kwtypes from flytekit.core.base_sql_task import SQLTask from flytekit.core.context_manager import SerializationSettings -from flytekit.core.python_third_party_task import PythonThirdPartyContainerTask, TaskTemplateExecutor +from flytekit.core.python_third_party_task import FlyteTaskExecutor, PythonThirdPartyContainerTask from flytekit.models import task as task_models from flytekit.types.schema import FlyteSchema @@ -96,7 +96,7 @@ def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing } -class SQLite3TaskExecutor(TaskTemplateExecutor[SQLite3Task]): +class SQLite3TaskExecutor(FlyteTaskExecutor[SQLite3Task]): @classmethod def execute_from_model(cls, tt: task_models.TaskTemplate, **kwargs) -> typing.Any: with tempfile.TemporaryDirectory() as temp_dir: diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py b/plugins/flytesqlalchemy/flytekitplugins/flytesqlalchemy/__init__.py similarity index 100% rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py rename to plugins/flytesqlalchemy/flytekitplugins/flytesqlalchemy/__init__.py diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytesqlalchemy/flytekitplugins/flytesqlalchemy/task.py similarity index 100% rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py rename to plugins/flytesqlalchemy/flytekitplugins/flytesqlalchemy/task.py diff --git a/plugins/sqlalchemy/setup.py b/plugins/flytesqlalchemy/setup.py similarity index 97% rename from plugins/sqlalchemy/setup.py rename to plugins/flytesqlalchemy/setup.py index e39b139552..ab296c3f3f 100644 --- a/plugins/sqlalchemy/setup.py +++ b/plugins/flytesqlalchemy/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -PLUGIN_NAME = "sqlalchemy" +PLUGIN_NAME = "flytesqlalchemy" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" diff --git a/tests/flytekit/unit/extras/external_container_sqlite3/test_task.py b/tests/flytekit/unit/extras/external_container_sqlite3/test_task.py index 556f35efd5..b4df7d031b 100644 --- a/tests/flytekit/unit/extras/external_container_sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/external_container_sqlite3/test_task.py @@ -88,10 +88,10 @@ def test_task_serialization(): "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", - "flytekit.core.python_third_party_task.default_task_template_resolver", + "flytekit.core.python_third_party_task.default_flyte_task_resolver", "--", "{{.taskTemplatePath}}", - "flytekit.extras.sqlite3.task.SQLite3TaskExecutor", + "flytekit.extras.external_container_sqlite3.task.SQLite3TaskExecutor", ] assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}"