From ad6c322008eb4d40a4ff3e8fe143349a152b5c76 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 4 May 2021 16:54:24 -0700 Subject: [PATCH] Customized container tasks and Shim tasks/executors (#449) --- flytekit/bin/entrypoint.py | 110 ++++---- flytekit/core/base_task.py | 77 +++++- flytekit/core/class_based_resolver.py | 3 +- flytekit/core/python_auto_container.py | 104 ++------ .../core/python_customized_container_task.py | 249 ++++++++++++++++++ flytekit/core/python_function_task.py | 39 +-- flytekit/core/shim_task.py | 163 ++++++++++++ flytekit/core/tracked_abc.py | 11 + flytekit/core/type_engine.py | 39 +++ flytekit/extend/__init__.py | 3 +- flytekit/extras/cloud_pickle_resolver.py | 3 +- flytekit/models/task.py | 2 +- flytekit/tools/module_loader.py | 13 +- flytekit/types/schema/types.py | 21 ++ .../flytekitplugins/sqlalchemy/__init__.py | 0 .../flytekitplugins/sqlalchemy/task.py | 0 .../setup.py | 0 plugins/setup.py | 2 +- tests/flytekit/unit/core/test_resolver.py | 8 +- tests/flytekit/unit/core/test_schema_types.py | 38 +++ tests/flytekit/unit/core/test_type_engine.py | 50 +++- .../flytekit/unit/tools/test_module_loader.py | 5 + 22 files changed, 745 insertions(+), 195 deletions(-) create mode 100644 flytekit/core/python_customized_container_task.py create mode 100644 flytekit/core/shim_task.py create mode 100644 flytekit/core/tracked_abc.py rename plugins/{sqlalchemy => flytekit.sqlalchemy}/flytekitplugins/sqlalchemy/__init__.py (100%) rename plugins/{sqlalchemy => flytekit.sqlalchemy}/flytekitplugins/sqlalchemy/task.py (100%) rename plugins/{sqlalchemy => flytekit.sqlalchemy}/setup.py (100%) create mode 100644 tests/flytekit/unit/core/test_schema_types.py diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 3c33b13f30..87a01014b6 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,3 +1,4 @@ +import contextlib import datetime as _datetime import importlib as _importlib import logging as _logging @@ -26,7 +27,6 @@ from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings, get_image_config from flytekit.core.map_task import MapPythonTask from flytekit.core.promise import VoidPromise -from flytekit.core.python_auto_container import TaskResolverMixin from flytekit.engines import loader as _engine_loader from flytekit.interfaces import random as _flyte_random from flytekit.interfaces.data import data_proxy as _data_proxy @@ -38,6 +38,7 @@ from flytekit.models.core import errors as _error_models from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution +from flytekit.tools.module_loader import load_object_from_module def _compute_array_job_index(): @@ -73,7 +74,12 @@ def _map_job_index_to_child_index(local_input_dir, datadir, index): return mapping_proto.literals[index].scalar.primitive.integer -def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, output_prefix: str): +def _dispatch_execute( + ctx: FlyteContext, + task_def: PythonTask, + inputs_path: str, + output_prefix: str, +): """ Dispatches execute to PythonTask Step1: Download inputs and load into a literal map @@ -90,6 +96,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, ctx.file_access.get_data(inputs_path, local_inputs_file) input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) + # Step2 outputs = task_def.dispatch_execute(ctx, idl_input_literals) # Step3a @@ -122,7 +129,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, _logging.warning(f"IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}") return # Step 3c - _logging.error(f"Exception when executing task {task_def.name}, reason {str(e)}") + _logging.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}") _logging.error("!! Begin Unknown System Error Captured by Flyte !!") exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( @@ -142,18 +149,12 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, _logging.info(f"Engine folder written successfully to the output prefix {output_prefix}") -def _handle_annotated_task( - task_def: PythonTask, - inputs: str, - output_prefix: str, +@contextlib.contextmanager +def setup_execution( raw_output_data_prefix: str, dynamic_addl_distro: str = None, dynamic_dest_dir: str = None, ): - """ - Entrypoint for all PythonTask extensions - """ - _click.echo("Running native-typed task") cloud_provider = _platform_config.CLOUD_PROVIDER.get() log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) @@ -235,7 +236,20 @@ def _handle_annotated_task( execution_params=execution_parameters, additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir}, ) as ctx: - _dispatch_execute(ctx, task_def, inputs, output_prefix) + yield ctx + + +def _handle_annotated_task( + ctx: FlyteContext, + task_def: PythonTask, + inputs: str, + output_prefix: str, +): + """ + Entrypoint for all PythonTask extensions + """ + _click.echo("Running native-typed task") + _dispatch_execute(ctx, task_def, inputs, output_prefix) @_scopes.system_entry_point @@ -277,18 +291,6 @@ def _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_outp ) -def _load_resolver(resolver_location: str) -> TaskResolverMixin: - # Load the actual resolver - this cannot be a nested thing, whatever kind of resolver it is, it has to be loadable - # directly from importlib - # TODO: Handle corner cases, like where the first part is [] maybe - # e.g. flytekit.core.python_auto_container.default_task_resolver - resolver = resolver_location.split(".") - resolver_mod = resolver[:-1] # e.g. ['flytekit', 'core', 'python_auto_container'] - resolver_key = resolver[-1] # e.g. 'default_task_resolver' - resolver_mod = _importlib.import_module(".".join(resolver_mod)) - return getattr(resolver_mod, resolver_key) - - @_scopes.system_entry_point def _execute_task( inputs, @@ -326,18 +328,17 @@ def _execute_task( if len(resolver_args) < 1: raise Exception("cannot be <1") - resolver_obj = _load_resolver(resolver) with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if test: - _click.echo( - f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" - ) - return - _handle_annotated_task( - _task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir - ) + with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx: + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + if test: + _click.echo( + f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" + ) + return + _handle_annotated_task(ctx, _task_def, inputs, output_prefix) @_scopes.system_entry_point @@ -355,28 +356,27 @@ def _execute_map_task( if len(resolver_args) < 1: raise Exception(f"Resolver args cannot be <1, got {resolver_args}") - resolver_obj = _load_resolver(resolver) with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) - - task_index = _compute_array_job_index() - output_prefix = _os.path.join(output_prefix, str(task_index)) - - if test: - _click.echo( - f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} " - f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} " - f"Resolver and args: {resolver} {resolver_args}" - ) - return + with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx: + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + if not isinstance(_task_def, PythonFunctionTask): + raise Exception("Map tasks cannot be run with instance tasks.") + map_task = MapPythonTask(_task_def, max_concurrency) + + task_index = _compute_array_job_index() + output_prefix = _os.path.join(output_prefix, str(task_index)) + + if test: + _click.echo( + f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} " + f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} " + f"Resolver and args: {resolver} {resolver_args}" + ) + return - _handle_annotated_task( - map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir - ) + _handle_annotated_task(ctx, map_task, inputs, output_prefix) @_click.group() diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 6a1bf0b7dd..e6572f4268 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -2,7 +2,7 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks.sdk_runnable import ExecutionParameters @@ -486,3 +486,78 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: @property def environment(self) -> Dict[str, str]: return self._environment + + +class TaskResolverMixin(object): + """ + Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin, + and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any + case, at execution time, the container image containing the task needs to be spun up again (for container tasks at + least which most tasks are) at which point the container needs to know which task it's supposed to run and + how to rehydrate the task object. + + For example, the serialization of a simple task :: + + # in repo_root/workflows/example.py + @task + def t1(...) -> ...: ... + + might result in a container with arguments like :: + + pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \ + --raw-output-data-prefix /tmp/data \ + --resolver flytekit.core.python_auto_container.default_task_resolver \ + -- \ + task-module repo_root.workflows.example task-name t1 + + At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute`` + bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things, + + #. the ``location`` of the task's task resolver, followed by two dashes, followed by + #. the arguments provided by calling the ``loader_args`` function below. + + The ``default_task_resolver`` declared below knows that :: + + * When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the + key of the task in the module, either the function name, or the variable it was assigned to). + * When ``load_task`` is called, it interprets the first part of the command as the module to call + ``importlib.import_module`` on, and then looks for a key ``t1``. + + This is just the default behavior. Users should feel free to implement their own resolvers. + """ + + @property + @abstractmethod + def location(self) -> str: + pass + + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def load_task(self, loader_args: List[str]) -> Task: + """ + Given the set of identifier keys, should return one Python Task or raise an error if not found + """ + pass + + @abstractmethod + def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]: + """ + Return a list of strings that can help identify the parameter Task + """ + pass + + @abstractmethod + def get_all_tasks(self) -> List[Task]: + """ + Future proof method. Just making it easy to access all tasks (Not required today as we auto register them) + """ + pass + + def task_name(self, t: Task) -> Optional[str]: + """ + Overridable function that can optionally return a custom name for a given task + """ + return None diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index 7709fb2f52..33addbe598 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -1,7 +1,8 @@ from typing import List +from flytekit.core.base_task import TaskResolverMixin from flytekit.core.context_manager import SerializationSettings -from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin +from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 6a574e5576..313e064d1c 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -2,13 +2,13 @@ import importlib import re -from abc import ABC, abstractmethod from typing import Dict, List, Optional, TypeVar from flytekit.common.tasks.raw_container import _get_container_definition -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.context_manager import FlyteContext, ImageConfig, SerializationSettings from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance from flytekit.loggers import logger from flytekit.models import task as _task_model @@ -17,14 +17,6 @@ T = TypeVar("T") -class FlyteTrackedABC(type(TrackedInstance), type(ABC)): - """ - This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the - well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass - of the metaclasses of all its bases`` error. - """ - - class PythonAutoContainerTask(PythonTask[T], metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the @@ -114,9 +106,22 @@ def container_image(self) -> Optional[str]: def resources(self) -> ResourceSpec: return self._resources - @abstractmethod def get_command(self, settings: SerializationSettings) -> List[str]: - pass + container_args = [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--resolver", + self.task_resolver.location, + "--", + *self.task_resolver.loader_args(settings, self), + ] + + return container_args def get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {**settings.env, **self.environment} if self.environment else settings.env @@ -137,81 +142,6 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe ) -class TaskResolverMixin(object): - """ - Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin, - and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any - case, at execution time, the container image containing the task needs to be spun up again (for container tasks at - least which most tasks are) at which point the container needs to know which task it's supposed to run and - how to rehydrate the task object. - - For example, the serialization of a simple task :: - - # in repo_root/workflows/example.py - @task - def t1(...) -> ...: ... - - might result in a container with arguments like :: - - pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \ - --raw-output-data-prefix /tmp/data \ - --resolver flytekit.core.python_auto_container.default_task_resolver \ - -- \ - task-module repo_root.workflows.example task-name t1 - - At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute`` - bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things, - - #. the ``location`` of the task's task resolver, followed by two dashes, followed by - #. the arguments provided by calling the ``loader_args`` function below. - - The ``default_task_resolver`` declared below knows that :: - - * When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the - key of the task in the module, either the function name, or the variable it was assigned to). - * When ``load_task`` is called, it interprets the first part of the command as the module to call - ``importlib.import_module`` on, and then looks for a key ``t1``. - - This is just the default behavior. Users should feel free to implement their own resolvers. - """ - - @property - @abstractmethod - def location(self) -> str: - pass - - @abstractmethod - def name(self) -> str: - pass - - @abstractmethod - def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: - """ - Given the set of identifier keys, should return one Python Task or raise an error if not found - """ - pass - - @abstractmethod - def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: - """ - Return a list of strings that can help identify the parameter Task - """ - pass - - @abstractmethod - def get_all_tasks(self) -> List[PythonAutoContainerTask]: - """ - Future proof method. Just making it easy to access all tasks (Not required today as we auto register them) - """ - pass - - def task_name(self, t: PythonAutoContainerTask) -> Optional[str]: - """ - Overridable function that can optionally return a custom name for a given task - """ - return None - - class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ Please see the notes in the TaskResolverMixin as it describes this default behavior. diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py new file mode 100644 index 0000000000..a615f2f6f5 --- /dev/null +++ b/flytekit/core/python_customized_container_task.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional, Type, TypeVar + +from flyteidl.core import tasks_pb2 as _tasks_pb2 + +from flytekit.common import utils as common_utils +from flytekit.common.tasks.raw_container import _get_container_definition +from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin +from flytekit.core.context_manager import FlyteContext, Image, ImageConfig, SerializationSettings +from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor +from flytekit.core.tracker import TrackedInstance +from flytekit.loggers import logger +from flytekit.models import task as _task_model +from flytekit.models.core import identifier as identifier_models +from flytekit.models.security import Secret, SecurityContext +from flytekit.tools.module_loader import load_object_from_module + +TC = TypeVar("TC") + + +class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): + """ + Please take a look at the comments for ``ExecutableTemplateShimTask`` as well. This class should be subclassed + and a custom Executor provided as a default to this parent class constructor when building a new external-container + flytekit-only plugin. + + This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order + to write such a task, authors need to: + + * subclass the ``ShimTaskExecutor`` class and override the ``execute_from_model`` function. This function is + where all the business logic should go. Keep in mind though that you, the plugin author, will not have access + to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to + * subclass this class, and override the ``get_custom`` function to include all the information the executor + will need to run. + * Also pass the executor you created as the ``executor_type`` argument of this class's constructor. + + Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed + frequently by the Flyte engine. + """ + + SERIALIZE_SETTINGS = SerializationSettings( + project="PLACEHOLDER_PROJECT", + domain="LOCAL", + version="PLACEHOLDER_VERSION", + env=None, + image_config=ImageConfig( + default_image=Image(name="custom_container_task", fqn="flyteorg.io/placeholder", tag="image") + ), + ) + + def __init__( + self, + name: str, + task_config: TC, + container_image: str, + executor_type: Type[ShimTaskExecutor], + task_resolver: Optional[TaskTemplateResolver] = None, + task_type="python-task", + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + environment: Optional[Dict[str, str]] = None, + secret_requests: Optional[List[Secret]] = None, + **kwargs, + ): + """ + :param name: unique name for the task, usually the function's module and name. + :param task_config: Configuration object for Task. Should be a unique type for that specific Task + :param container_image: This is the external container image the task should run at platform-run-time. + :param executor: This is an executor which will actually provide the business logic. + :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver. + :param task_type: String task type to be associated with this Task + :param requests: custom resource request settings. + :param limits: custom resource limit settings. + :param environment: Environment variables you want the task to have when run. + :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will + be mounted based on the configuration in the Secret and available through + the SecretManager using the name of the secret as the group + Ideally the secret keys should also be semi-descriptive. + The key values will be available from runtime, if the backend is configured + to provide secrets and if secrets are available in the configured secrets store. + Possible options for secret stores are + - `Vault `, + - `Confidant `, + - `Kube secrets ` + - `AWS Parameter store `_ + etc + """ + sec_ctx = None + if secret_requests: + for s in secret_requests: + if not isinstance(s, Secret): + raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") + sec_ctx = SecurityContext(secrets=secret_requests) + super().__init__( + tt=None, + executor_type=executor_type, + task_type=task_type, + name=name, + task_config=task_config, + security_ctx=sec_ctx, + **kwargs, + ) + self._resources = ResourceSpec( + requests=requests if requests else Resources(), limits=limits if limits else Resources() + ) + self._environment = environment + self._container_image = container_image + self._task_resolver = task_resolver or default_task_template_resolver + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # Overriding base implementation to raise an error, force task author to implement + raise NotImplementedError + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + # Overriding base implementation but not doing anything. Technically this should be the task config, + # but the IDL limitation that the value also has to be a string is very limiting. + # Recommend putting information you need in the config into custom instead, because when serializing + # the custom field, we jsonify custom and the place it into a protobuf struct. This config field + # just gets put into a Dict[str, str] + return {} + + @property + def resources(self) -> ResourceSpec: + return self._resources + + @property + def task_resolver(self) -> TaskTemplateResolver: + return self._task_resolver + + @property + def task_template(self) -> Optional[_task_model.TaskTemplate]: + """ + Override the base class implementation to serialize on first call. + """ + return self._task_template or self.serialize_to_model(settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS) + + @property + def container_image(self) -> str: + return self._container_image + + def get_command(self, settings: SerializationSettings) -> List[str]: + container_args = [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--resolver", + self.task_resolver.location, + "--", + *self.task_resolver.loader_args(settings, self), + ] + + return container_args + + def get_container(self, settings: SerializationSettings) -> _task_model.Container: + env = {**settings.env, **self.environment} if self.environment else settings.env + return _get_container_definition( + image=self.container_image, + command=[], + args=self.get_command(settings=settings), + data_loading_config=None, + environment=env, + storage_request=self.resources.requests.storage, + cpu_request=self.resources.requests.cpu, + gpu_request=self.resources.requests.gpu, + memory_request=self.resources.requests.mem, + storage_limit=self.resources.limits.storage, + cpu_limit=self.resources.limits.cpu, + gpu_limit=self.resources.limits.gpu, + memory_limit=self.resources.limits.mem, + ) + + def serialize_to_model(self, settings: SerializationSettings) -> _task_model.TaskTemplate: + # This doesn't get called from translator unfortunately. Will need to move the translator to use the model + # objects directly first. + # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. + # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of + # customized-container tasks have a familiar thing to work with. + obj = _task_model.TaskTemplate( + identifier_models.Identifier( + identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version + ), + self.task_type, + self.metadata.to_taskmetadata_model(), + self.interface, + self.get_custom(settings), + container=self.get_container(settings), + config=self.get_config(settings), + ) + self._task_template = obj + return obj + + +class TaskTemplateResolver(TrackedInstance, TaskResolverMixin): + """ + This is a special resolver that resolves the task above at execution time, using only the ``TaskTemplate``, + meaning it should only be used for tasks that contain all pertinent information within the template itself. + + This class differs from some TaskResolverMixin pattern a bit. Most of the other resolvers you'll find, + + * restores the same task when ``load_task`` is called as the object that ``loader_args`` was called on. + That is, even though at run time it's in a container on a cluster and is obviously a different Python process, + the Python object in memory should look the same. + * offers a one-to-one mapping between the list of strings returned by the ``loader_args`` function, an the task, + at least within the container. + + This resolver differs in that, + * when loading a task, the task that is a loaded is always an ``ExecutableTemplateShimTask``, regardless of what + kind of task it was originally. It will only ever have what's available to it from the ``TaskTemplate``. No + information that wasn't serialized into the template will be available. + * all tasks will result in the same list of strings for a given subclass of the ``ShimTaskExecutor`` + executor. The strings will be ``["{{.taskTemplatePath}}", "path.to.your.executor"]`` + + Also, ``get_all_tasks`` will always return an empty list, at least for now. + """ + + def __init__(self): + super(TaskTemplateResolver, self).__init__() + + def name(self) -> str: + return "task template resolver" + + # The return type of this function is different, it should be a Task, but it's not because it doesn't make + # sense for ExecutableTemplateShimTask to inherit from Task. + def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: + logger.info(f"Task template loader args: {loader_args}") + ctx = FlyteContext.current_context() + task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") + ctx.file_access.get_data(loader_args[0], task_template_local_path) + task_template_proto = common_utils.load_proto_from_file(_tasks_pb2.TaskTemplate, task_template_local_path) + task_template_model = _task_model.TaskTemplate.from_flyte_idl(task_template_proto) + + executor_class = load_object_from_module(loader_args[1]) + return ExecutableTemplateShimTask(task_template_model, executor_class) + + def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: + return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"] + + def get_all_tasks(self) -> List[Task]: + return [] + + +default_task_template_resolver = TaskTemplateResolver() diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index eba817084e..3ba8b0fc76 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -4,9 +4,10 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union -from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings +from flytekit.core.base_task import TaskResolverMixin +from flytekit.core.context_manager import ExecutionState, FlyteContext from flytekit.core.interface import transform_signature_to_interface -from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin, default_task_resolver +from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver from flytekit.core.tracker import isnested, istestfunction from flytekit.core.workflow import ( PythonFunctionWorkflow, @@ -46,23 +47,6 @@ def __init__( ): super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs) - def get_command(self, settings: SerializationSettings) -> List[str]: - container_args = [ - "pyflyte-execute", - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - "--resolver", - self.task_resolver.location, - "--", - *self.task_resolver.loader_args(settings, self), - ] - - return container_args - class PythonFunctionTask(PythonAutoContainerTask[T]): """ @@ -147,23 +131,6 @@ def execute(self, **kwargs) -> Any: elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) - def get_command(self, settings: SerializationSettings) -> List[str]: - container_args = [ - "pyflyte-execute", - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - "--resolver", - self.task_resolver.location, - "--", - *self.task_resolver.loader_args(settings, self), - ] - - return container_args - def compile_into_workflow( self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py new file mode 100644 index 0000000000..c8349c0339 --- /dev/null +++ b/flytekit/core/shim_task.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Any, Generic, Type, TypeVar, Union + +from flytekit import ExecutionParameters, FlyteContext, logger +from flytekit.core.tracker import TrackedInstance +from flytekit.core.type_engine import TypeEngine +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_model + + +class ExecutableTemplateShimTask(object): + """ + The canonical ``@task`` decorated Python function task is pretty simple to reason about. At execution time (either + locally or on a Flyte cluster), the function runs. + + This class, along with the ``ShimTaskExecutor`` class below, represents another execution pattern. This pattern, + has two components: + * The ``TaskTemplate``, or something like it like a ``FlyteTask``. + * An executor, which can use information from the task template (including the ``custom`` field) + + Basically at execution time (both locally and on a Flyte cluster), the task template is given to the executor, + which is responsible for computing and returning the results. + + .. note:: + + The interface at execution time will have to derived from the Flyte IDL interface, which means it may be lossy. + This is because when a task is serialized from Python into the ``TaskTemplate`` some information is lost because + Flyte IDL can't keep track of every single Python type (or Java type if writing in the Java flytekit). + + This class also implements the ``dispatch_execute`` and ``execute`` functions to make it look like a ``PythonTask`` + that the ``entrypoint.py`` can execute, even though this class doesn't inherit from ``PythonTask``. + """ + + def __init__(self, tt: _task_model.TaskTemplate, executor_type: Type[ShimTaskExecutor], *args, **kwargs): + self._executor_type = executor_type + self._executor = executor_type() + self._task_template = tt + super().__init__(*args, **kwargs) + + @property + def task_template(self) -> _task_model.TaskTemplate: + return self._task_template + + @property + def executor(self) -> ShimTaskExecutor: + return self._executor + + @property + def executor_type(self) -> Type[ShimTaskExecutor]: + return self._executor_type + + def execute(self, **kwargs) -> Any: + """ + Send things off to the executor instead of running here. + """ + return self.executor.execute_from_model(self.task_template, **kwargs) + + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return user_params + + def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + """ + This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. + """ + return rval + + def dispatch_execute( + self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + """ + This function is mostly copied from the base PythonTask, but differs in that we have to infer the Python + interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` like in task + classes that derive from the base ``PythonTask``. + """ + # Invoked before the task is executed + new_user_params = self.pre_execute(ctx.user_space_params) + + # Create another execution context with the new user params, but let's keep the same working dir + with ctx.new_execution_context( + mode=ctx.execution_state.mode, + execution_params=new_user_params, + working_dir=ctx.execution_state.working_dir, + ) as exec_ctx: + # Added: Have to reverse the Python interface from the task template Flyte interface + # See docstring for more details. + guessed_python_input_types = TypeEngine.guess_python_types(self.task_template.interface.inputs) + native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, guessed_python_input_types) + + logger.info(f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}") + try: + native_outputs = self.execute(**native_inputs) + except Exception as e: + logger.exception(f"Exception when executing {e}") + raise e + + logger.info(f"Task executed successfully in user level, outputs: {native_outputs}") + # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is + # bubbled up to be handled at the callee layer. + native_outputs = self.post_execute(new_user_params, native_outputs) + + # Short circuit the translation to literal map because what's returned may be a dj spec (or an + # already-constructed LiteralMap if the dynamic task was a no-op), not python native values + if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( + native_outputs, _dynamic_job.DynamicJobSpec + ): + return native_outputs + + expected_output_names = list(self.task_template.interface.outputs.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + # Deleted some stuff + native_outputs_as_map = {expected_output_names[0]: native_outputs} + elif len(expected_output_names) == 0: + native_outputs_as_map = {} + else: + native_outputs_as_map = { + expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) + } + + # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption + # built into the IDL that all the values of a literal map are of the same type. + literals = {} + for k, v in native_outputs_as_map.items(): + literal_type = self.task_template.interface.outputs[k].type + py_type = type(v) + + if isinstance(v, tuple): + raise AssertionError( + f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}" + ) + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + raise AssertionError(f"failed to convert return value for var {k}") from e + + outputs_literal_map = _literal_models.LiteralMap(literals=literals) + # After the execute has been successfully completed + return outputs_literal_map + + +T = TypeVar("T") + + +class ShimTaskExecutor(TrackedInstance, Generic[T]): + def execute_from_model(self, tt: _task_model.TaskTemplate, **kwargs) -> Any: + """ + This function must be overridden and is where all the business logic for running a task should live. Keep in + mind that you're only working with the ``TaskTemplate``. You won't have access to any information in the task + that wasn't serialized into the template. + + :param tt: This is the template, the serialized form of the task. + :param kwargs: These are the Python native input values to the task. + :return: Python native output values from the task. + """ + raise NotImplementedError diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py new file mode 100644 index 0000000000..bad4f8c555 --- /dev/null +++ b/flytekit/core/tracked_abc.py @@ -0,0 +1,11 @@ +from abc import ABC + +from flytekit.core.tracker import TrackedInstance + + +class FlyteTrackedABC(type(TrackedInstance), type(ABC)): + """ + This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the + well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass + of the metaclasses of all its bases`` error. + """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 0c7f7e12cb..9ac21d8bcd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -19,6 +19,7 @@ from flytekit.common.types import primitives as _primitives from flytekit.core.context_manager import FlyteContext +from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types @@ -63,6 +64,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: """ raise NotImplementedError("Conversion to LiteralType should be implemented") + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + raise ValueError("By default, transformers do not translate from Flyte types back to Python types") + @abstractmethod def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: """ @@ -123,6 +127,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return self._from_literal_transformer(lv) + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if literal_type.simple is not None and literal_type.simple == self._lt.simple: + return self.python_type + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + class RestrictedTypeError(Exception): pass @@ -303,6 +312,24 @@ def get_available_transformers(cls) -> typing.KeysView[Type]: """ return cls._REGISTRY.keys() + @classmethod + def guess_python_types( + cls, flyte_variable_dict: typing.Dict[str, _interface_models.Variable] + ) -> typing.Dict[str, type]: + python_types = {} + for k, v in flyte_variable_dict.items(): + python_types[k] = cls.guess_python_type(v.type) + return python_types + + @classmethod + def guess_python_type(cls, flyte_type: LiteralType) -> type: + for _, transformer in cls._REGISTRY.items(): + try: + return transformer.guess_python_type(flyte_type) + except ValueError: + logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}") + raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") + class ListTransformer(TypeTransformer[T]): """ @@ -341,6 +368,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lv.collection.literals] + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if literal_type.collection_type: + ct = TypeEngine.guess_python_type(literal_type.collection_type) + return typing.List[ct] + raise ValueError(f"List transformer cannot reverse {literal_type}") + class DictTransformer(TypeTransformer[dict]): """ @@ -412,6 +445,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return _json.loads(_json_format.MessageToJson(lv.scalar.generic)) raise TypeError(f"Cannot convert from {lv} to {expected_python_type}") + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if literal_type.map_value_type: + mt = TypeEngine.guess_python_type(literal_type.map_value_type) + return typing.Dict[str, mt] + raise ValueError(f"Dictionary transformer cannot reverse {literal_type}") + class TextIOTransformer(TypeTransformer[typing.TextIO]): """ diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index 3b3d05f4b4..ef0b69d5df 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -32,11 +32,10 @@ from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_sql_task import SQLTask -from flytekit.core.base_task import IgnoreOutputs, PythonTask +from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import ExecutionState, Image, ImageConfig, SerializationSettings from flytekit.core.interface import Interface from flytekit.core.promise import Promise -from flytekit.core.python_auto_container import TaskResolverMixin from flytekit.core.task import TaskPlugins from flytekit.core.type_engine import DictTransformer, T, TypeEngine, TypeTransformer diff --git a/flytekit/extras/cloud_pickle_resolver.py b/flytekit/extras/cloud_pickle_resolver.py index d6e5c3a008..3ea5fb0e5c 100644 --- a/flytekit/extras/cloud_pickle_resolver.py +++ b/flytekit/extras/cloud_pickle_resolver.py @@ -3,8 +3,9 @@ import cloudpickle # intentionally not yet part of setup.py +from flytekit.core.base_task import TaskResolverMixin from flytekit.core.context_manager import SerializationSettings -from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin +from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 8e34b90cf7..f48ade4bf4 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -372,7 +372,7 @@ def metadata(self): def interface(self): """ The interface definition for this task. - :rtype: flytekit.common.interface.TypedInterface + :rtype: flytekit.models.interface.TypedInterface """ return self._interface diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index 7149e532af..182b11de72 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -3,7 +3,7 @@ import os import pkgutil import sys -from typing import Iterator, List, Union +from typing import Any, Iterator, List, Union from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.local_workflow import SdkRunnableWorkflow as _SdkRunnableWorkflow @@ -190,3 +190,14 @@ def iterate_registerable_entities_in_order( detect_unreferenced_entities=detect_unreferenced_entities, ): yield m, k, o2 + + +def load_object_from_module(object_location: str) -> Any: + """ + # TODO: Handle corner cases, like where the first part is [] maybe + """ + class_obj = object_location.split(".") + class_obj_mod = class_obj[:-1] # e.g. ['flytekit', 'core', 'python_auto_container'] + class_obj_key = class_obj[-1] # e.g. 'default_task_class_obj' + class_obj_mod = importlib.import_module(".".join(class_obj_mod)) + return getattr(class_obj_mod, class_obj_key) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index d51f572b4e..8d073e809a 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -374,5 +374,26 @@ def downloader(x, y): supported_mode=SchemaOpenMode.READ, ) + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if not literal_type.schema: + raise ValueError(f"Cannot reverse {literal_type}") + columns = {} + for literal_column in literal_type.schema.columns: + if literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER: + columns[literal_column.name] = int + elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.FLOAT: + columns[literal_column.name] = float + elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.STRING: + columns[literal_column.name] = str + elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.DATETIME: + columns[literal_column.name] = _datetime.datetime + elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.DURATION: + columns[literal_column.name] = _datetime.timedelta + elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: + columns[literal_column.name] = bool + else: + raise ValueError(f"Unknown schema column type {literal_column}") + return FlyteSchema[columns] + TypeEngine.register(FlyteSchemaTransformer()) diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py b/plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/__init__.py similarity index 100% rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py rename to plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/__init__.py diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/task.py similarity index 100% rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py rename to plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/task.py diff --git a/plugins/sqlalchemy/setup.py b/plugins/flytekit.sqlalchemy/setup.py similarity index 100% rename from plugins/sqlalchemy/setup.py rename to plugins/flytekit.sqlalchemy/setup.py diff --git a/plugins/setup.py b/plugins/setup.py index 48f302bc5a..958fe0ffd0 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -15,7 +15,7 @@ "flytekitplugins-awssagemaker": "awssagemaker", "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", - "flytekitplugins-sqlalchemy": "sqlalchemy", + "flytekitplugins-sqlalchemy": "flytekit.sqlalchemy", "flytekitplugins-dolt": "flytekitplugins.dolt", } diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 0077c2b373..15d69ceffa 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -4,9 +4,10 @@ from flytekit.common.translator import get_serializable from flytekit.core import context_manager +from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import Image, ImageConfig -from flytekit.core.python_auto_container import TaskResolverMixin +from flytekit.core.python_auto_container import default_task_resolver from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -101,3 +102,8 @@ def test_mixin(): x.loader_args(None, None) x.get_all_tasks() x.load_task([]) + + +def test_error(): + with pytest.raises(Exception): + default_task_resolver.get_all_tasks() diff --git a/tests/flytekit/unit/core/test_schema_types.py b/tests/flytekit/unit/core/test_schema_types.py new file mode 100644 index 0000000000..dce40a3f3e --- /dev/null +++ b/tests/flytekit/unit/core/test_schema_types.py @@ -0,0 +1,38 @@ +from datetime import datetime, timedelta + +import pytest + +from flytekit import kwtypes +from flytekit.core.type_engine import TypeEngine +from flytekit.types.schema import FlyteSchema, SchemaFormat + + +def test_typed_schema(): + s = FlyteSchema[kwtypes(x=int, y=float)] + assert s.format() == SchemaFormat.PARQUET + assert s.columns() == {"x": int, "y": float} + + +def test_schema_back_and_forth(): + orig = FlyteSchema[kwtypes(TrackId=int, Name=str)] + lt = TypeEngine.to_literal_type(orig) + pt = TypeEngine.guess_python_type(lt) + lt2 = TypeEngine.to_literal_type(pt) + assert lt == lt2 + + +def test_remaining_prims(): + orig = FlyteSchema[kwtypes(my_dt=datetime, my_td=timedelta, my_b=bool)] + lt = TypeEngine.to_literal_type(orig) + pt = TypeEngine.guess_python_type(lt) + lt2 = TypeEngine.to_literal_type(pt) + assert lt == lt2 + + +def test_bad_conversion(): + orig = FlyteSchema[kwtypes(my_custom=bool)] + lt = TypeEngine.to_literal_type(orig) + # Make a not real column type + lt.schema.columns[0]._type = 15 + with pytest.raises(ValueError): + TypeEngine.guess_python_type(lt) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 04a9e57160..26d262f6f0 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,7 +6,6 @@ import pytest from flyteidl.core import errors_pb2 -from flytekit import kwtypes from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import ( DictTransformer, @@ -20,7 +19,6 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar from flytekit.models.types import LiteralType, SimpleType from flytekit.types.file.file import FlyteFile -from flytekit.types.schema import FlyteSchema, SchemaFormat def test_type_engine(): @@ -92,12 +90,6 @@ def test_file_format_getting_python_value(): assert pv.extension() == "txt" -def test_typed_schema(): - s = FlyteSchema[kwtypes(x=int, y=float)] - assert s.format() == SchemaFormat.PARQUET - assert s.columns() == {"x": int, "y": float} - - def test_dict_transformer(): d = DictTransformer() @@ -199,6 +191,48 @@ def test_protos(): TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError) +def test_guessing_basic(): + b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) + pt = TypeEngine.guess_python_type(b) + assert pt is bool + + lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER) + pt = TypeEngine.guess_python_type(lt) + assert pt is int + + lt = model_types.LiteralType(simple=model_types.SimpleType.STRING) + pt = TypeEngine.guess_python_type(lt) + assert pt is str + + lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION) + pt = TypeEngine.guess_python_type(lt) + assert pt is timedelta + + lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME) + pt = TypeEngine.guess_python_type(lt) + assert pt is datetime.datetime + + lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT) + pt = TypeEngine.guess_python_type(lt) + assert pt is float + + lt = model_types.LiteralType(simple=model_types.SimpleType.NONE) + pt = TypeEngine.guess_python_type(lt) + assert pt is None + + +def test_guessing_containers(): + b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) + lt = model_types.LiteralType(collection_type=b) + pt = TypeEngine.guess_python_type(lt) + assert pt == typing.List[bool] + + dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION) + lt = model_types.LiteralType(map_value_type=dur) + pt = TypeEngine.guess_python_type(lt) + assert pt == typing.Dict[str, timedelta] + + def test_zero_floats(): ctx = FlyteContext.current_context() diff --git a/tests/flytekit/unit/tools/test_module_loader.py b/tests/flytekit/unit/tools/test_module_loader.py index 5b0df20416..aa7fdd255c 100644 --- a/tests/flytekit/unit/tools/test_module_loader.py +++ b/tests/flytekit/unit/tools/test_module_loader.py @@ -37,3 +37,8 @@ def test_module_loading(): assert [ pkg.__file__ for pkg in module_loader.iterate_modules(["top.a", "top.middle.a", "top.middle.bottom.a"]) ] == [os.path.join(lvl, "a.py") for lvl in (top_level, middle_level, bottom_level)] + + +def test_load_object(): + loader_self = module_loader.load_object_from_module(f"{module_loader.__name__}.load_object_from_module") + assert loader_self.__module__ == f"{module_loader.__name__}"