diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py
index 3c33b13f30..87a01014b6 100644
--- a/flytekit/bin/entrypoint.py
+++ b/flytekit/bin/entrypoint.py
@@ -1,3 +1,4 @@
+import contextlib
import datetime as _datetime
import importlib as _importlib
import logging as _logging
@@ -26,7 +27,6 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings, get_image_config
from flytekit.core.map_task import MapPythonTask
from flytekit.core.promise import VoidPromise
-from flytekit.core.python_auto_container import TaskResolverMixin
from flytekit.engines import loader as _engine_loader
from flytekit.interfaces import random as _flyte_random
from flytekit.interfaces.data import data_proxy as _data_proxy
@@ -38,6 +38,7 @@
from flytekit.models.core import errors as _error_models
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
+from flytekit.tools.module_loader import load_object_from_module
def _compute_array_job_index():
@@ -73,7 +74,12 @@ def _map_job_index_to_child_index(local_input_dir, datadir, index):
return mapping_proto.literals[index].scalar.primitive.integer
-def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, output_prefix: str):
+def _dispatch_execute(
+ ctx: FlyteContext,
+ task_def: PythonTask,
+ inputs_path: str,
+ output_prefix: str,
+):
"""
Dispatches execute to PythonTask
Step1: Download inputs and load into a literal map
@@ -90,6 +96,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
ctx.file_access.get_data(inputs_path, local_inputs_file)
input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file)
idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto)
+
# Step2
outputs = task_def.dispatch_execute(ctx, idl_input_literals)
# Step3a
@@ -122,7 +129,7 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.warning(f"IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}")
return
# Step 3c
- _logging.error(f"Exception when executing task {task_def.name}, reason {str(e)}")
+ _logging.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}")
_logging.error("!! Begin Unknown System Error Captured by Flyte !!")
exc_str = _traceback.format_exc()
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
@@ -142,18 +149,12 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")
-def _handle_annotated_task(
- task_def: PythonTask,
- inputs: str,
- output_prefix: str,
+@contextlib.contextmanager
+def setup_execution(
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
- """
- Entrypoint for all PythonTask extensions
- """
- _click.echo("Running native-typed task")
cloud_provider = _platform_config.CLOUD_PROVIDER.get()
log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
_logging.getLogger().setLevel(log_level)
@@ -235,7 +236,20 @@ def _handle_annotated_task(
execution_params=execution_parameters,
additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir},
) as ctx:
- _dispatch_execute(ctx, task_def, inputs, output_prefix)
+ yield ctx
+
+
+def _handle_annotated_task(
+ ctx: FlyteContext,
+ task_def: PythonTask,
+ inputs: str,
+ output_prefix: str,
+):
+ """
+ Entrypoint for all PythonTask extensions
+ """
+ _click.echo("Running native-typed task")
+ _dispatch_execute(ctx, task_def, inputs, output_prefix)
@_scopes.system_entry_point
@@ -277,18 +291,6 @@ def _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_outp
)
-def _load_resolver(resolver_location: str) -> TaskResolverMixin:
- # Load the actual resolver - this cannot be a nested thing, whatever kind of resolver it is, it has to be loadable
- # directly from importlib
- # TODO: Handle corner cases, like where the first part is [] maybe
- # e.g. flytekit.core.python_auto_container.default_task_resolver
- resolver = resolver_location.split(".")
- resolver_mod = resolver[:-1] # e.g. ['flytekit', 'core', 'python_auto_container']
- resolver_key = resolver[-1] # e.g. 'default_task_resolver'
- resolver_mod = _importlib.import_module(".".join(resolver_mod))
- return getattr(resolver_mod, resolver_key)
-
-
@_scopes.system_entry_point
def _execute_task(
inputs,
@@ -326,18 +328,17 @@ def _execute_task(
if len(resolver_args) < 1:
raise Exception("cannot be <1")
- resolver_obj = _load_resolver(resolver)
with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
- # Use the resolver to load the actual task object
- _task_def = resolver_obj.load_task(loader_args=resolver_args)
- if test:
- _click.echo(
- f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
- )
- return
- _handle_annotated_task(
- _task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
- )
+ with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
+ resolver_obj = load_object_from_module(resolver)
+ # Use the resolver to load the actual task object
+ _task_def = resolver_obj.load_task(loader_args=resolver_args)
+ if test:
+ _click.echo(
+ f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
+ )
+ return
+ _handle_annotated_task(ctx, _task_def, inputs, output_prefix)
@_scopes.system_entry_point
@@ -355,28 +356,27 @@ def _execute_map_task(
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")
- resolver_obj = _load_resolver(resolver)
with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
- # Use the resolver to load the actual task object
- _task_def = resolver_obj.load_task(loader_args=resolver_args)
- if not isinstance(_task_def, PythonFunctionTask):
- raise Exception("Map tasks cannot be run with instance tasks.")
- map_task = MapPythonTask(_task_def, max_concurrency)
-
- task_index = _compute_array_job_index()
- output_prefix = _os.path.join(output_prefix, str(task_index))
-
- if test:
- _click.echo(
- f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} "
- f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} "
- f"Resolver and args: {resolver} {resolver_args}"
- )
- return
+ with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
+ resolver_obj = load_object_from_module(resolver)
+ # Use the resolver to load the actual task object
+ _task_def = resolver_obj.load_task(loader_args=resolver_args)
+ if not isinstance(_task_def, PythonFunctionTask):
+ raise Exception("Map tasks cannot be run with instance tasks.")
+ map_task = MapPythonTask(_task_def, max_concurrency)
+
+ task_index = _compute_array_job_index()
+ output_prefix = _os.path.join(output_prefix, str(task_index))
+
+ if test:
+ _click.echo(
+ f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} "
+ f"New output prefix: {output_prefix} Raw output path: {raw_output_data_prefix} "
+ f"Resolver and args: {resolver} {resolver_args}"
+ )
+ return
- _handle_annotated_task(
- map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
- )
+ _handle_annotated_task(ctx, map_task, inputs, output_prefix)
@_click.group()
diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py
index 6a1bf0b7dd..e6572f4268 100644
--- a/flytekit/core/base_task.py
+++ b/flytekit/core/base_task.py
@@ -2,7 +2,7 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
-from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
@@ -486,3 +486,78 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
@property
def environment(self) -> Dict[str, str]:
return self._environment
+
+
+class TaskResolverMixin(object):
+ """
+ Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin,
+ and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any
+ case, at execution time, the container image containing the task needs to be spun up again (for container tasks at
+ least which most tasks are) at which point the container needs to know which task it's supposed to run and
+ how to rehydrate the task object.
+
+ For example, the serialization of a simple task ::
+
+ # in repo_root/workflows/example.py
+ @task
+ def t1(...) -> ...: ...
+
+ might result in a container with arguments like ::
+
+ pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \
+ --raw-output-data-prefix /tmp/data \
+ --resolver flytekit.core.python_auto_container.default_task_resolver \
+ -- \
+ task-module repo_root.workflows.example task-name t1
+
+ At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute``
+ bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things,
+
+ #. the ``location`` of the task's task resolver, followed by two dashes, followed by
+ #. the arguments provided by calling the ``loader_args`` function below.
+
+ The ``default_task_resolver`` declared below knows that ::
+
+ * When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the
+ key of the task in the module, either the function name, or the variable it was assigned to).
+ * When ``load_task`` is called, it interprets the first part of the command as the module to call
+ ``importlib.import_module`` on, and then looks for a key ``t1``.
+
+ This is just the default behavior. Users should feel free to implement their own resolvers.
+ """
+
+ @property
+ @abstractmethod
+ def location(self) -> str:
+ pass
+
+ @abstractmethod
+ def name(self) -> str:
+ pass
+
+ @abstractmethod
+ def load_task(self, loader_args: List[str]) -> Task:
+ """
+ Given the set of identifier keys, should return one Python Task or raise an error if not found
+ """
+ pass
+
+ @abstractmethod
+ def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
+ """
+ Return a list of strings that can help identify the parameter Task
+ """
+ pass
+
+ @abstractmethod
+ def get_all_tasks(self) -> List[Task]:
+ """
+ Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
+ """
+ pass
+
+ def task_name(self, t: Task) -> Optional[str]:
+ """
+ Overridable function that can optionally return a custom name for a given task
+ """
+ return None
diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py
index 7709fb2f52..33addbe598 100644
--- a/flytekit/core/class_based_resolver.py
+++ b/flytekit/core/class_based_resolver.py
@@ -1,7 +1,8 @@
from typing import List
+from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
-from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin
+from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.tracker import TrackedInstance
diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py
index 6a574e5576..313e064d1c 100644
--- a/flytekit/core/python_auto_container.py
+++ b/flytekit/core/python_auto_container.py
@@ -2,13 +2,13 @@
import importlib
import re
-from abc import ABC, abstractmethod
from typing import Dict, List, Optional, TypeVar
from flytekit.common.tasks.raw_container import _get_container_definition
-from flytekit.core.base_task import PythonTask
+from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.context_manager import FlyteContext, ImageConfig, SerializationSettings
from flytekit.core.resources import Resources, ResourceSpec
+from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance
from flytekit.loggers import logger
from flytekit.models import task as _task_model
@@ -17,14 +17,6 @@
T = TypeVar("T")
-class FlyteTrackedABC(type(TrackedInstance), type(ABC)):
- """
- This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the
- well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass
- of the metaclasses of all its bases`` error.
- """
-
-
class PythonAutoContainerTask(PythonTask[T], metaclass=FlyteTrackedABC):
"""
A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the
@@ -114,9 +106,22 @@ def container_image(self) -> Optional[str]:
def resources(self) -> ResourceSpec:
return self._resources
- @abstractmethod
def get_command(self, settings: SerializationSettings) -> List[str]:
- pass
+ container_args = [
+ "pyflyte-execute",
+ "--inputs",
+ "{{.input}}",
+ "--output-prefix",
+ "{{.outputPrefix}}",
+ "--raw-output-data-prefix",
+ "{{.rawOutputDataPrefix}}",
+ "--resolver",
+ self.task_resolver.location,
+ "--",
+ *self.task_resolver.loader_args(settings, self),
+ ]
+
+ return container_args
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {**settings.env, **self.environment} if self.environment else settings.env
@@ -137,81 +142,6 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
)
-class TaskResolverMixin(object):
- """
- Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin,
- and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any
- case, at execution time, the container image containing the task needs to be spun up again (for container tasks at
- least which most tasks are) at which point the container needs to know which task it's supposed to run and
- how to rehydrate the task object.
-
- For example, the serialization of a simple task ::
-
- # in repo_root/workflows/example.py
- @task
- def t1(...) -> ...: ...
-
- might result in a container with arguments like ::
-
- pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \
- --raw-output-data-prefix /tmp/data \
- --resolver flytekit.core.python_auto_container.default_task_resolver \
- -- \
- task-module repo_root.workflows.example task-name t1
-
- At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute``
- bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things,
-
- #. the ``location`` of the task's task resolver, followed by two dashes, followed by
- #. the arguments provided by calling the ``loader_args`` function below.
-
- The ``default_task_resolver`` declared below knows that ::
-
- * When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the
- key of the task in the module, either the function name, or the variable it was assigned to).
- * When ``load_task`` is called, it interprets the first part of the command as the module to call
- ``importlib.import_module`` on, and then looks for a key ``t1``.
-
- This is just the default behavior. Users should feel free to implement their own resolvers.
- """
-
- @property
- @abstractmethod
- def location(self) -> str:
- pass
-
- @abstractmethod
- def name(self) -> str:
- pass
-
- @abstractmethod
- def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
- """
- Given the set of identifier keys, should return one Python Task or raise an error if not found
- """
- pass
-
- @abstractmethod
- def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]:
- """
- Return a list of strings that can help identify the parameter Task
- """
- pass
-
- @abstractmethod
- def get_all_tasks(self) -> List[PythonAutoContainerTask]:
- """
- Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
- """
- pass
-
- def task_name(self, t: PythonAutoContainerTask) -> Optional[str]:
- """
- Overridable function that can optionally return a custom name for a given task
- """
- return None
-
-
class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Please see the notes in the TaskResolverMixin as it describes this default behavior.
diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py
new file mode 100644
index 0000000000..a615f2f6f5
--- /dev/null
+++ b/flytekit/core/python_customized_container_task.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+import os
+from typing import Any, Dict, List, Optional, Type, TypeVar
+
+from flyteidl.core import tasks_pb2 as _tasks_pb2
+
+from flytekit.common import utils as common_utils
+from flytekit.common.tasks.raw_container import _get_container_definition
+from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin
+from flytekit.core.context_manager import FlyteContext, Image, ImageConfig, SerializationSettings
+from flytekit.core.resources import Resources, ResourceSpec
+from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor
+from flytekit.core.tracker import TrackedInstance
+from flytekit.loggers import logger
+from flytekit.models import task as _task_model
+from flytekit.models.core import identifier as identifier_models
+from flytekit.models.security import Secret, SecurityContext
+from flytekit.tools.module_loader import load_object_from_module
+
+TC = TypeVar("TC")
+
+
+class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]):
+ """
+ Please take a look at the comments for ``ExecutableTemplateShimTask`` as well. This class should be subclassed
+ and a custom Executor provided as a default to this parent class constructor when building a new external-container
+ flytekit-only plugin.
+
+ This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order
+ to write such a task, authors need to:
+
+ * subclass the ``ShimTaskExecutor`` class and override the ``execute_from_model`` function. This function is
+ where all the business logic should go. Keep in mind though that you, the plugin author, will not have access
+ to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to
+ * subclass this class, and override the ``get_custom`` function to include all the information the executor
+ will need to run.
+ * Also pass the executor you created as the ``executor_type`` argument of this class's constructor.
+
+ Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed
+ frequently by the Flyte engine.
+ """
+
+ SERIALIZE_SETTINGS = SerializationSettings(
+ project="PLACEHOLDER_PROJECT",
+ domain="LOCAL",
+ version="PLACEHOLDER_VERSION",
+ env=None,
+ image_config=ImageConfig(
+ default_image=Image(name="custom_container_task", fqn="flyteorg.io/placeholder", tag="image")
+ ),
+ )
+
+ def __init__(
+ self,
+ name: str,
+ task_config: TC,
+ container_image: str,
+ executor_type: Type[ShimTaskExecutor],
+ task_resolver: Optional[TaskTemplateResolver] = None,
+ task_type="python-task",
+ requests: Optional[Resources] = None,
+ limits: Optional[Resources] = None,
+ environment: Optional[Dict[str, str]] = None,
+ secret_requests: Optional[List[Secret]] = None,
+ **kwargs,
+ ):
+ """
+ :param name: unique name for the task, usually the function's module and name.
+ :param task_config: Configuration object for Task. Should be a unique type for that specific Task
+ :param container_image: This is the external container image the task should run at platform-run-time.
+ :param executor: This is an executor which will actually provide the business logic.
+ :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver.
+ :param task_type: String task type to be associated with this Task
+ :param requests: custom resource request settings.
+ :param limits: custom resource limit settings.
+ :param environment: Environment variables you want the task to have when run.
+ :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will
+ be mounted based on the configuration in the Secret and available through
+ the SecretManager using the name of the secret as the group
+ Ideally the secret keys should also be semi-descriptive.
+ The key values will be available from runtime, if the backend is configured
+ to provide secrets and if secrets are available in the configured secrets store.
+ Possible options for secret stores are
+ - `Vault `,
+ - `Confidant `,
+ - `Kube secrets `
+ - `AWS Parameter store `_
+ etc
+ """
+ sec_ctx = None
+ if secret_requests:
+ for s in secret_requests:
+ if not isinstance(s, Secret):
+ raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
+ sec_ctx = SecurityContext(secrets=secret_requests)
+ super().__init__(
+ tt=None,
+ executor_type=executor_type,
+ task_type=task_type,
+ name=name,
+ task_config=task_config,
+ security_ctx=sec_ctx,
+ **kwargs,
+ )
+ self._resources = ResourceSpec(
+ requests=requests if requests else Resources(), limits=limits if limits else Resources()
+ )
+ self._environment = environment
+ self._container_image = container_image
+ self._task_resolver = task_resolver or default_task_template_resolver
+
+ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
+ # Overriding base implementation to raise an error, force task author to implement
+ raise NotImplementedError
+
+ def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
+ # Overriding base implementation but not doing anything. Technically this should be the task config,
+ # but the IDL limitation that the value also has to be a string is very limiting.
+ # Recommend putting information you need in the config into custom instead, because when serializing
+ # the custom field, we jsonify custom and the place it into a protobuf struct. This config field
+ # just gets put into a Dict[str, str]
+ return {}
+
+ @property
+ def resources(self) -> ResourceSpec:
+ return self._resources
+
+ @property
+ def task_resolver(self) -> TaskTemplateResolver:
+ return self._task_resolver
+
+ @property
+ def task_template(self) -> Optional[_task_model.TaskTemplate]:
+ """
+ Override the base class implementation to serialize on first call.
+ """
+ return self._task_template or self.serialize_to_model(settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS)
+
+ @property
+ def container_image(self) -> str:
+ return self._container_image
+
+ def get_command(self, settings: SerializationSettings) -> List[str]:
+ container_args = [
+ "pyflyte-execute",
+ "--inputs",
+ "{{.input}}",
+ "--output-prefix",
+ "{{.outputPrefix}}",
+ "--raw-output-data-prefix",
+ "{{.rawOutputDataPrefix}}",
+ "--resolver",
+ self.task_resolver.location,
+ "--",
+ *self.task_resolver.loader_args(settings, self),
+ ]
+
+ return container_args
+
+ def get_container(self, settings: SerializationSettings) -> _task_model.Container:
+ env = {**settings.env, **self.environment} if self.environment else settings.env
+ return _get_container_definition(
+ image=self.container_image,
+ command=[],
+ args=self.get_command(settings=settings),
+ data_loading_config=None,
+ environment=env,
+ storage_request=self.resources.requests.storage,
+ cpu_request=self.resources.requests.cpu,
+ gpu_request=self.resources.requests.gpu,
+ memory_request=self.resources.requests.mem,
+ storage_limit=self.resources.limits.storage,
+ cpu_limit=self.resources.limits.cpu,
+ gpu_limit=self.resources.limits.gpu,
+ memory_limit=self.resources.limits.mem,
+ )
+
+ def serialize_to_model(self, settings: SerializationSettings) -> _task_model.TaskTemplate:
+ # This doesn't get called from translator unfortunately. Will need to move the translator to use the model
+ # objects directly first.
+ # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow.
+ # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of
+ # customized-container tasks have a familiar thing to work with.
+ obj = _task_model.TaskTemplate(
+ identifier_models.Identifier(
+ identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version
+ ),
+ self.task_type,
+ self.metadata.to_taskmetadata_model(),
+ self.interface,
+ self.get_custom(settings),
+ container=self.get_container(settings),
+ config=self.get_config(settings),
+ )
+ self._task_template = obj
+ return obj
+
+
+class TaskTemplateResolver(TrackedInstance, TaskResolverMixin):
+ """
+ This is a special resolver that resolves the task above at execution time, using only the ``TaskTemplate``,
+ meaning it should only be used for tasks that contain all pertinent information within the template itself.
+
+ This class differs from some TaskResolverMixin pattern a bit. Most of the other resolvers you'll find,
+
+ * restores the same task when ``load_task`` is called as the object that ``loader_args`` was called on.
+ That is, even though at run time it's in a container on a cluster and is obviously a different Python process,
+ the Python object in memory should look the same.
+ * offers a one-to-one mapping between the list of strings returned by the ``loader_args`` function, an the task,
+ at least within the container.
+
+ This resolver differs in that,
+ * when loading a task, the task that is a loaded is always an ``ExecutableTemplateShimTask``, regardless of what
+ kind of task it was originally. It will only ever have what's available to it from the ``TaskTemplate``. No
+ information that wasn't serialized into the template will be available.
+ * all tasks will result in the same list of strings for a given subclass of the ``ShimTaskExecutor``
+ executor. The strings will be ``["{{.taskTemplatePath}}", "path.to.your.executor"]``
+
+ Also, ``get_all_tasks`` will always return an empty list, at least for now.
+ """
+
+ def __init__(self):
+ super(TaskTemplateResolver, self).__init__()
+
+ def name(self) -> str:
+ return "task template resolver"
+
+ # The return type of this function is different, it should be a Task, but it's not because it doesn't make
+ # sense for ExecutableTemplateShimTask to inherit from Task.
+ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask:
+ logger.info(f"Task template loader args: {loader_args}")
+ ctx = FlyteContext.current_context()
+ task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb")
+ ctx.file_access.get_data(loader_args[0], task_template_local_path)
+ task_template_proto = common_utils.load_proto_from_file(_tasks_pb2.TaskTemplate, task_template_local_path)
+ task_template_model = _task_model.TaskTemplate.from_flyte_idl(task_template_proto)
+
+ executor_class = load_object_from_module(loader_args[1])
+ return ExecutableTemplateShimTask(task_template_model, executor_class)
+
+ def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]:
+ return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"]
+
+ def get_all_tasks(self) -> List[Task]:
+ return []
+
+
+default_task_template_resolver = TaskTemplateResolver()
diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py
index eba817084e..3ba8b0fc76 100644
--- a/flytekit/core/python_function_task.py
+++ b/flytekit/core/python_function_task.py
@@ -4,9 +4,10 @@
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar, Union
-from flytekit.core.context_manager import ExecutionState, FlyteContext, SerializationSettings
+from flytekit.core.base_task import TaskResolverMixin
+from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_signature_to_interface
-from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin, default_task_resolver
+from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver
from flytekit.core.tracker import isnested, istestfunction
from flytekit.core.workflow import (
PythonFunctionWorkflow,
@@ -46,23 +47,6 @@ def __init__(
):
super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs)
- def get_command(self, settings: SerializationSettings) -> List[str]:
- container_args = [
- "pyflyte-execute",
- "--inputs",
- "{{.input}}",
- "--output-prefix",
- "{{.outputPrefix}}",
- "--raw-output-data-prefix",
- "{{.rawOutputDataPrefix}}",
- "--resolver",
- self.task_resolver.location,
- "--",
- *self.task_resolver.loader_args(settings, self),
- ]
-
- return container_args
-
class PythonFunctionTask(PythonAutoContainerTask[T]):
"""
@@ -147,23 +131,6 @@ def execute(self, **kwargs) -> Any:
elif self.execution_mode == self.ExecutionBehavior.DYNAMIC:
return self.dynamic_execute(self._task_function, **kwargs)
- def get_command(self, settings: SerializationSettings) -> List[str]:
- container_args = [
- "pyflyte-execute",
- "--inputs",
- "{{.input}}",
- "--output-prefix",
- "{{.outputPrefix}}",
- "--raw-output-data-prefix",
- "{{.rawOutputDataPrefix}}",
- "--resolver",
- self.task_resolver.location,
- "--",
- *self.task_resolver.loader_args(settings, self),
- ]
-
- return container_args
-
def compile_into_workflow(
self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs
) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py
new file mode 100644
index 0000000000..c8349c0339
--- /dev/null
+++ b/flytekit/core/shim_task.py
@@ -0,0 +1,163 @@
+from __future__ import annotations
+
+from typing import Any, Generic, Type, TypeVar, Union
+
+from flytekit import ExecutionParameters, FlyteContext, logger
+from flytekit.core.tracker import TrackedInstance
+from flytekit.core.type_engine import TypeEngine
+from flytekit.models import dynamic_job as _dynamic_job
+from flytekit.models import literals as _literal_models
+from flytekit.models import task as _task_model
+
+
+class ExecutableTemplateShimTask(object):
+ """
+ The canonical ``@task`` decorated Python function task is pretty simple to reason about. At execution time (either
+ locally or on a Flyte cluster), the function runs.
+
+ This class, along with the ``ShimTaskExecutor`` class below, represents another execution pattern. This pattern,
+ has two components:
+ * The ``TaskTemplate``, or something like it like a ``FlyteTask``.
+ * An executor, which can use information from the task template (including the ``custom`` field)
+
+ Basically at execution time (both locally and on a Flyte cluster), the task template is given to the executor,
+ which is responsible for computing and returning the results.
+
+ .. note::
+
+ The interface at execution time will have to derived from the Flyte IDL interface, which means it may be lossy.
+ This is because when a task is serialized from Python into the ``TaskTemplate`` some information is lost because
+ Flyte IDL can't keep track of every single Python type (or Java type if writing in the Java flytekit).
+
+ This class also implements the ``dispatch_execute`` and ``execute`` functions to make it look like a ``PythonTask``
+ that the ``entrypoint.py`` can execute, even though this class doesn't inherit from ``PythonTask``.
+ """
+
+ def __init__(self, tt: _task_model.TaskTemplate, executor_type: Type[ShimTaskExecutor], *args, **kwargs):
+ self._executor_type = executor_type
+ self._executor = executor_type()
+ self._task_template = tt
+ super().__init__(*args, **kwargs)
+
+ @property
+ def task_template(self) -> _task_model.TaskTemplate:
+ return self._task_template
+
+ @property
+ def executor(self) -> ShimTaskExecutor:
+ return self._executor
+
+ @property
+ def executor_type(self) -> Type[ShimTaskExecutor]:
+ return self._executor_type
+
+ def execute(self, **kwargs) -> Any:
+ """
+ Send things off to the executor instead of running here.
+ """
+ return self.executor.execute_from_model(self.task_template, **kwargs)
+
+ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
+ """
+ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask.
+ """
+ return user_params
+
+ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
+ """
+ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask.
+ """
+ return rval
+
+ def dispatch_execute(
+ self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
+ ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
+ """
+ This function is mostly copied from the base PythonTask, but differs in that we have to infer the Python
+ interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` like in task
+ classes that derive from the base ``PythonTask``.
+ """
+ # Invoked before the task is executed
+ new_user_params = self.pre_execute(ctx.user_space_params)
+
+ # Create another execution context with the new user params, but let's keep the same working dir
+ with ctx.new_execution_context(
+ mode=ctx.execution_state.mode,
+ execution_params=new_user_params,
+ working_dir=ctx.execution_state.working_dir,
+ ) as exec_ctx:
+ # Added: Have to reverse the Python interface from the task template Flyte interface
+ # See docstring for more details.
+ guessed_python_input_types = TypeEngine.guess_python_types(self.task_template.interface.inputs)
+ native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, guessed_python_input_types)
+
+ logger.info(f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}")
+ try:
+ native_outputs = self.execute(**native_inputs)
+ except Exception as e:
+ logger.exception(f"Exception when executing {e}")
+ raise e
+
+ logger.info(f"Task executed successfully in user level, outputs: {native_outputs}")
+ # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
+ # bubbled up to be handled at the callee layer.
+ native_outputs = self.post_execute(new_user_params, native_outputs)
+
+ # Short circuit the translation to literal map because what's returned may be a dj spec (or an
+ # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
+ if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance(
+ native_outputs, _dynamic_job.DynamicJobSpec
+ ):
+ return native_outputs
+
+ expected_output_names = list(self.task_template.interface.outputs.keys())
+ if len(expected_output_names) == 1:
+ # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
+ # length one. That convention is used for naming outputs - and single-length-NamedTuples are
+ # particularly troublesome but elegant handling of them is not a high priority
+ # Again, we're using the output_tuple_name as a proxy.
+ # Deleted some stuff
+ native_outputs_as_map = {expected_output_names[0]: native_outputs}
+ elif len(expected_output_names) == 0:
+ native_outputs_as_map = {}
+ else:
+ native_outputs_as_map = {
+ expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)
+ }
+
+ # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
+ # built into the IDL that all the values of a literal map are of the same type.
+ literals = {}
+ for k, v in native_outputs_as_map.items():
+ literal_type = self.task_template.interface.outputs[k].type
+ py_type = type(v)
+
+ if isinstance(v, tuple):
+ raise AssertionError(
+ f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}"
+ )
+ try:
+ literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type)
+ except Exception as e:
+ raise AssertionError(f"failed to convert return value for var {k}") from e
+
+ outputs_literal_map = _literal_models.LiteralMap(literals=literals)
+ # After the execute has been successfully completed
+ return outputs_literal_map
+
+
+T = TypeVar("T")
+
+
+class ShimTaskExecutor(TrackedInstance, Generic[T]):
+ def execute_from_model(self, tt: _task_model.TaskTemplate, **kwargs) -> Any:
+ """
+ This function must be overridden and is where all the business logic for running a task should live. Keep in
+ mind that you're only working with the ``TaskTemplate``. You won't have access to any information in the task
+ that wasn't serialized into the template.
+
+ :param tt: This is the template, the serialized form of the task.
+ :param kwargs: These are the Python native input values to the task.
+ :return: Python native output values from the task.
+ """
+ raise NotImplementedError
diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py
new file mode 100644
index 0000000000..bad4f8c555
--- /dev/null
+++ b/flytekit/core/tracked_abc.py
@@ -0,0 +1,11 @@
+from abc import ABC
+
+from flytekit.core.tracker import TrackedInstance
+
+
+class FlyteTrackedABC(type(TrackedInstance), type(ABC)):
+ """
+ This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the
+ well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass
+ of the metaclasses of all its bases`` error.
+ """
diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py
index 0c7f7e12cb..9ac21d8bcd 100644
--- a/flytekit/core/type_engine.py
+++ b/flytekit/core/type_engine.py
@@ -19,6 +19,7 @@
from flytekit.common.types import primitives as _primitives
from flytekit.core.context_manager import FlyteContext
+from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
@@ -63,6 +64,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
"""
raise NotImplementedError("Conversion to LiteralType should be implemented")
+ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+ raise ValueError("By default, transformers do not translate from Flyte types back to Python types")
+
@abstractmethod
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
"""
@@ -123,6 +127,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
return self._from_literal_transformer(lv)
+ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+ if literal_type.simple is not None and literal_type.simple == self._lt.simple:
+ return self.python_type
+ raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
+
class RestrictedTypeError(Exception):
pass
@@ -303,6 +312,24 @@ def get_available_transformers(cls) -> typing.KeysView[Type]:
"""
return cls._REGISTRY.keys()
+ @classmethod
+ def guess_python_types(
+ cls, flyte_variable_dict: typing.Dict[str, _interface_models.Variable]
+ ) -> typing.Dict[str, type]:
+ python_types = {}
+ for k, v in flyte_variable_dict.items():
+ python_types[k] = cls.guess_python_type(v.type)
+ return python_types
+
+ @classmethod
+ def guess_python_type(cls, flyte_type: LiteralType) -> type:
+ for _, transformer in cls._REGISTRY.items():
+ try:
+ return transformer.guess_python_type(flyte_type)
+ except ValueError:
+ logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
+ raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}")
+
class ListTransformer(TypeTransformer[T]):
"""
@@ -341,6 +368,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lv.collection.literals]
+ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+ if literal_type.collection_type:
+ ct = TypeEngine.guess_python_type(literal_type.collection_type)
+ return typing.List[ct]
+ raise ValueError(f"List transformer cannot reverse {literal_type}")
+
class DictTransformer(TypeTransformer[dict]):
"""
@@ -412,6 +445,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return _json.loads(_json_format.MessageToJson(lv.scalar.generic))
raise TypeError(f"Cannot convert from {lv} to {expected_python_type}")
+ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+ if literal_type.map_value_type:
+ mt = TypeEngine.guess_python_type(literal_type.map_value_type)
+ return typing.Dict[str, mt]
+ raise ValueError(f"Dictionary transformer cannot reverse {literal_type}")
+
class TextIOTransformer(TypeTransformer[typing.TextIO]):
"""
diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py
index 3b3d05f4b4..ef0b69d5df 100644
--- a/flytekit/extend/__init__.py
+++ b/flytekit/extend/__init__.py
@@ -32,11 +32,10 @@
from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
from flytekit.core.base_sql_task import SQLTask
-from flytekit.core.base_task import IgnoreOutputs, PythonTask
+from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.context_manager import ExecutionState, Image, ImageConfig, SerializationSettings
from flytekit.core.interface import Interface
from flytekit.core.promise import Promise
-from flytekit.core.python_auto_container import TaskResolverMixin
from flytekit.core.task import TaskPlugins
from flytekit.core.type_engine import DictTransformer, T, TypeEngine, TypeTransformer
diff --git a/flytekit/extras/cloud_pickle_resolver.py b/flytekit/extras/cloud_pickle_resolver.py
index d6e5c3a008..3ea5fb0e5c 100644
--- a/flytekit/extras/cloud_pickle_resolver.py
+++ b/flytekit/extras/cloud_pickle_resolver.py
@@ -3,8 +3,9 @@
import cloudpickle # intentionally not yet part of setup.py
+from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
-from flytekit.core.python_auto_container import PythonAutoContainerTask, TaskResolverMixin
+from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.tracker import TrackedInstance
diff --git a/flytekit/models/task.py b/flytekit/models/task.py
index 8e34b90cf7..f48ade4bf4 100644
--- a/flytekit/models/task.py
+++ b/flytekit/models/task.py
@@ -372,7 +372,7 @@ def metadata(self):
def interface(self):
"""
The interface definition for this task.
- :rtype: flytekit.common.interface.TypedInterface
+ :rtype: flytekit.models.interface.TypedInterface
"""
return self._interface
diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py
index 7149e532af..182b11de72 100644
--- a/flytekit/tools/module_loader.py
+++ b/flytekit/tools/module_loader.py
@@ -3,7 +3,7 @@
import os
import pkgutil
import sys
-from typing import Iterator, List, Union
+from typing import Any, Iterator, List, Union
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.local_workflow import SdkRunnableWorkflow as _SdkRunnableWorkflow
@@ -190,3 +190,14 @@ def iterate_registerable_entities_in_order(
detect_unreferenced_entities=detect_unreferenced_entities,
):
yield m, k, o2
+
+
+def load_object_from_module(object_location: str) -> Any:
+ """
+ # TODO: Handle corner cases, like where the first part is [] maybe
+ """
+ class_obj = object_location.split(".")
+ class_obj_mod = class_obj[:-1] # e.g. ['flytekit', 'core', 'python_auto_container']
+ class_obj_key = class_obj[-1] # e.g. 'default_task_class_obj'
+ class_obj_mod = importlib.import_module(".".join(class_obj_mod))
+ return getattr(class_obj_mod, class_obj_key)
diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py
index d51f572b4e..8d073e809a 100644
--- a/flytekit/types/schema/types.py
+++ b/flytekit/types/schema/types.py
@@ -374,5 +374,26 @@ def downloader(x, y):
supported_mode=SchemaOpenMode.READ,
)
+ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+ if not literal_type.schema:
+ raise ValueError(f"Cannot reverse {literal_type}")
+ columns = {}
+ for literal_column in literal_type.schema.columns:
+ if literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER:
+ columns[literal_column.name] = int
+ elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.FLOAT:
+ columns[literal_column.name] = float
+ elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.STRING:
+ columns[literal_column.name] = str
+ elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.DATETIME:
+ columns[literal_column.name] = _datetime.datetime
+ elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.DURATION:
+ columns[literal_column.name] = _datetime.timedelta
+ elif literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN:
+ columns[literal_column.name] = bool
+ else:
+ raise ValueError(f"Unknown schema column type {literal_column}")
+ return FlyteSchema[columns]
+
TypeEngine.register(FlyteSchemaTransformer())
diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py b/plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/__init__.py
similarity index 100%
rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py
rename to plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/__init__.py
diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/task.py
similarity index 100%
rename from plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py
rename to plugins/flytekit.sqlalchemy/flytekitplugins/sqlalchemy/task.py
diff --git a/plugins/sqlalchemy/setup.py b/plugins/flytekit.sqlalchemy/setup.py
similarity index 100%
rename from plugins/sqlalchemy/setup.py
rename to plugins/flytekit.sqlalchemy/setup.py
diff --git a/plugins/setup.py b/plugins/setup.py
index 48f302bc5a..958fe0ffd0 100644
--- a/plugins/setup.py
+++ b/plugins/setup.py
@@ -15,7 +15,7 @@
"flytekitplugins-awssagemaker": "awssagemaker",
"flytekitplugins-kftensorflow": "kftensorflow",
"flytekitplugins-pandera": "pandera",
- "flytekitplugins-sqlalchemy": "sqlalchemy",
+ "flytekitplugins-sqlalchemy": "flytekit.sqlalchemy",
"flytekitplugins-dolt": "flytekitplugins.dolt",
}
diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py
index 0077c2b373..15d69ceffa 100644
--- a/tests/flytekit/unit/core/test_resolver.py
+++ b/tests/flytekit/unit/core/test_resolver.py
@@ -4,9 +4,10 @@
from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
+from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.context_manager import Image, ImageConfig
-from flytekit.core.python_auto_container import TaskResolverMixin
+from flytekit.core.python_auto_container import default_task_resolver
from flytekit.core.task import task
from flytekit.core.workflow import workflow
@@ -101,3 +102,8 @@ def test_mixin():
x.loader_args(None, None)
x.get_all_tasks()
x.load_task([])
+
+
+def test_error():
+ with pytest.raises(Exception):
+ default_task_resolver.get_all_tasks()
diff --git a/tests/flytekit/unit/core/test_schema_types.py b/tests/flytekit/unit/core/test_schema_types.py
new file mode 100644
index 0000000000..dce40a3f3e
--- /dev/null
+++ b/tests/flytekit/unit/core/test_schema_types.py
@@ -0,0 +1,38 @@
+from datetime import datetime, timedelta
+
+import pytest
+
+from flytekit import kwtypes
+from flytekit.core.type_engine import TypeEngine
+from flytekit.types.schema import FlyteSchema, SchemaFormat
+
+
+def test_typed_schema():
+ s = FlyteSchema[kwtypes(x=int, y=float)]
+ assert s.format() == SchemaFormat.PARQUET
+ assert s.columns() == {"x": int, "y": float}
+
+
+def test_schema_back_and_forth():
+ orig = FlyteSchema[kwtypes(TrackId=int, Name=str)]
+ lt = TypeEngine.to_literal_type(orig)
+ pt = TypeEngine.guess_python_type(lt)
+ lt2 = TypeEngine.to_literal_type(pt)
+ assert lt == lt2
+
+
+def test_remaining_prims():
+ orig = FlyteSchema[kwtypes(my_dt=datetime, my_td=timedelta, my_b=bool)]
+ lt = TypeEngine.to_literal_type(orig)
+ pt = TypeEngine.guess_python_type(lt)
+ lt2 = TypeEngine.to_literal_type(pt)
+ assert lt == lt2
+
+
+def test_bad_conversion():
+ orig = FlyteSchema[kwtypes(my_custom=bool)]
+ lt = TypeEngine.to_literal_type(orig)
+ # Make a not real column type
+ lt.schema.columns[0]._type = 15
+ with pytest.raises(ValueError):
+ TypeEngine.guess_python_type(lt)
diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py
index 04a9e57160..26d262f6f0 100644
--- a/tests/flytekit/unit/core/test_type_engine.py
+++ b/tests/flytekit/unit/core/test_type_engine.py
@@ -6,7 +6,6 @@
import pytest
from flyteidl.core import errors_pb2
-from flytekit import kwtypes
from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import (
DictTransformer,
@@ -20,7 +19,6 @@
from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.types import LiteralType, SimpleType
from flytekit.types.file.file import FlyteFile
-from flytekit.types.schema import FlyteSchema, SchemaFormat
def test_type_engine():
@@ -92,12 +90,6 @@ def test_file_format_getting_python_value():
assert pv.extension() == "txt"
-def test_typed_schema():
- s = FlyteSchema[kwtypes(x=int, y=float)]
- assert s.format() == SchemaFormat.PARQUET
- assert s.columns() == {"x": int, "y": float}
-
-
def test_dict_transformer():
d = DictTransformer()
@@ -199,6 +191,48 @@ def test_protos():
TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError)
+def test_guessing_basic():
+ b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
+ pt = TypeEngine.guess_python_type(b)
+ assert pt is bool
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is int
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.STRING)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is str
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is timedelta
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is datetime.datetime
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is float
+
+ lt = model_types.LiteralType(simple=model_types.SimpleType.NONE)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt is None
+
+
+def test_guessing_containers():
+ b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
+ lt = model_types.LiteralType(collection_type=b)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt == typing.List[bool]
+
+ dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION)
+ lt = model_types.LiteralType(map_value_type=dur)
+ pt = TypeEngine.guess_python_type(lt)
+ assert pt == typing.Dict[str, timedelta]
+
+
def test_zero_floats():
ctx = FlyteContext.current_context()
diff --git a/tests/flytekit/unit/tools/test_module_loader.py b/tests/flytekit/unit/tools/test_module_loader.py
index 5b0df20416..aa7fdd255c 100644
--- a/tests/flytekit/unit/tools/test_module_loader.py
+++ b/tests/flytekit/unit/tools/test_module_loader.py
@@ -37,3 +37,8 @@ def test_module_loading():
assert [
pkg.__file__ for pkg in module_loader.iterate_modules(["top.a", "top.middle.a", "top.middle.bottom.a"])
] == [os.path.join(lvl, "a.py") for lvl in (top_level, middle_level, bottom_level)]
+
+
+def test_load_object():
+ loader_self = module_loader.load_object_from_module(f"{module_loader.__name__}.load_object_from_module")
+ assert loader_self.__module__ == f"{module_loader.__name__}"