From 9c748c913eaa6fb7fbf55e5e66053b813bc4cd4c Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 27 Jul 2021 15:07:17 -0700 Subject: [PATCH] wip - pr into #559 (#565) Signed-off-by: wild-endeavor Signed-off-by: Ketan Umare --- .gitignore | 2 +- dev-requirements.txt | 13 +- doc-requirements.txt | 32 +- docs/source/_templates/custom.rst | 4 + docs/source/conf.py | 2 + docs/source/design/control_plane.rst | 115 +++-- docs/source/design/index.rst | 4 +- docs/source/index.rst | 39 ++ flytekit/bin/entrypoint.py | 14 +- flytekit/clients/friendly.py | 2 +- flytekit/core/base_task.py | 4 +- flytekit/core/context_manager.py | 2 +- flytekit/core/data_persistence.py | 67 +-- flytekit/core/docstring.py | 27 ++ flytekit/core/interface.py | 32 +- flytekit/core/python_auto_container.py | 3 + flytekit/core/python_function_task.py | 2 + flytekit/core/workflow.py | 11 +- flytekit/extras/persistence/__init__.py | 6 +- flytekit/extras/persistence/gcs_gsutil.py | 28 +- flytekit/extras/persistence/http.py | 15 +- flytekit/extras/persistence/s3_awscli.py | 32 +- flytekit/interfaces/data/data_proxy.py | 211 -------- flytekit/remote/__init__.py | 56 ++- flytekit/remote/component_nodes.py | 4 + flytekit/remote/launch_plan.py | 2 + flytekit/remote/nodes.py | 4 + flytekit/remote/remote.py | 454 +++++++++++------- flytekit/remote/tasks/executions.py | 2 + flytekit/remote/tasks/task.py | 2 + flytekit/remote/workflow.py | 2 +- flytekit/remote/workflow_execution.py | 3 + requirements-spark2.txt | 16 +- requirements.txt | 16 +- setup.py | 1 + .../integration/remote/test_remote.py | 21 +- .../unit/bin/test_python_entrypoint.py | 28 +- tests/flytekit/unit/core/test_docstring.py | 95 ++++ .../unit/core/test_flyte_directory.py | 12 +- tests/flytekit/unit/core/test_flyte_file.py | 87 ++-- tests/flytekit/unit/core/test_interface.py | 75 +++ .../flytekit/unit/core/test_serialization.py | 18 + tests/flytekit/unit/core/test_type_hints.py | 6 +- tests/flytekit/unit/core/test_workflows.py | 16 +- tests/flytekit/unit/remote/test_remote.py | 4 +- 45 files changed, 957 insertions(+), 634 deletions(-) create mode 100644 flytekit/core/docstring.py create mode 100644 tests/flytekit/unit/core/test_docstring.py diff --git a/.gitignore b/.gitignore index df3edf57c3..ec56a0f239 100644 --- a/.gitignore +++ b/.gitignore @@ -26,5 +26,5 @@ dist .python-version _build/ docs/source/generated/ -.pytest-flyte +.pytest_flyte htmlcov diff --git a/dev-requirements.txt b/dev-requirements.txt index a79927f114..49266067c6 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -89,12 +89,16 @@ dockerpty==0.4.1 # via docker-compose docopt==0.6.2 # via docker-compose +docstring-parser==0.9.1 + # via + # -c requirements.txt + # flytekit flake8==3.9.2 # via # -r dev-requirements.in # flake8-black # flake8-isort -flake8-black==0.2.2 +flake8-black==0.2.3 # via -r dev-requirements.in flake8-isort==4.0.0 # via -r dev-requirements.in @@ -136,7 +140,7 @@ markupsafe==2.0.1 # via # -c requirements.txt # jinja2 -marshmallow==3.12.2 +marshmallow==3.13.0 # via # -c requirements.txt # dataclasses-json @@ -166,7 +170,7 @@ natsort==7.1.1 # via # -c requirements.txt # flytekit -numpy==1.21.0 +numpy==1.21.1 # via # -c requirements.txt # pandas @@ -183,7 +187,7 @@ paramiko==2.7.2 # via # -c requirements.txt # docker -pathspec==0.8.1 +pathspec==0.9.0 # via # -c requirements.txt # black @@ -315,6 +319,7 @@ texttable==1.6.4 toml==0.10.2 # via # coverage + # flake8-black # mypy # pytest tomli==1.0.4 diff --git a/doc-requirements.txt b/doc-requirements.txt index 7c2d90bcc6..e41581fd90 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.6.2 +astroid==2.6.5 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -39,9 +39,9 @@ black==21.7b0 # via papermill bleach==3.3.1 # via nbconvert -boto3==1.18.1 +boto3==1.18.6 # via sagemaker-training -botocore==1.21.1 +botocore==1.21.6 # via # boto3 # s3transfer @@ -70,7 +70,7 @@ css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.4 # via flytekit -debugpy==1.3.0 +debugpy==1.4.1 # via ipykernel decorator==5.0.9 # via @@ -82,9 +82,11 @@ deprecated==1.2.12 # via flytekit dirhash==0.2.1 # via flytekit -docker-image-py==0.1.10 +docker-image-py==0.1.11 # via flytekit -docutils==0.16 +docstring-parser==0.9.1 + # via flytekit +docutils==0.17.1 # via sphinx entrypoints==0.3 # via @@ -98,7 +100,7 @@ gevent==21.1.2 # via sagemaker-training greenlet==1.1.0 # via gevent -grpcio==1.38.1 +grpcio==1.39.0 # via # -r doc-requirements.in # flytekit @@ -112,7 +114,7 @@ importlib-metadata==4.6.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==6.0.2 +ipykernel==6.0.3 # via flytekit ipython==7.25.0 # via ipykernel @@ -154,7 +156,7 @@ lxml==4.6.3 # via sphinx-material markupsafe==2.0.1 # via jinja2 -marshmallow==3.12.2 +marshmallow==3.13.0 # via # dataclasses-json # marshmallow-enum @@ -188,7 +190,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.21.0 +numpy==1.21.1 # via # flytekit # pandas @@ -199,7 +201,7 @@ packaging==21.0 # via # bleach # sphinx -pandas==1.3.0 +pandas==1.3.1 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -209,7 +211,7 @@ paramiko==2.7.2 # via sagemaker-training parso==0.8.2 # via jedi -pathspec==0.8.1 +pathspec==0.9.0 # via # black # scantree @@ -322,7 +324,7 @@ sortedcontainers==2.4.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -sphinx==3.5.4 +sphinx==4.1.2 # via # -r doc-requirements.in # furo @@ -343,7 +345,7 @@ sphinx-fontawesome==0.0.6 # via -r doc-requirements.in sphinx-gallery==0.9.0 # via -r doc-requirements.in -sphinx-material==0.0.32 +sphinx-material==0.0.34 # via -r doc-requirements.in sphinx-prompt==1.4.0 # via -r doc-requirements.in @@ -373,7 +375,7 @@ textwrap3==0.9.2 # via ansiwrap thrift==0.13.0 # via hmsclient -tomli==1.0.4 +tomli==1.1.0 # via black tornado==6.1 # via diff --git a/docs/source/_templates/custom.rst b/docs/source/_templates/custom.rst index f7cfbd333b..a1447ac0e9 100644 --- a/docs/source/_templates/custom.rst +++ b/docs/source/_templates/custom.rst @@ -11,7 +11,11 @@ .. rubric:: {{ _('Methods') }} {% for item in methods %} + + {% if item != '__init__' %} .. automethod:: {{ item }} + {% endif %} + {%- endfor %} {% endif %} {% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 24415d22ca..c403747ca9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -52,6 +52,8 @@ # build the templated autosummary files autosummary_generate = True +autodoc_typehints = "description" + # autosectionlabel throws warnings if section names are duplicated. # The following tells autosectionlabel to not throw a warning for # duplicated section names that are in different documents. diff --git a/docs/source/design/control_plane.rst b/docs/source/design/control_plane.rst index 1d24b50b8d..012277b1c3 100644 --- a/docs/source/design/control_plane.rst +++ b/docs/source/design/control_plane.rst @@ -1,55 +1,100 @@ .. _design-control-plane: -############################ -Control Plane Objects -############################ -For those who require programmatic access to the control plane, certain APIs are available through "control plane classes". +################################################### +FlyteRemote: A Programmatic Control Plane Interface +################################################### -.. warning:: +For those who require programmatic access to the control plane, the :mod:`~flytekit.remote` module enables you to perform +certain operations in a python runtime environment. - The syntax of this section, while it will continue to work, is subject to change. +Since this section naturally deals with the control plane, this discussion is only relevant for those who have a Flyte +backend set up and have access to it (a :std:ref:`local sandbox ` will suffice as well). -Since this section naturally deals with the control plane, this discussion is only relevant for those who have a Flyte backend set up, and have access to it (a local backend will suffice as well of course). These objects do not rely on the underlying code they represent being locally available. +*************************** +Create a FlyteRemote Object +*************************** -******* -Setup -******* -Similar to the CLIs, this section requires proper configuration. Please follow the setup guide there. +The :class:`~flytekit.remote.remote.FlyteRemote` class is the entrypoint for programmatically performing operations in a python +runtime. There are two ways of creating a remote object. -Unlike the CLI case however, you may need to explicitly target the configuration file like so :: +**Initialize directly** - from flytekit.configuration.common import CONFIGURATION_SINGLETON - CONFIGURATION_SINGLETON.reset_config("/Users/full/path/to/config") +.. code-block:: python + + from flytekit.remote import FlyteRemote + + remote = FlyteRemote( + default_project="project", + default_domain="domain", + flyte_admin_url="", + insecure=True, + ) + +**Initialize from flyte config** + +.. TODO: link documentation to flyte config and environment variables + +This will initialize a :class:`~flytekit.remote.remote.FlyteRemote` object from your flyte config file or environment variable +overrides: -******* -Classes -******* -This is not an exhaustive list of the objects available, but should provide the reader with the wherewithal to further ascertain for herself additional functionality. +.. code-block:: python -Task -====== -To fetch a Task :: + remote = FlyteRemote.from_config() - from flytekit.common.tasks.task import SdkTask - SdkTask.fetch('flytetester', 'staging', 'recipes.core_basic.task.square', '49b6c6bdbb86e974ffd9875cab1f721ada8066a7') +***************************** +Fetching Flyte Admin Entities +***************************** +.. code-block:: python -Workflow -======== -To inspect a Workflow :: + flyte_task = remote.fetch_task(name="my_task", version="v1") + flyte_workflow = remote.fetch_workflow(name="my_workflow", version="v1") + flyte_launch_plan = remote.fetch_launch_plan(name="my_launch_plan", version="v1") - from flytekit.common.workflow import SdkWorkflow - wf = SdkWorkflow.fetch('flytetester', 'staging', 'recipes.core_basic.basic_workflow.my_wf', '49b6c6bdbb86e974ffd9875cab1f721ada8066a7') +****************** +Executing Entities +****************** -WorkflowExecution -================= -This class represents one run of a workflow. The ``execution_name`` used here is just the tail end of the URL you see in your browser when looking at a workflow run. +You can execute all of these flyte entities, which returns a :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` object. +For more information on flyte entities, see the See the :ref:`remote flyte entities ` +reference. .. code-block:: python - from flytekit.common.workflow_execution import SdkWorkflowExecution + flyte_entity = ... # one of FlyteTask, FlyteWorkflow, or FlyteLaunchPlan + execution = remote.execute(flyte_entity, inputs={...}) + +******************************** +Waiting for Execution Completion +******************************** + +You can use the :meth:`~flytekit.remote.remote.FlyteRemote.wait` method to synchronously wait for the execution to complete: + +.. code-block:: python + + completed_execution = remote.wait(execution) + +You can also pass in ``wait=True`` to the :meth:`~flytekit.remote.remote.FlyteRemote.execute` method. + +.. code-block:: python + + completed_execution = remote.execute(flyte_entity, inputs={...}, wait=True) + +******************** +Syncing Remote State +******************** + +Use the :meth:`~flytekit.remote.remote.FlyteRemote.sync` method to sync the entity object's state with the remote state + +.. code-block:: python + + synced_execution = remote.sync(execution) + - e = SdkWorkflowExecution.fetch('project_name', 'development', 'execution_name') - e.sync() +**************************** +Inspecting Execution Objects +**************************** -As a workflow is made up of nodes (each of which can contain a task, a subworkflow, or a launch plan), a workflow execution is made up of node executions (each of which can contain a task execution, a subworkflow execution, or a launch plan execution). +At any time you can inspect the inputs, outputs, completion status, error status, and other aspects of a workflow +execution object. See the :ref:`remote execution objects ` reference for a list +of all the available attributes. diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 24eef916a1..7da87efba6 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -8,7 +8,7 @@ Flytekit is comprised of a handful of different logical components, each discuss * :ref:`Models Files ` - These are almost Protobuf generated files. * :ref:`Authoring ` - This provides the core Flyte authoring experiences, allowing users to write tasks, workflows, and launch plans. -* :ref:`Control Plane ` - The code here allows users to interact with the control plane through Python objecs. +* :ref:`Control Plane ` - The code here allows users to interact with the control plane through Python objects. * :ref:`Execution ` - A small shim layer basically that handles interaction with the Flyte ecosystem at execution time. * :ref:`CLIs and Clients ` - Command line tools users may find themselves interacting with and the control plane client the CLIs call. @@ -19,6 +19,6 @@ Flytekit is comprised of a handful of different logical components, each discuss models authoring - control_plane + Control Plane execution clis diff --git a/docs/source/index.rst b/docs/source/index.rst index 83993af0dc..4bc00a3a43 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,6 +11,45 @@ This section of the documentation provides more detailed descriptions of the hig API reference for specific usage details of python functions, classes, and decorators that you import to specify tasks, build workflows, and extend ``flytekit``. +Installation +============ + +.. code-block:: + + pip install flytekit + +For developer environment setup instructions, see the :ref:`contributor guide `. + + +Quickstart +========== + +.. code-block:: python + + from flytekit import task, workflow + + @task + def sum(x: int, y: int) -> int: + return x + y + + @task + def square(z: int) -> int: + return z * z + + @workflow + def my_workflow(x: int, y: int) -> int: + return sum(x=square(z=x), y=square(z=y)) + + print(f"my_workflow output: {my_workflow(x=1, y=2)}") + + +Expected output: + +.. code-block:: + + my_workflow output: 5 + + .. toctree:: :maxdepth: 1 :hidden: diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 7b70309435..1e9a6cb269 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -31,13 +31,12 @@ SerializationSettings, get_image_config, ) +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.map_task import MapPythonTask from flytekit.core.promise import VoidPromise from flytekit.engines import loader as _engine_loader from flytekit.interfaces import random as _flyte_random from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy -from flytekit.interfaces.data.s3 import s3proxy as _s3proxy from flytekit.interfaces.stats.taggable import get_stats as _get_stats from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models @@ -227,18 +226,15 @@ def setup_execution( ) if cloud_provider == _constants.CloudProvider.AWS: - file_access = _data_proxy.FileAccessProvider( + file_access = FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), - remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), + raw_output_prefix=raw_output_data_prefix, ) elif cloud_provider == _constants.CloudProvider.GCP: - file_access = _data_proxy.FileAccessProvider( + file_access = FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), - remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), + raw_output_prefix=raw_output_data_prefix, ) - elif cloud_provider == _constants.CloudProvider.LOCAL: - # A fake remote using the local disk will automatically be created - file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get()) else: raise Exception(f"Bad cloud provider {cloud_provider}") diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index daac8ee088..e016d55485 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -293,7 +293,7 @@ def list_workflows_paginated(self, identifier, limit=100, token=None, filters=No def get_workflow(self, id): """ - This returns a single task for a given ID. + This returns a single workflow for a given ID. :param flytekit.models.core.identifier.Identifier id: The ID representing a given task. :raises: TODO diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index fa1726fbf2..7aa03d41fc 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -33,6 +33,7 @@ FlyteEntities, SerializationSettings, ) +from flytekit.core.docstring import Docstring from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.promise import ( Promise, @@ -372,6 +373,7 @@ def __init__( task_config: T, interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, + docstring: Optional[Docstring] = None, **kwargs, ): """ @@ -389,7 +391,7 @@ def __init__( super().__init__( task_type=task_type, name=name, - interface=transform_interface_to_typed_interface(interface), + interface=transform_interface_to_typed_interface(interface, docstring), **kwargs, ) self._python_interface = interface if interface else Interface() diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 361d88eb06..4fd82ccbf8 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -481,7 +481,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec in all other cases it is preferable to use with_execution_state """ if not working_dir: - working_dir = self.file_access.get_random_local_directory() + working_dir = self.file_access.local_sandbox_dir return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params) @staticmethod diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 77c727fd44..03707fd674 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -22,12 +22,11 @@ import datetime import os -import os as _os import pathlib import typing from abc import abstractmethod from distutils import dir_util as _dir_util -from shutil import copyfile as _copyfile +from shutil import copyfile from typing import Dict, Union from uuid import UUID @@ -48,16 +47,21 @@ def __init__(self, message: str): class DataPersistence(object): """ - Base abstract type for all DataPersistence operations. This can be plugged in using the flytekitplugins architecture + Base abstract type for all DataPersistence operations. This can be plugged in using the flytekitplugins architecture """ - def __init__(self, name: str, *args, **kwargs): + def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs): self._name = name + self._default_prefix = default_prefix @property def name(self) -> str: return self._name + @property + def default_prefix(self) -> typing.Optional[str]: + return self._default_prefix + def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: """ Returns true if the given path exists, else false @@ -86,7 +90,7 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): pass @abstractmethod - def construct_path(self, add_protocol: bool, *paths) -> str: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: """ if add_protocol is true then is prefixed else Constructs a path in the format *args @@ -107,10 +111,10 @@ class DataPersistencePlugins(object): These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins. """ - _PLUGINS: Dict[str, DataPersistence] = {} + _PLUGINS: Dict[str, typing.Type[DataPersistence]] = {} @classmethod - def register_plugin(cls, protocol: str, plugin: DataPersistence, force: bool = False): + def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], force: bool = False): """ Registers the supplied plugin for the specified protocol if one does not already exists. If one exists and force is default or False, then a TypeError is raised. @@ -130,7 +134,7 @@ def register_plugin(cls, protocol: str, plugin: DataPersistence, force: bool = F cls._PLUGINS[protocol] = plugin @classmethod - def find_plugin(cls, path: str) -> DataPersistence: + def find_plugin(cls, path: str) -> typing.Type[DataPersistence]: """ Returns a plugin for the given protocol, else raise a TypeError """ @@ -164,19 +168,16 @@ class DiskPersistence(DataPersistence): PROTOCOL = "file://" - def __init__(self, *args, **kwargs): - """ - :param Text sandbox: - """ - super().__init__(name="local", *args, **kwargs) + def __init__(self, default_prefix: typing.Optional[str] = None, **kwargs): + super().__init__(name="local", default_prefix=default_prefix, **kwargs) @staticmethod def _make_local_path(path): - if not _os.path.exists(path): + if not os.path.exists(path): try: - _os.makedirs(path) + pathlib.Path(path).mkdir(parents=True, exist_ok=True) except OSError: # Guard against race condition - if not _os.path.isdir(path): + if not os.path.isdir(path): raise @staticmethod @@ -201,14 +202,14 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N return def exists(self, path: str): - return _os.path.exists(self.strip_file_header(path)) + return os.path.exists(self.strip_file_header(path)) def get(self, from_path: str, to_path: str, recursive: bool = False): if from_path != to_path: if recursive: _dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) else: - _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) + copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) def put(self, from_path: str, to_path: str, recursive: bool = False): if from_path != to_path: @@ -216,13 +217,14 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): _dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) else: # Emulate s3's flat storage by automatically creating directory path - self._make_local_path(_os.path.dirname(self.strip_file_header(to_path))) + self._make_local_path(os.path.dirname(self.strip_file_header(to_path))) # Write the object to a local file in the sandbox - _copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) + copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - def construct_path(self, add_protocol: bool, *args) -> str: - if add_protocol: - return os.path.join(self.PROTOCOL, *args) + def construct_path(self, _: bool, add_prefix: bool, *args) -> str: + # Ignore add_protocol for now. Only complicates things + if add_prefix: + return os.path.join(self.default_prefix, *args) return os.path.join(*args) @@ -231,6 +233,7 @@ class FileAccessProvider(object): This is the class that is available through the FlyteContext and can be used for persisting data to the remote durable store. """ + def __init__(self, local_sandbox_dir: Union[str, os.PathLike], raw_output_prefix: str): # Local access if local_sandbox_dir is None or local_sandbox_dir == "": @@ -238,9 +241,9 @@ def __init__(self, local_sandbox_dir: Union[str, os.PathLike], raw_output_prefix local_sandbox_dir_appended = os.path.join(local_sandbox_dir, "local_flytekit") self._local_sandbox_dir = pathlib.Path(local_sandbox_dir_appended) self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) - self._local = DiskPersistence() + self._local = DiskPersistence(default_prefix=local_sandbox_dir_appended) - self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix) + self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)(default_prefix=raw_output_prefix) self._raw_output_prefix = raw_output_prefix @staticmethod @@ -270,10 +273,10 @@ def construct_random_path( if file_path_or_file_name: _, tail = os.path.split(file_path_or_file_name) if tail: - return persist.construct_path(False, self._raw_output_prefix, key, tail) + return persist.construct_path(False, True, key, tail) else: logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") - return persist.construct_path(False, self._raw_output_prefix, key) + return persist.construct_path(False, True, key) def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ @@ -302,7 +305,7 @@ def exists(self, path: str) -> bool: """ checks if the given path exists """ - return DataPersistencePlugins.find_plugin(path).exists(path) + return DataPersistencePlugins.find_plugin(path)().exists(path) def download_directory(self, remote_path: str, local_path: str): """ @@ -338,7 +341,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): """ try: with PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)): - DataPersistencePlugins.find_plugin(remote_path).get(remote_path, local_path, recursive=is_multipart) + DataPersistencePlugins.find_plugin(remote_path)().get(remote_path, local_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" @@ -361,7 +364,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul """ try: with PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)): - DataPersistencePlugins.find_plugin(remote_path).put(local_path, remote_path, recursive=is_multipart) + DataPersistencePlugins.find_plugin(remote_path)().put(local_path, remote_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" @@ -369,8 +372,8 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul ) from ex -DataPersistencePlugins.register_plugin("file://", DiskPersistence()) -DataPersistencePlugins.register_plugin("/", DiskPersistence()) +DataPersistencePlugins.register_plugin("file://", DiskPersistence) +DataPersistencePlugins.register_plugin("/", DiskPersistence) # TODO make this use tmpdir tmp_dir = os.path.join("/tmp/flyte", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) diff --git a/flytekit/core/docstring.py b/flytekit/core/docstring.py new file mode 100644 index 0000000000..420f26f8f5 --- /dev/null +++ b/flytekit/core/docstring.py @@ -0,0 +1,27 @@ +from typing import Callable, Dict, Optional + +from docstring_parser import parse + + +class Docstring(object): + def __init__(self, docstring: str = None, callable_: Callable = None): + if docstring is not None: + self._parsed_docstring = parse(docstring) + else: + self._parsed_docstring = parse(callable_.__doc__) + + @property + def input_descriptions(self) -> Dict[str, str]: + return {p.arg_name: p.description for p in self._parsed_docstring.params} + + @property + def output_descriptions(self) -> Dict[str, str]: + return {p.return_name: p.description for p in self._parsed_docstring.many_returns} + + @property + def short_description(self) -> Optional[str]: + return self._parsed_docstring.short_description + + @property + def long_description(self) -> Optional[str]: + return self._parsed_docstring.long_description diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 992fae3a0d..64beca8e6a 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -9,6 +9,7 @@ from flytekit.common.exceptions.user import FlyteValidationException from flytekit.core import context_manager +from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger from flytekit.models import interface as _interface_models @@ -178,6 +179,7 @@ def transform_inputs_to_parameters( def transform_interface_to_typed_interface( interface: typing.Optional[Interface], + docstring: Optional[Docstring] = None, ) -> typing.Optional[_interface_models.TypedInterface]: """ Transform the given simple python native interface to FlyteIDL's interface @@ -185,8 +187,14 @@ def transform_interface_to_typed_interface( if interface is None: return None - inputs_map = transform_variable_map(interface.inputs) - outputs_map = transform_variable_map(interface.outputs) + if docstring is None: + input_descriptions = output_descriptions = {} + else: + input_descriptions = docstring.input_descriptions + output_descriptions = remap_shared_output_descriptions(docstring.output_descriptions, interface.outputs) + + inputs_map = transform_variable_map(interface.inputs, input_descriptions) + outputs_map = transform_variable_map(interface.outputs, output_descriptions) return _interface_models.TypedInterface(inputs_map, outputs_map) @@ -253,7 +261,9 @@ def transform_signature_to_interface(signature: inspect.Signature) -> Interface: return Interface(inputs, outputs, output_tuple_name=custom_name) -def transform_variable_map(variable_map: Dict[str, type]) -> Dict[str, _interface_models.Variable]: +def transform_variable_map( + variable_map: Dict[str, type], descriptions: Dict[str, str] = {} +) -> Dict[str, _interface_models.Variable]: """ Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a Flyte Variable object with that type. @@ -261,7 +271,7 @@ def transform_variable_map(variable_map: Dict[str, type]) -> Dict[str, _interfac res = OrderedDict() if variable_map: for k, v in variable_map.items(): - res[k] = transform_type(v, k) + res[k] = transform_type(v, descriptions.get(k, k)) return res @@ -345,3 +355,17 @@ def t(a: int, b: str) -> Dict[str, int]: ... # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") return {default_output_name(): return_annotation} + + +def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: + """ + Deals with mixed styles of return value descriptions used in docstrings. If the docstring contains a single entry of return value description, that output description is shared by each output variable. + :param output_descriptions: Dict of output variable names mapping to output description + :param outputs: Interface outputs + :return: Dict of output variable names mapping to shared output description + """ + # no need to remap + if len(output_descriptions) != 1: + return output_descriptions + _, shared_description = next(iter(output_descriptions.items())) + return {k: shared_description for k, _ in outputs.items()} diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 38331d59bc..bd737e9e92 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -7,6 +7,7 @@ from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings +from flytekit.core.docstring import Docstring from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance @@ -37,6 +38,7 @@ def __init__( environment: Optional[Dict[str, str]] = None, task_resolver: Optional[TaskResolverMixin] = None, secret_requests: Optional[List[Secret]] = None, + docstring: Optional[Docstring] = None, **kwargs, ): """ @@ -73,6 +75,7 @@ def __init__( name=name, task_config=task_config, security_ctx=sec_ctx, + docstring=docstring, **kwargs, ) self._container_image = container_image diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 71522e365b..9e823d7219 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -23,6 +23,7 @@ from flytekit.common.exceptions import scopes as exception_scopes from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager +from flytekit.core.docstring import Docstring from flytekit.core.interface import transform_signature_to_interface from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver from flytekit.core.tracker import isnested, istestfunction @@ -121,6 +122,7 @@ def __init__( interface=mutated_interface, task_config=task_config, task_resolver=task_resolver, + docstring=Docstring(callable_=task_function), **kwargs, ) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 72eacdd3eb..9ec0e3fd22 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -19,6 +19,7 @@ FlyteContextManager, FlyteEntities, ) +from flytekit.core.docstring import Docstring from flytekit.core.interface import ( Interface, transform_inputs_to_parameters, @@ -173,13 +174,14 @@ def __init__( workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, + docstring: Optional[Docstring] = None, **kwargs, ): self._name = name self._workflow_metadata = workflow_metadata self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface - self._interface = transform_interface_to_typed_interface(python_interface) + self._interface = transform_interface_to_typed_interface(python_interface, docstring) self._inputs = {} self._unbound_inputs = set() self._nodes = [] @@ -640,6 +642,7 @@ def __init__( workflow_function: Callable, metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], + docstring: Docstring = None, ): name = f"{workflow_function.__module__}.{workflow_function.__name__}" self._workflow_function = workflow_function @@ -654,6 +657,7 @@ def __init__( workflow_metadata=metadata, workflow_metadata_defaults=default_metadata, python_interface=native_interface, + docstring=docstring, ) @property @@ -794,7 +798,10 @@ def wrapper(fn): workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) workflow_instance = PythonFunctionWorkflow( - fn, metadata=workflow_metadata, default_metadata=workflow_metadata_defaults + fn, + metadata=workflow_metadata, + default_metadata=workflow_metadata_defaults, + docstring=Docstring(callable_=fn), ) workflow_instance.compile() return workflow_instance diff --git a/flytekit/extras/persistence/__init__.py b/flytekit/extras/persistence/__init__.py index fe5f463190..a677632fd8 100644 --- a/flytekit/extras/persistence/__init__.py +++ b/flytekit/extras/persistence/__init__.py @@ -21,6 +21,6 @@ S3Persistence """ -from gcs_gsutil import GCSPersistence -from s3_awscli import S3Persistence -from .http import HttpPersistence +from flytekit.extras.persistence.gcs_gsutil import GCSPersistence +from flytekit.extras.persistence.http import HttpPersistence +from flytekit.extras.persistence.s3_awscli import S3Persistence diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py index c590c09551..a1cfaf7f87 100644 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ b/flytekit/extras/persistence/gcs_gsutil.py @@ -1,24 +1,20 @@ -import os as _os -import sys as _sys +import os +import typing +from shutil import which as shell_which from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException from flytekit.configuration import gcp as _gcp_config from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.tools import subprocess as _subprocess -if _sys.version_info >= (3,): - from shutil import which as _which -else: - from distutils.spawn import find_executable as _which - def _update_cmd_config_and_execute(cmd): - env = _os.environ.copy() + env = os.environ.copy() return _subprocess.check_call(cmd, env=env) def _amend_path(path): - return _os.path.join(path, "*") if not path.endswith("*") else path + return os.path.join(path, "*") if not path.endswith("*") else path class GCSPersistence(DataPersistence): @@ -32,18 +28,19 @@ class GCSPersistence(DataPersistence): pip install gsutil """ + _GS_UTIL_CLI = "gsutil" PROTOCOL = "gs://" - def __init__(self): - super(GCSPersistence, self).__init__(name="gcs-gsutil") + def __init__(self, default_prefix: typing.Optional[str] = None): + super(GCSPersistence, self).__init__(name="gcs-gsutil", default_prefix=default_prefix) @staticmethod def _check_binary(): """ Make sure that the `gsutil` cli is present """ - if not _which(GCSPersistence._GS_UTIL_CLI): + if not shell_which(GCSPersistence._GS_UTIL_CLI): raise _FlyteUserException("gsutil (gcloud cli) not found! Please install.") @staticmethod @@ -104,11 +101,14 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) return _update_cmd_config_and_execute(cmd) - def construct_path(self, add_protocol: bool, *paths) -> str: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: + paths = list(paths) # make type check happy + if add_prefix: + paths = paths.insert(0, self.default_prefix) path = f"{'/'.join(paths)}" if add_protocol: return f"{self.PROTOCOL}{path}" return path -DataPersistencePlugins.register_plugin("gcs://", GCSPersistence()) +DataPersistencePlugins.register_plugin(GCSPersistence.PROTOCOL, GCSPersistence) diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py index c6324ea914..00fbfd9177 100644 --- a/flytekit/extras/persistence/http.py +++ b/flytekit/extras/persistence/http.py @@ -1,7 +1,11 @@ +import os +import pathlib + import requests as _requests from flytekit.common.exceptions import user as _user_exceptions from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.loggers import logger class HttpPersistence(DataPersistence): @@ -9,6 +13,7 @@ class HttpPersistence(DataPersistence): DataPersistence implementation for the HTTP protocol. only supports downloading from an http source. Uploads are not supported currently. """ + PROTOCOL_HTTP = "http" PROTOCOL_HTTPS = "https" _HTTP_OK = 200 @@ -43,17 +48,21 @@ def get(self, from_path: str, to_path: str, recursive: bool = False): rsp.status_code, "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), ) + head, _ = os.path.split(to_path) + if head and head.startswith("/"): + logger.debug(f"HttpPersistence creating {head} so that parent dirs exist") + pathlib.Path(head).mkdir(parents=True, exist_ok=True) with open(to_path, "wb") as writer: writer.write(rsp.content) def put(self, from_path: str, to_path: str, recursive: bool = False): raise _user_exceptions.FlyteAssertion("Writing data to HTTP endpoint is not currently supported.") - def construct_path(self, add_protocol: bool, *paths) -> str: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: raise _user_exceptions.FlyteAssertion( "There are multiple ways of creating http links / paths, this is not supported by the persistence layer" ) -DataPersistencePlugins.register_plugin("http://", HttpPersistence()) -DataPersistencePlugins.register_plugin("https://", HttpPersistence()) +DataPersistencePlugins.register_plugin("http://", HttpPersistence) +DataPersistencePlugins.register_plugin("https://", HttpPersistence) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py index 134ed26c25..88bc71a49d 100644 --- a/flytekit/extras/persistence/s3_awscli.py +++ b/flytekit/extras/persistence/s3_awscli.py @@ -2,23 +2,15 @@ import os as _os import re as _re import string as _string -import sys as _sys import time -from typing import Dict, List - -from six import moves as _six_moves -from six import text_type as _text_type +from shutil import which as shell_which +from typing import Dict, List, Optional from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException from flytekit.configuration import aws as _aws_config from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.tools import subprocess as _subprocess -if _sys.version_info >= (3,): - from shutil import which as _which -else: - from distutils.spawn import find_executable as _which - def _update_cmd_config_and_execute(cmd: List[str]): env = _os.environ.copy() @@ -67,19 +59,20 @@ class S3Persistence(DataPersistence): DataPersistence plugin for AWS S3 (and Minio). Use aws cli to manage the transfer. The binary needs to be installed separately """ + PROTOCOL = "s3://" _AWS_CLI = "aws" - _SHARD_CHARACTERS = [_text_type(x) for x in _six_moves.range(10)] + list(_string.ascii_lowercase) + _SHARD_CHARACTERS = [str(x) for x in range(10)] + list(_string.ascii_lowercase) - def __init__(self): - super().__init__(name="awscli-s3") + def __init__(self, default_prefix: Optional[str] = None): + super().__init__(name="awscli-s3", default_prefix=default_prefix) @staticmethod def _check_binary(): """ Make sure that the AWS cli is present """ - if not _which(S3Persistence._AWS_CLI): + if not shell_which(S3Persistence._AWS_CLI): raise _FlyteUserException("AWS CLI not found at Please install.") @staticmethod @@ -120,7 +113,7 @@ def exists(self, remote_path): # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" # This is a best effort for returning if the object does not exist by searching # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search("(404)", _text_type(ex)): + if _re.search("(404)", str(ex)): return False else: raise ex @@ -134,7 +127,7 @@ def get(self, from_path: str, to_path: str, recursive: bool = False): if recursive: cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] else: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", remote_path, local_path] + cmd = [S3Persistence._AWS_CLI, "s3", "cp", from_path, to_path] return _update_cmd_config_and_execute(cmd) def put(self, from_path: str, to_path: str, recursive: bool = False): @@ -153,11 +146,14 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): cmd += [from_path, to_path] return _update_cmd_config_and_execute(cmd) - def construct_path(self, add_protocol: bool, *paths) -> str: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: + paths = list(paths) # make type check happy + if add_prefix: + paths = paths.insert(0, self.default_prefix) path = f"{'/'.join(paths)}" if add_protocol: return f"{self.PROTOCOL}{path}" return path -DataPersistencePlugins.register_plugin("s3://", S3Persistence()) +DataPersistencePlugins.register_plugin(S3Persistence.PROTOCOL, S3Persistence) diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index a0babeb9ed..7cf5dcae58 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -1,8 +1,3 @@ -import datetime -import os -import pathlib -from typing import Optional, Union - from flytekit.common import constants as _constants from flytekit.common import utils as _common_utils from flytekit.common.exceptions import user as _user_exception @@ -12,7 +7,6 @@ from flytekit.interfaces.data.http import http_data_proxy as _http_data_proxy from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy from flytekit.interfaces.data.s3 import s3proxy as _s3proxy -from flytekit.loggers import logger class LocalWorkingDirectoryContext(object): @@ -176,208 +170,3 @@ def get_remote_directory(cls): :rtype: Text """ return _OutputDataContext.get_active_proxy().get_random_directory() - - -class FileAccessProvider(object): - def __init__( - self, - local_sandbox_dir: Union[str, os.PathLike], - remote_proxy: Union[_s3proxy.AwsS3Proxy, _gcs_proxy.GCSProxy, None] = None, - ): - - # Local access - if local_sandbox_dir is None or local_sandbox_dir == "": - raise Exception("Can't use empty path") - local_sandbox_dir_appended = os.path.join(local_sandbox_dir, "local_flytekit") - pathlib.Path(local_sandbox_dir_appended).mkdir(parents=True, exist_ok=True) - self._local_sandbox_dir = local_sandbox_dir_appended - self._local = _local_file_proxy.LocalFileProxy(local_sandbox_dir_appended) - - # Remote/cloud stuff - if isinstance(remote_proxy, _s3proxy.AwsS3Proxy): - self._aws = remote_proxy - if isinstance(remote_proxy, _gcs_proxy.GCSProxy): - self._gcs = remote_proxy - if remote_proxy is not None: - self._remote = remote_proxy - else: - mock_remote = os.path.join(local_sandbox_dir, "mock_remote") - pathlib.Path(mock_remote).mkdir(parents=True, exist_ok=True) - self._remote = _local_file_proxy.LocalFileProxy(mock_remote) - - # HTTP access - self._http_proxy = _http_data_proxy.HttpFileProxy() - - @staticmethod - def is_remote(path: Union[str, os.PathLike]) -> bool: - if path.startswith("s3:/") or path.startswith("gs:/") or path.startswith("file:/") or path.startswith("http"): - return True - return False - - def _get_data_proxy_by_path(self, path: Union[str, os.PathLike]): - """ - :param Text path: - :rtype: flytekit.interfaces.data.common.DataProxy - """ - if path.startswith("s3:/"): - return self.aws - elif path.startswith("gs:/"): - return self.gcs - elif path.startswith("http"): - return self.http - elif path.startswith("file://"): - # Note that we default to the local one here, not the remote one. - return self.local_access - elif path.startswith("/"): - # Note that we default to the local one here, not the remote one. - return self.local_access - raise Exception(f"Unknown file access {path}") - - @property - def aws(self) -> _s3proxy.AwsS3Proxy: - if self._aws is None: - raise Exception("No AWS handler found") - return self._aws - - @property - def gcs(self) -> _gcs_proxy.GCSProxy: - if self._gcs is None: - raise Exception("No GCP handler found") - return self._gcs - - @property - def remote(self): - if self._remote is not None: - return self._remote - raise Exception("No cloud provider specified") - - @property - def http(self) -> _http_data_proxy.HttpFileProxy: - return self._http_proxy - - @property - def local_sandbox_dir(self) -> os.PathLike: - return self._local_sandbox_dir - - @property - def local_access(self) -> _local_file_proxy.LocalFileProxy: - return self._local - - def get_random_remote_path(self, file_path_or_file_name: Optional[str] = None) -> str: - """ - :param file_path_or_file_name: For when you want a random directory, but want to preserve the leaf file name - """ - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return f"{self.remote.get_random_directory()}{tail}" - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, using random remote path...") - return self.remote.get_random_path() - - def get_random_remote_directory(self): - return self.remote.get_random_directory() - - def get_random_local_path(self, file_path_or_file_name: Optional[str] = None) -> str: - """ - :param file_path_or_file_name: For when you want a random directory, but want to preserve the leaf file name - """ - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return os.path.join(self.local_access.get_random_directory(), tail) - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, using random local path...") - return self.local_access.get_random_path() - - def get_random_local_directory(self) -> str: - dir = self.local_access.get_random_directory() - pathlib.Path(dir).mkdir(parents=True, exist_ok=True) - return dir - - def exists(self, remote_path: str) -> bool: - """ - :param Text remote_path: remote s3:// or gs:// path - """ - return self._get_data_proxy_by_path(remote_path).exists(remote_path) - - def download_directory(self, remote_path: str, local_path: str): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ - return self._get_data_proxy_by_path(remote_path).download_directory(remote_path, local_path) - - def download(self, remote_path: str, local_path: str): - """ - :param Text remote_path: remote s3:// path - :param Text local_path: directory to copy to - """ - return self._get_data_proxy_by_path(remote_path).download(remote_path, local_path) - - def upload(self, file_path: str, to_path: str): - """ - :param Text file_path: - :param Text to_path: - """ - return self.remote.upload(file_path, to_path) - - def upload_directory(self, local_path: str, remote_path: str): - """ - :param Text local_path: - :param Text remote_path: - """ - # TODO: Clean this up, this is a minor hack in lieu of https://github.com/flyteorg/flyte/issues/762 - if remote_path.startswith("/"): - return self.local_access.upload_directory(local_path, remote_path) - return self.remote.upload_directory(local_path, remote_path) - - def get_data(self, remote_path: str, local_path: str, is_multipart=False): - """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: - """ - try: - with _common_utils.PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)): - if is_multipart: - self.download_directory(remote_path, local_path) - else: - self.download(remote_path, local_path) - except Exception as ex: - raise _user_exception.FlyteAssertion( - "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" - "Original exception: {error_string}".format( - remote_path=remote_path, - local_path=local_path, - is_multipart=is_multipart, - error_string=str(ex), - ) - ) - - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): - """ - The implication here is that we're always going to put data to the remote location, so we .remote to ensure - we don't use the true local proxy if the remote path is a file:// - - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: - """ - try: - with _common_utils.PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)): - if is_multipart: - self.remote.upload_directory(local_path, remote_path) - else: - self.remote.upload(local_path, remote_path) - except Exception as ex: - raise _user_exception.FlyteAssertion( - f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" - f"Original exception: {str(ex)}" - ) from ex - - -timestamped_default_sandbox_location = os.path.join( - _sdk_config.LOCAL_SANDBOX.get(), datetime.datetime.now().strftime("%Y%m%d_%H%M%S") -) -default_local_file_access_provider = FileAccessProvider(local_sandbox_dir=timestamped_default_sandbox_location) diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 3eb7a14c8b..9dc4a5f0ed 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -10,8 +10,8 @@ .. code-block:: python - # create a remote object from environment variables - remote = FlyteRemote.from_environment() + # create a remote object from flyte config and environment variables + remote = FlyteRemote.from_config() # fetch a workflow from the flyte backend flyte_workflow = remote.fetch_workflow(name="my_workflow", version="v1") @@ -22,31 +22,59 @@ # inspect the execution's outputs print(workflow_execution.outputs) +.. _remote-entrypoint: + Entrypoint ========== .. autosummary:: :template: custom.rst :toctree: generated/ + :nosignatures: + + ~remote.FlyteRemote + +.. _remote-flyte-entities: + +Entities +======== + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + :nosignatures: + + ~tasks.task.FlyteTask + ~workflow.FlyteWorkflow + ~launch_plan.FlyteLaunchPlan + +.. _remote-flyte-entity-components: + +Entity Components +================= + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + :nosignatures: + + ~nodes.FlyteNode + ~component_nodes.FlyteTaskNode + ~component_nodes.FlyteWorkflowNode - FlyteRemote +.. _remote-flyte-execution-objects: -Flyte Entities -============== +Execution Objects +================= .. autosummary:: :template: custom.rst :toctree: generated/ + :nosignatures: - FlyteTask - FlyteWorkflow - FlyteLaunchPlan - FlyteNode - FlyteTaskNode - FlyteWorkflowNode - FlyteWorkflowExecution - FlyteTaskExecution - FlyteNodeExecution + ~workflow_execution.FlyteWorkflowExecution + ~tasks.executions.FlyteTaskExecution + ~nodes.FlyteNodeExecution """ diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py index 290bcdbd5d..a0933a75f8 100644 --- a/flytekit/remote/component_nodes.py +++ b/flytekit/remote/component_nodes.py @@ -9,6 +9,8 @@ class FlyteTaskNode(_workflow_model.TaskNode): + """A class encapsulating a task that a Flyte node needs to execute.""" + def __init__(self, flyte_task: "flytekit.remote.tasks.task.FlyteTask"): self._flyte_task = flyte_task super(FlyteTaskNode, self).__init__(None) @@ -56,6 +58,8 @@ def promote_from_model( class FlyteWorkflowNode(_workflow_model.WorkflowNode): + """A class encapsulating a workflow that a Flyte node needs to execute.""" + def __init__( self, flyte_workflow: "flytekit.remote.workflow.FlyteWorkflow" = None, diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py index c8ee1f2ff0..c472333913 100644 --- a/flytekit/remote/launch_plan.py +++ b/flytekit/remote/launch_plan.py @@ -9,6 +9,8 @@ class FlyteLaunchPlan(_launch_plan_models.LaunchPlanSpec): + """A class encapsulating a remote Flyte launch plan.""" + def __init__(self, *args, **kwargs): super(FlyteLaunchPlan, self).__init__(*args, **kwargs) # Set all the attributes we expect this class to have diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py index 49c041ff74..456d6b1892 100644 --- a/flytekit/remote/nodes.py +++ b/flytekit/remote/nodes.py @@ -21,6 +21,8 @@ class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + """A class encapsulating a remote Flyte node.""" + def __init__( self, id, @@ -150,6 +152,8 @@ def __repr__(self) -> str: class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact): + """A class encapsulating a node execution being run on a Flyte remote backend.""" + def __init__(self, *args, **kwargs): super(FlyteNodeExecution, self).__init__(*args, **kwargs) self._task_executions = None diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index fb467ff8f0..a3e7ffa0d1 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -7,6 +7,7 @@ import uuid from collections import OrderedDict from copy import deepcopy +from dataclasses import asdict, dataclass from datetime import datetime, timedelta from flyteidl.core import literals_pb2 as literals_pb2 @@ -30,12 +31,10 @@ from flytekit.configuration.internal import DOMAIN, PROJECT from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.launch_plan import LaunchPlan from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import WorkflowBase -from flytekit.interfaces.data.data_proxy import FileAccessProvider -from flytekit.interfaces.data.gcs.gcs_proxy import GCSProxy -from flytekit.interfaces.data.s3.s3proxy import AwsS3Proxy from flytekit.models import common as common_models from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models @@ -60,6 +59,14 @@ ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] +@dataclass +class ResolvedIdentifiers: + project: str + domain: str + name: str + version: str + + def _get_latest_version(list_entities_method: typing.Callable, project: str, domain: str, name: str): named_entity = common_models.NamedEntityIdentifier(project, domain, name) entity_list, _ = list_entities_method( @@ -91,7 +98,7 @@ def _get_entity_identifier( class FlyteRemote(object): - """Main entrypoint for programmatically accessing Flyte remote backend. + """Main entrypoint for programmatically accessing a Flyte remote backend. The term 'remote' is synonymous with 'backend' or 'deployment' and refers to a hosted instance of the Flyte platform, which comes with a Flyte Admin server on some known URI. @@ -102,31 +109,30 @@ class FlyteRemote(object): """ - @staticmethod - def from_environment( - default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None + @classmethod + def from_config( + cls, default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None ) -> FlyteRemote: - """Create a FlyteRemote object using environment variables. + """Create a FlyteRemote object using flyte configuration variables and/or environment variable overrides. :param default_project: default project to use when fetching or executing flyte entities. :param default_domain: default domain to use when fetching or executing flyte entities. """ - raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() - raw_output_data_prefix = raw_output_data_prefix if raw_output_data_prefix else None + raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join( + sdk_config.LOCAL_SANDBOX.get(), "raw" + ) + + file_access = FileAccessProvider( + local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"), + raw_output_prefix=raw_output_data_prefix, + ) - return FlyteRemote( - default_project=default_project or PROJECT.get(), - default_domain=default_domain or DOMAIN.get(), + return cls( flyte_admin_url=platform_config.URL.get(), insecure=platform_config.INSECURE.get(), - file_access=FileAccessProvider( - local_sandbox_dir=sdk_config.LOCAL_SANDBOX.get(), - remote_proxy={ - constants.CloudProvider.AWS: AwsS3Proxy(raw_output_data_prefix), - constants.CloudProvider.GCP: GCSProxy(raw_output_data_prefix), - constants.CloudProvider.LOCAL: None, - }.get(platform_config.CLOUD_PROVIDER.get(), None), - ), + default_project=default_project or PROJECT.get() or None, + default_domain=default_domain or DOMAIN.get() or None, + file_access=file_access, auth_role=common_models.AuthRole( assumable_iam_role=auth_config.ASSUMABLE_IAM_ROLE.get(), kubernetes_service_account=auth_config.KUBERNETES_SERVICE_ACCOUNT.get(), @@ -142,10 +148,10 @@ def from_environment( def __init__( self, - default_project: str, - default_domain: str, flyte_admin_url: str, insecure: bool, + default_project: typing.Optional[str] = None, + default_domain: typing.Optional[str] = None, file_access: typing.Optional[FileAccessProvider] = None, auth_role: typing.Optional[common_models.AuthRole] = None, notifications: typing.Optional[typing.List[common_models.Notification]] = None, @@ -156,10 +162,10 @@ def __init__( ): """Initilize a FlyteRemote object. - :param default_project: default project to use when fetching or executing flyte entities. - :param default_domain: default domain to use when fetching or executing flyte entities. :param flyte_admin_url: url pointing to the remote backend. :param insecure: whether or not the enable SSL. + :param default_project: default project to use when fetching or executing flyte entities. + :param default_domain: default domain to use when fetching or executing flyte entities. :param file_access: file access provider to use for offloading non-literal inputs/outputs. :param auth_role: auth role config :param notifications: notification config @@ -173,24 +179,74 @@ def __init__( self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure) # read config files, env vars, host, ssl options for admin client - self.default_project = default_project - self.default_domain = default_domain - self.image_config = image_config - self.file_access = file_access - self.auth_role = auth_role - self.notifications = notifications - self.labels = labels - self.annotations = annotations - self.raw_output_data_config = raw_output_data_config + self._flyte_admin_url = flyte_admin_url + self._insecure = insecure + self._default_project = default_project + self._default_domain = default_domain + self._image_config = image_config + self._auth_role = auth_role + self._notifications = notifications + self._labels = labels + self._annotations = annotations + self._raw_output_data_config = raw_output_data_config + + # Save the file access object locally, but also make it available for use from the context. + FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build()) + self._file_access = file_access # TODO: Reconsider whether we want this. Probably best to not cache. - self.serialized_entity_cache = OrderedDict() + self._serialized_entity_cache = OrderedDict() @property def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" return self._client + @property + def default_project(self) -> str: + """Default project to use when fetching or executing flyte entities.""" + return self._default_project + + @property + def default_domain(self) -> str: + """Default project to use when fetching or executing flyte entities.""" + return self._default_domain + + @property + def image_config(self) -> ImageConfig: + """Image config.""" + return self._image_config + + @property + def file_access(self) -> FileAccessProvider: + """File access provider to use for offloading non-literal inputs/outputs.""" + return self._file_access + + @property + def auth_role(self): + """Auth role config.""" + return self._auth_role + + @property + def notifications(self): + """Notification config.""" + return self._notifications + + @property + def labels(self): + """Label config.""" + return self._labels + + @property + def annotations(self): + """Annotation config.""" + return self._annotations + + @property + def raw_output_data_config(self): + """Location for offloaded data, e.g. in S3""" + return self._raw_output_data_config + @property def version(self) -> str: """Get a randomly generated version string.""" @@ -204,10 +260,10 @@ def remote_context(self): def with_overrides( self, - default_project: str = None, - default_domain: str = None, - flyte_admin_url: str = None, - insecure: bool = None, + default_project: typing.Optional[str] = None, + default_domain: typing.Optional[str] = None, + flyte_admin_url: typing.Optional[str] = None, + insecure: typing.Optional[bool] = None, file_access: typing.Optional[FileAccessProvider] = None, auth_role: typing.Optional[common_models.AuthRole] = None, notifications: typing.Optional[typing.List[common_models.Notification]] = None, @@ -219,27 +275,29 @@ def with_overrides( """Create a copy of the remote object, overriding the specified attributes.""" new_remote = deepcopy(self) if default_project: - new_remote.default_project = default_project + new_remote._default_project = default_project if default_domain: - new_remote.default_domain = default_domain + new_remote._default_domain = default_domain if flyte_admin_url: - new_remote.flyte_admin_url = flyte_admin_url + new_remote._flyte_admin_url = flyte_admin_url + new_remote._client = SynchronousFlyteClient(flyte_admin_url, self._insecure) if insecure: - new_remote.insecure = insecure + new_remote._insecure = insecure + new_remote._client = SynchronousFlyteClient(self._flyte_admin_url, insecure) if file_access: - new_remote.file_access = file_access + new_remote._file_access = file_access if auth_role: - new_remote.auth_role = auth_role + new_remote._auth_role = auth_role if notifications: - new_remote.notifications = notifications + new_remote._notifications = notifications if labels: - new_remote.labels = labels + new_remote._labels = labels if annotations: - new_remote.annotations = annotations + new_remote._annotations = annotations if image_config: - new_remote.image_config = image_config + new_remote._image_config = image_config if raw_output_data_config: - new_remote.raw_output_data_config = raw_output_data_config + new_remote._raw_output_data_config = raw_output_data_config return new_remote def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: @@ -373,7 +431,7 @@ def _serialize( """Serialize an entity for registration.""" # TODO: Revisit cache return get_serializable( - self.serialized_entity_cache, + self._serialized_entity_cache, SerializationSettings( project or self.default_project, domain or self.default_domain, @@ -387,34 +445,15 @@ def _serialize( # Register Entities # ##################### - def _resolve_identifier_kwargs( - self, - entity, - project: typing.Optional[str], - domain: typing.Optional[str], - name: typing.Optional[str], - version: typing.Optional[str], - ): - """ - Resolves the identifier attributes based on user input, falling back on default project/domain and - auto-generated version. - """ - return { - "project": project or self.default_project, - "domain": domain or self.default_domain, - "name": name or entity.name, - "version": version or self.version, - } - @singledispatchmethod def register( self, - entity, + entity: typing.Union[PythonTask, WorkflowBase, LaunchPlan], project: str = None, domain: str = None, name: str = None, version: str = None, - ): + ) -> typing.Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan]: """Register an entity to flyte admin. :param entity: entity to register. @@ -422,7 +461,6 @@ def register( :param domain: register entity into this domain. If None, uses ``default_domain`` attribute :param name: register entity with this name. If None, uses ``entity.name`` :param version: register entity with this version. If None, uses auto-generated version. - :returns: flyte entity """ raise NotImplementedError(f"entity type {type(entity)} not recognized for registration") @@ -431,24 +469,24 @@ def _( self, entity: PythonTask, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteTask: """Register an @task-decorated function or TaskTemplate task to flyte admin.""" - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) self.client.create_task( - Identifier(ResourceType.TASK, **flyte_id_kwargs), - task_spec=self._serialize(entity, **flyte_id_kwargs), + Identifier(ResourceType.TASK, **resolved_identifiers), + task_spec=self._serialize(entity, **resolved_identifiers), ) - return self.fetch_task(**flyte_id_kwargs) + return self.fetch_task(**resolved_identifiers) @register.register def _( self, entity: WorkflowBase, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteWorkflow: """Register an @workflow-decorated function to flyte admin.""" - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) self.client.create_workflow( - Identifier(ResourceType.WORKFLOW, **flyte_id_kwargs), - workflow_spec=self._serialize(entity, **flyte_id_kwargs), + Identifier(ResourceType.WORKFLOW, **resolved_identifiers), + workflow_spec=self._serialize(entity, **resolved_identifiers), ) - return self.fetch_workflow(**flyte_id_kwargs) + return self.fetch_workflow(**resolved_identifiers) @register.register def _( @@ -457,8 +495,8 @@ def _( """Register a LaunchPlan object to flyte admin.""" # See _get_patch_launch_plan_fn for what we need to patch. These are the elements of a launch plan # that are not set at serialization time and are filled in either by flyte-cli register files or flytectl. - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) - serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity, **flyte_id_kwargs) + resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) + serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity, **resolved_identifiers) if self.auth_role: serialized_lp.spec._auth_role = common_models.AuthRole( self.auth_role.assumable_iam_role, self.auth_role.kubernetes_service_account @@ -478,19 +516,68 @@ def _( serialized_lp.spec._annotations.values[k] = v self.client.create_launch_plan( - Identifier(ResourceType.LAUNCH_PLAN, **flyte_id_kwargs), + Identifier(ResourceType.LAUNCH_PLAN, **resolved_identifiers), launch_plan_spec=serialized_lp.spec, ) - return self.fetch_launch_plan(**flyte_id_kwargs) + return self.fetch_launch_plan(**resolved_identifiers) #################### # Execute Entities # #################### + def _resolve_identifier_kwargs( + self, + entity, + project: typing.Optional[str], + domain: typing.Optional[str], + name: typing.Optional[str], + version: typing.Optional[str], + ) -> ResolvedIdentifiers: + """ + Resolves the identifier attributes based on user input, falling back on the default project/domain and + auto-generated version, and ultimately the entity project/domain if entity is a remote flyte entity. + """ + error_msg = ( + "entity {entity} of type {entity_type} is not associated with a {arg_name}. Please specify the {arg_name} " + "argument when invoking the FlyteRemote.execute method or a default_{arg_name} value when initializig the " + "FlyteRemote object." + ) + + if project: + resolved_project, msg_project = project, "execute-method" + elif self.default_project: + resolved_project, msg_project = self.default_project, "remote" + elif hasattr(entity, "id"): + resolved_project, msg_project = entity.id.project, "entity" + else: + raise TypeError(error_msg.format(entity=entity, entity_type=type(entity), arg_name="project")) + + if domain: + resolved_domain, msg_domain = domain, "execute-method" + elif self.default_domain: + resolved_domain, msg_domain = self.default_domain, "remote" + elif hasattr(entity, "id"): + resolved_domain, msg_domain = entity.id.domain, "entity" + else: + raise TypeError(error_msg.format(entity=entity, entity_type=type(entity), arg_name="domain")) + + remote_logger.debug( + f"Using {msg_project}-supplied value for project and {msg_domain}-supplied value for domain." + ) + + return ResolvedIdentifiers( + resolved_project, + resolved_domain, + name or entity.name, + version or self.version, + ) + def _execute( self, flyte_id: Identifier, inputs: typing.Dict[str, typing.Any], + project: str, + domain: str, execution_name: typing.Optional[str] = None, wait: bool = False, ) -> FlyteWorkflowExecution: @@ -498,6 +585,8 @@ def _execute( :param flyte_id: entity identifier :param inputs: dictionary mapping argument names to values + :param project: project on which to execute the entity referenced by flyte_id + :param domain: domain on which to execute the entity referenced by flyte_id :param execution_name: name of the execution :param wait: if True, waits for execution to complete :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` @@ -519,8 +608,8 @@ def _execute( # in the case that I want to use a flyte entity from e.g. project "A" but actually execute the entity on a # different project "B". For now, this method doesn't support this use case. exec_id = self.client.create_execution( - flyte_id.project, - flyte_id.domain, + project, + domain, execution_name, ExecutionSpec( flyte_id, @@ -547,17 +636,23 @@ def _execute( @singledispatchmethod def execute( self, - entity, + entity: typing.Union[FlyteTask, FlyteLaunchPlan, FlyteWorkflow, PythonTask, WorkflowBase, LaunchPlan], inputs: typing.Dict[str, typing.Any], project: str = None, domain: str = None, name: str = None, version: str = None, - execution_name=None, - wait=False, + execution_name: str = None, + wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a task, workflow, or launchplan. + This method supports: + - ``Flyte{Task, Workflow, LaunchPlan}`` remote module objects. + - ``@task``-decorated functions and ``TaskTemplate`` tasks. + - ``@workflow``-decorated functions. + - ``LaunchPlan`` objects. + :param entity: entity to execute :param inputs: dictionary mapping argument names to values :param project: execute entity in this project. If entity doesn't exist in the project, register the entity @@ -566,14 +661,13 @@ def execute( first before executing. :param name: execute entity using this name. If not None, use this value instead of ``entity.name`` :param version: execute entity using this version. If None, uses auto-generated value. - :param execution_name: name of the execution + :param execution_name: name of the execution. If None, uses auto-generated value. :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` - .. note:: + .. note: - The ``project``, ``domain``, ``name``. and ``version`` arguments do not apply to ``FlyteTask``, - ``FlyteLaunchPlan``, and ``FlyteWorkflow`` objects. + The ``name`` and ``version`` arguments do not apply to ``FlyteTask``, ``FlyteLaunchPlan``, and + ``FlyteWorkflow`` entity inputs. These values are determined by referencing the entity identifier values. """ raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -586,38 +680,57 @@ def _( self, entity, inputs: typing.Dict[str, typing.Any], - execution_name=None, - wait=False, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name: str = None, + wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. - :param entity: entity to execute - :param inputs: dictionary mapping argument names to values - :param execution_name: name of the execution - :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` + NOTE: the name and version arguments are currently not used and only there consistency in the function signature """ - return self._execute(entity.id, inputs, execution_name, wait) + if name or version: + remote_logger.warn(f"The 'name' and 'version' arguments are ignored for entities of type {type(entity)}") + resolved_identifiers = self._resolve_identifier_kwargs( + entity, project, domain, entity.id.name, entity.id.version + ) + return self._execute( + entity.id, + inputs, + project=resolved_identifiers.project, + domain=resolved_identifiers.domain, + execution_name=execution_name, + wait=wait, + ) @execute.register def _( self, entity: FlyteWorkflow, inputs: typing.Dict[str, typing.Any], - execution_name=None, - wait=False, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name: str = None, + wait: bool = False, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. - :param entity: entity to execute - :param inputs: dictionary mapping argument names to values - :param execution_name: name of the execution - :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` + NOTE: the name and version arguments are currently not used and only there consistency in the function signature """ + if name or version: + remote_logger.warn(f"The 'name' and 'version' arguments are ignored for entities of type {type(entity)}") + resolved_identifiers = self._resolve_identifier_kwargs( + entity, project, domain, entity.id.name, entity.id.version + ) return self.execute( self.fetch_launch_plan(entity.id.project, entity.id.domain, entity.id.name, entity.id.version), inputs, + project=resolved_identifiers.project, + domain=resolved_identifiers.domain, execution_name=execution_name, wait=wait, ) @@ -635,28 +748,23 @@ def _( name: str = None, version: str = None, execution_name: str = None, - wait=False, + wait: bool = False, ) -> FlyteWorkflowExecution: - """Execute an @task-decorated function or TaskTemplate task. - - :param entity: entity to execute - :param inputs: dictionary mapping argument names to values - :param project: execute entity in this project. If entity doesn't exist in the project, register the entity - first before executing. - :param domain: execute entity in this domain. If entity doesn't exist in the domain, register the entity - first before executing. - :param name: execute entity using this name. If not None, use this value instead of ``entity.name`` - :param version: execute entity using this version. If None, uses auto-generated value. - :param execution_name: name of the execution - :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` - """ - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + """Execute an @task-decorated function or TaskTemplate task.""" + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers_dict = asdict(resolved_identifiers) try: - flyte_task: FlyteTask = self.fetch_task(**flyte_id_kwargs) + flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict) except Exception: - flyte_task: FlyteTask = self.register(entity, **flyte_id_kwargs) - return self.execute(flyte_task, inputs, execution_name=execution_name, wait=wait) + flyte_task: FlyteTask = self.register(entity, **resolved_identifiers_dict) + return self.execute( + flyte_task, + inputs, + project=resolved_identifiers.project, + domain=resolved_identifiers.domain, + execution_name=execution_name, + wait=wait, + ) @execute.register def _( @@ -667,29 +775,24 @@ def _( domain: str = None, name: str = None, version: str = None, - execution_name=None, - wait=False, + execution_name: str = None, + wait: bool = False, ) -> FlyteWorkflowExecution: - """Execute an @workflow-decorated function. - - :param entity: entity to execute - :param inputs: dictionary mapping argument names to values - :param project: execute entity in this project. If entity doesn't exist in the project, register the entity - first before executing. - :param domain: execute entity in this domain. If entity doesn't exist in the domain, register the entity - first before executing. - :param name: execute entity using this name. If not None, use this value instead of ``entity.name`` - :param version: execute entity using this version. If None, uses auto-generated value. - :param execution_name: name of the execution - :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` - """ - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + """Execute an @workflow-decorated function.""" + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers_dict = asdict(resolved_identifiers) try: - flyte_workflow: FlyteWorkflow = self.fetch_workflow(**flyte_id_kwargs) + flyte_workflow: FlyteWorkflow = self.fetch_workflow(**resolved_identifiers_dict) except Exception: - flyte_workflow: FlyteWorkflow = self.register(entity, **flyte_id_kwargs) - return self.execute(flyte_workflow, inputs, execution_name=execution_name, wait=wait) + flyte_workflow: FlyteWorkflow = self.register(entity, **resolved_identifiers_dict) + return self.execute( + flyte_workflow, + inputs, + project=resolved_identifiers.project, + domain=resolved_identifiers.domain, + execution_name=execution_name, + wait=wait, + ) @execute.register def _( @@ -700,29 +803,24 @@ def _( domain: str = None, name: str = None, version: str = None, - execution_name=None, - wait=False, + execution_name: str = None, + wait: bool = False, ) -> FlyteWorkflowExecution: - """Execute a LaunchPlan object. - - :param entity: entity to execute - :param inputs: dictionary mapping argument names to values - :param project: execute entity in this project. If entity doesn't exist in the project, register the entity - first before executing. - :param domain: execute entity in this domain. If entity doesn't exist in the domain, register the entity - first before executing. - :param name: execute entity using this name. If not None, use this value instead of ``entity.name`` - :param version: execute entity using this version. If None, uses auto-generated value. - :param execution_name: name of the execution - :param wait: if True, waits for execution to complete - :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` - """ - flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + """Execute a LaunchPlan object.""" + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) + resolved_identifiers_dict = asdict(resolved_identifiers) try: - flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**flyte_id_kwargs) + flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**resolved_identifiers_dict) except Exception: - flyte_launchplan: FlyteLaunchPlan = self.register(entity, **flyte_id_kwargs) - return self.execute(flyte_launchplan, inputs, execution_name=execution_name, wait=wait) + flyte_launchplan: FlyteLaunchPlan = self.register(entity, **resolved_identifiers_dict) + return self.execute( + flyte_launchplan, + inputs, + project=resolved_identifiers.project, + domain=resolved_identifiers.domain, + execution_name=execution_name, + wait=wait, + ) ################################### # Wait for Executions to Complete # @@ -756,7 +854,7 @@ def wait( ######################## @singledispatchmethod - def sync(self, execution): + def sync(self, execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution]): """Sync a flyte execution object with its corresponding remote state. This method syncs the inputs and outputs of the execution object and all of its child node executions. @@ -836,6 +934,22 @@ def _(self, execution: FlyteTaskExecution) -> FlyteTaskExecution: task = self.fetch_task(task_id.project, task_id.domain, task_id.name, task_id.version) return self._assign_inputs_and_outputs(synced_execution, execution_data, task.interface) + ############################# + # Terminate Execution State # + ############################# + + def terminate(self, execution: FlyteWorkflowExecution, cause: str): + """Terminate a workflow execution. + + :param execution: workflow execution to terminate + :param cause: reason for termination + """ + self.client.terminate_execution(execution.id, cause) + + ################## + # Helper Methods # + ################## + def _assign_inputs_and_outputs(self, execution, execution_data, interface): """Helper for assigning synced inputs and outputs to an execution object.""" with self.remote_context() as ctx: diff --git a/flytekit/remote/tasks/executions.py b/flytekit/remote/tasks/executions.py index 0687d0e9c1..9937b4be77 100644 --- a/flytekit/remote/tasks/executions.py +++ b/flytekit/remote/tasks/executions.py @@ -9,6 +9,8 @@ class FlyteTaskExecution(_task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact): + """A class encapsulating a task execution being run on a Flyte remote backend.""" + def __init__(self, *args, **kwargs): super(FlyteTaskExecution, self).__init__(*args, **kwargs) self._inputs = None diff --git a/flytekit/remote/tasks/task.py b/flytekit/remote/tasks/task.py index dfd9dade89..0f539bce00 100644 --- a/flytekit/remote/tasks/task.py +++ b/flytekit/remote/tasks/task.py @@ -6,6 +6,8 @@ class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): + """A class encapsulating a remote Flyte task.""" + def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): super(FlyteTask, self).__init__( id, diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py index c4cc9a59cc..6153faa02c 100644 --- a/flytekit/remote/workflow.py +++ b/flytekit/remote/workflow.py @@ -13,7 +13,7 @@ class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, _workflow_models.WorkflowTemplate): - """A Flyte control plane construct.""" + """A class encapsulating a remote Flyte workflow.""" def __init__( self, diff --git a/flytekit/remote/workflow_execution.py b/flytekit/remote/workflow_execution.py index 2c5e44bd2e..f98aae4f87 100644 --- a/flytekit/remote/workflow_execution.py +++ b/flytekit/remote/workflow_execution.py @@ -12,6 +12,8 @@ class FlyteWorkflowExecution(_execution_models.Execution, _artifact.ExecutionArtifact): + """A class encapsulating a workflow execution being run on a Flyte remote backend.""" + def __init__(self, *args, **kwargs): super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) self._node_executions = None @@ -20,6 +22,7 @@ def __init__(self, *args, **kwargs): @property def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]: + """Get a dictionary of node executions that are a part of this workflow execution.""" return self._node_executions or {} @property diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 9adfd5a684..48fe3ee2c0 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -29,9 +29,9 @@ black==21.7b0 # via papermill bleach==3.3.1 # via nbconvert -boto3==1.18.1 +boto3==1.18.4 # via sagemaker-training -botocore==1.21.1 +botocore==1.21.4 # via # boto3 # s3transfer @@ -56,7 +56,7 @@ cryptography==3.4.7 # via paramiko dataclasses-json==0.5.4 # via flytekit -debugpy==1.3.0 +debugpy==1.4.0 # via ipykernel decorator==5.0.9 # via @@ -70,6 +70,8 @@ dirhash==0.2.1 # via flytekit docker-image-py==0.1.10 # via flytekit +docstring-parser==0.9.1 + # via flytekit entrypoints==0.3 # via # nbconvert @@ -90,7 +92,7 @@ importlib-metadata==4.6.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==6.0.2 +ipykernel==6.0.3 # via flytekit ipython==7.25.0 # via ipykernel @@ -125,7 +127,7 @@ keyring==23.0.1 # via flytekit markupsafe==2.0.1 # via jinja2 -marshmallow==3.12.2 +marshmallow==3.13.0 # via # dataclasses-json # marshmallow-enum @@ -159,7 +161,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.21.0 +numpy==1.21.1 # via # flytekit # pandas @@ -178,7 +180,7 @@ paramiko==2.7.2 # via sagemaker-training parso==0.8.2 # via jedi -pathspec==0.8.1 +pathspec==0.9.0 # via # black # scantree diff --git a/requirements.txt b/requirements.txt index 5529877fa9..c61aa47892 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,9 +29,9 @@ black==21.7b0 # via papermill bleach==3.3.1 # via nbconvert -boto3==1.18.1 +boto3==1.18.4 # via sagemaker-training -botocore==1.21.1 +botocore==1.21.4 # via # boto3 # s3transfer @@ -56,7 +56,7 @@ cryptography==3.4.7 # via paramiko dataclasses-json==0.5.4 # via flytekit -debugpy==1.3.0 +debugpy==1.4.0 # via ipykernel decorator==5.0.9 # via @@ -70,6 +70,8 @@ dirhash==0.2.1 # via flytekit docker-image-py==0.1.10 # via flytekit +docstring-parser==0.9.1 + # via flytekit entrypoints==0.3 # via # nbconvert @@ -90,7 +92,7 @@ importlib-metadata==4.6.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==6.0.2 +ipykernel==6.0.3 # via flytekit ipython==7.25.0 # via ipykernel @@ -125,7 +127,7 @@ keyring==23.0.1 # via flytekit markupsafe==2.0.1 # via jinja2 -marshmallow==3.12.2 +marshmallow==3.13.0 # via # dataclasses-json # marshmallow-enum @@ -159,7 +161,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.21.0 +numpy==1.21.1 # via # flytekit # pandas @@ -178,7 +180,7 @@ paramiko==2.7.2 # via sagemaker-training parso==0.8.2 # via jedi -pathspec==0.8.1 +pathspec==0.9.0 # via # black # scantree diff --git a/setup.py b/setup.py index 0d86dd2d4f..b4f3dd23c1 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ "dirhash>=0.2.1", "docker-image-py>=0.1.10", "singledispatchmethod; python_version < '3.8.0'", + "docstring-parser>=0.9.0", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 4f43bda7ec..6b17226821 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -47,14 +47,14 @@ def test_client(flyteclient, flyte_workflows_register, docker_services): def test_fetch_execute_launch_plan(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}, wait=True) assert execution.outputs["o0"] == "hello world" def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.basic_workflow.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 10, "b": "foobar"}, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} @@ -72,7 +72,7 @@ def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}) @@ -108,7 +108,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.subworkflows.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 101}, wait=True) # check node execution inputs and outputs @@ -124,14 +124,17 @@ def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflow def test_fetch_execute_workflow(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_workflow = remote.fetch_workflow(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_workflow, {}, wait=True) assert execution.outputs["o0"] == "hello world" + execution_to_terminate = remote.execute(flyte_workflow, {}) + remote.terminate(execution_to_terminate, cause="just because") + def test_fetch_execute_task(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") flyte_task = remote.fetch_task(name="workflows.basic.basic_workflow.t1", version=f"v{VERSION}") execution = remote.execute(flyte_task, {"a": 10}, wait=True) assert execution.outputs["t1_int_output"] == 12 @@ -145,7 +148,7 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote # make sure the task name is the same as the name used during registration t1._name = t1.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") execution = remote.execute(t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" @@ -158,7 +161,7 @@ def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_re # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") execution = remote.execute(my_wf, inputs={"a": 10, "b": "xyz"}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "xyzworld" @@ -170,7 +173,7 @@ def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_re def test_execute_sqlite3_task(flyteclient, flyte_workflows_register, flyte_remote_env): - remote = FlyteRemote.from_environment(PROJECT, "development") + remote = FlyteRemote.from_config(PROJECT, "development") example_db = "https://cdn.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" interactive_sql_task = SQLite3Task( diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 106156499c..06ae373f8c 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -183,8 +183,8 @@ def return_args(*args, **kwargs): @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_void(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -212,8 +212,8 @@ def verify_output(*args, **kwargs): @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -241,8 +241,8 @@ def test_dispatch_execute_ignore(mock_write_to_file, mock_upload_dir, mock_get_d @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_exception(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -279,8 +279,8 @@ def output_collector(proto, path): @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_normal(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -316,8 +316,8 @@ def t1(a: int) -> str: @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -356,8 +356,8 @@ def t1(a: int) -> str: @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_user_error_recoverable(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens @@ -400,8 +400,8 @@ def my_subwf(a: int) -> typing.List[str]: @mock.patch("flytekit.common.utils.load_proto_from_file") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.get_data") -@mock.patch("flytekit.interfaces.data.data_proxy.FileAccessProvider.upload_directory") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.common.utils.write_proto_to_file") def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens diff --git a/tests/flytekit/unit/core/test_docstring.py b/tests/flytekit/unit/core/test_docstring.py new file mode 100644 index 0000000000..de0a398af5 --- /dev/null +++ b/tests/flytekit/unit/core/test_docstring.py @@ -0,0 +1,95 @@ +import typing + +from flytekit.core.docstring import Docstring + + +def test_get_variable_descriptions(): + # sphinx style + def z(a: int, b: str) -> typing.Tuple[int, str]: + """ + function z + + longer description here + + :param a: foo + :param b: bar + :return: ramen + """ + ... + + docstring = Docstring(callable_=z) + input_descriptions = docstring.input_descriptions + output_descriptions = docstring.output_descriptions + assert input_descriptions["a"] == "foo" + assert input_descriptions["b"] == "bar" + assert len(output_descriptions) == 1 + assert next(iter(output_descriptions.items()))[1] == "ramen" + assert docstring.short_description == "function z" + assert docstring.long_description == "longer description here" + + # numpy style + def z(a: int, b: str) -> typing.Tuple[int, str]: + """ + function z + + longer description here + + Parameters + ---------- + a : int + foo + b : str + bar + + Returns + ------- + out : tuple + ramen + """ + ... + + docstring = Docstring(callable_=z) + input_descriptions = docstring.input_descriptions + output_descriptions = docstring.output_descriptions + assert input_descriptions["a"] == "foo" + assert input_descriptions["b"] == "bar" + assert len(output_descriptions) == 1 + assert next(iter(output_descriptions.items()))[1] == "ramen" + assert docstring.short_description == "function z" + assert docstring.long_description == "longer description here" + + # google style + def z(a: int, b: str) -> typing.Tuple[int, str]: + """function z + + longer description here + + Args: + a(int): foo + b(str): bar + Returns: + str: ramen + """ + ... + + docstring = Docstring(callable_=z) + input_descriptions = docstring.input_descriptions + output_descriptions = docstring.output_descriptions + assert input_descriptions["a"] == "foo" + assert input_descriptions["b"] == "bar" + assert len(output_descriptions) == 1 + assert next(iter(output_descriptions.items()))[1] == "ramen" + assert docstring.short_description == "function z" + assert docstring.long_description == "longer description here" + + # empty doc + def z(a: int, b: str) -> typing.Tuple[int, str]: + ... + + docstring = Docstring(callable_=z) + input_descriptions = docstring.input_descriptions + output_descriptions = docstring.output_descriptions + assert len(input_descriptions) == 0 + assert len(output_descriptions) == 0 + assert docstring.short_description is None + assert docstring.long_description is None diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index b916832c2b..719613647c 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -4,14 +4,13 @@ import pytest -import flytekit from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager, Image, ImageConfig +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.interfaces.data.data_proxy import FileAccessProvider from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer @@ -32,8 +31,9 @@ def test_engine(): def test_transformer_to_literal_local(): + random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir) + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: # Use a separate directory that we know won't be the same as anything generated by flytekit itself, lest we @@ -80,7 +80,7 @@ def test_transformer_to_literal_local(): def test_transformer_to_literal_remote(): random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir) + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: # Use a separate directory that we know won't be the same as anything generated by flytekit itself, lest we @@ -102,7 +102,7 @@ def test_transformer_to_literal_remote(): def test_wf(): @task def t1() -> FlyteDirectory: - user_ctx = flytekit.current_context() + user_ctx = FlyteContextManager.current_context().user_space_params # Create a local directory to work with p = os.path.join(user_ctx.working_directory, "test_wf") if os.path.exists(p): diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index c283249e1f..a7dfec33db 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -1,13 +1,15 @@ import os -import flytekit +import pytest + from flytekit.core import context_manager from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.interfaces.data.data_proxy import FileAccessProvider from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile @@ -47,7 +49,7 @@ def my_wf(fname: os.PathLike = SAMPLE_DATA) -> int: return length assert my_wf.python_interface.inputs_with_defaults["fname"][1] == SAMPLE_DATA - sample_lp = flytekit.LaunchPlan.create("test_launch_plan", my_wf) + sample_lp = LaunchPlan.create("test_launch_plan", my_wf) assert sample_lp.parameters.parameters["fname"].default.scalar.blob.uri == SAMPLE_DATA @@ -62,20 +64,18 @@ def my_wf() -> FlyteFile: return t1() random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir) + # print(f"Random: {random_dir}") + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) - assert len(top_level_files) == 2 # the mock_remote folder and the local folder - - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 # the mock_remote folder itself is empty + assert len(top_level_files) == 1 # the local_flytekit folder x = my_wf() # After running, this test file should've been copied to the mock remote location. mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 1 + assert len(mock_remote_files) == 1 # the file # File should've been copied to the mock remote folder assert x.path.startswith(random_dir) @@ -91,20 +91,16 @@ def my_wf() -> FlyteFile: return t1() random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir) + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) - assert len(top_level_files) == 2 # the mock_remote folder and the local folder - - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 # the mock_remote folder itself is empty + assert len(top_level_files) == 1 # the flytekit_local folder workflow_output = my_wf() # After running, this test file should've been copied to the mock remote location. - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 + assert not os.path.exists(os.path.join(random_dir, "mock_remote")) # Because Flyte doesn't presume to handle a uri that look like a raw path, the path that is returned is # the original. @@ -125,20 +121,18 @@ def my_wf() -> FlyteFile: # This creates a random directory that we know is empty. random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir - fs = FileAccessProvider(local_sandbox_dir=random_dir) + print(f"Random {random_dir}") + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) - assert len(working_dir) == 2 # the mock_remote folder and the local folder - - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 # the mock_remote folder itself is empty + assert len(working_dir) == 1 # the local_flytekit folder workflow_output = my_wf() # After running the mock remote dir should still be empty, since the workflow_output has not been used - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 + with pytest.raises(FileNotFoundError): + os.listdir(os.path.join(random_dir, "mock_remote")) # While the literal returned by t1 does contain the web address as the uri, because it's a remote address, # flytekit will translate it back into a FlyteFile object on the local drive (but not download it) @@ -152,16 +146,16 @@ def my_wf() -> FlyteFile: # This second layer should have two dirs, a random one generated by the new_execution_context call # and an empty folder, created by FlyteFile transformer's to_python_value function. This folder will have # something in it after we open() it. - assert len(working_dir) == 2 + assert len(working_dir) == 1 assert not os.path.exists(workflow_output.path) - # The act of opening it should trigger the download, since we do lazy downloading. + # # The act of opening it should trigger the download, since we do lazy downloading. with open(workflow_output, "rb"): ... - assert os.path.exists(workflow_output.path) - - # The file name is maintained on download. - assert str(workflow_output).endswith(os.path.split(SAMPLE_DATA)[1]) + # assert os.path.exists(workflow_output.path) + # + # # The file name is maintained on download. + # assert str(workflow_output).endswith(os.path.split(SAMPLE_DATA)[1]) def test_file_handling_remote_file_handling_flyte_file(): @@ -178,40 +172,43 @@ def my_wf() -> FlyteFile: # This creates a random directory that we know is empty. random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + print(f"Random {random_dir}") # Creating a new FileAccessProvider will add two folderst to the random dir - fs = FileAccessProvider(local_sandbox_dir=random_dir) + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) - assert len(working_dir) == 2 # the mock_remote folder and the local folder + assert len(working_dir) == 1 # the local_flytekit dir - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 # the mock_remote folder itself is empty + mock_remote_path = os.path.join(random_dir, "mock_remote") + assert not os.path.exists(mock_remote_path) # the persistence layer won't create the folder yet workflow_output = my_wf() # After running the mock remote dir should still be empty, since the workflow_output has not been used - mock_remote_files = os.listdir(os.path.join(random_dir, "mock_remote")) - assert len(mock_remote_files) == 0 + assert not os.path.exists(mock_remote_path) # While the literal returned by t1 does contain the web address as the uri, because it's a remote address, # flytekit will translate it back into a FlyteFile object on the local drive (but not download it) - assert workflow_output.path.startswith(random_dir) + assert workflow_output.path.startswith(f"{random_dir}/local_flytekit") # But the remote source should still be the https address assert workflow_output.remote_source == SAMPLE_DATA # The act of running the workflow should create the engine dir, and the directory that will contain the # file but the file itself isn't downloaded yet. working_dir = os.listdir(os.path.join(random_dir, "local_flytekit")) - # This second layer should have two dirs, a random one generated by the new_execution_context call - # and an empty folder, created by FlyteFile transformer's to_python_value function. This folder will have - # something in it after we open() it. - assert len(working_dir) == 2 + assert len(working_dir) == 1 # local flytekit and the downloaded file assert not os.path.exists(workflow_output.path) - # The act of opening it should trigger the download, since we do lazy downloading. + # # The act of opening it should trigger the download, since we do lazy downloading. with open(workflow_output, "rb"): ... + # This second layer should have two dirs, a random one generated by the new_execution_context call + # and an empty folder, created by FlyteFile transformer's to_python_value function. This folder will have + # something in it after we open() it. + working_dir = os.listdir(os.path.join(random_dir, "local_flytekit")) + assert len(working_dir) == 2 # local flytekit and the downloaded file + assert os.path.exists(workflow_output.path) # The file name is maintained on download. diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 4ffabe4eb2..1fb47f39de 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,9 +4,11 @@ from typing import Dict, List from flytekit.core import context_manager +from flytekit.core.docstring import Docstring from flytekit.core.interface import ( extract_return_annotation, transform_inputs_to_parameters, + transform_interface_to_typed_interface, transform_signature_to_interface, transform_variable_map, ) @@ -175,3 +177,76 @@ def z(a: int = 7, b: str = "eleven") -> typing.Tuple[int, str]: assert params.parameters["a"].default.scalar.primitive.integer == 7 assert not params.parameters["b"].required assert params.parameters["b"].default.scalar.primitive.string_value == "eleven" + + +def test_transform_interface_to_typed_interface_with_docstring(): + # sphinx style + def z(a: int, b: str) -> typing.Tuple[int, str]: + """ + function z + + :param a: foo + :param b: bar + :return: ramen + """ + ... + + our_interface = transform_signature_to_interface(inspect.signature(z)) + typed_interface = transform_interface_to_typed_interface(our_interface, Docstring(callable_=z)) + assert typed_interface.inputs.get("a").description == "foo" + assert typed_interface.inputs.get("b").description == "bar" + assert typed_interface.outputs.get("o1").description == "ramen" + + # numpy style, multiple return values, shared descriptions + def z(a: int, b: str) -> typing.Tuple[int, str]: + """ + function z + + Parameters + ---------- + a : int + foo + b : str + bar + + Returns + ------- + out1, out2 : tuple + ramen + """ + ... + + our_interface = transform_signature_to_interface(inspect.signature(z)) + typed_interface = transform_interface_to_typed_interface(our_interface, Docstring(callable_=z)) + assert typed_interface.inputs.get("a").description == "foo" + assert typed_interface.inputs.get("b").description == "bar" + assert typed_interface.outputs.get("o0").description == "ramen" + assert typed_interface.outputs.get("o1").description == "ramen" + + # numpy style, multiple return values, named + def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int): + """ + function z + + Parameters + ---------- + a : int + foo + b : str + bar + + Returns + ------- + x_str : str + description for x_str + y_int : int + description for y_int + """ + ... + + our_interface = transform_signature_to_interface(inspect.signature(z)) + typed_interface = transform_interface_to_typed_interface(our_interface, Docstring(callable_=z)) + assert typed_interface.inputs.get("a").description == "foo" + assert typed_interface.inputs.get("b").description == "bar" + assert typed_interface.outputs.get("x_str").description == "description for x_str" + assert typed_interface.outputs.get("y_int").description == "description for y_int" diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 5096c39b4d..2f69001e5e 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -418,3 +418,21 @@ def say_hello() -> nm: def my_wf() -> wf_outputs: # Note only Namedtuples can be created like this return wf_outputs(say_hello(), say_hello()) + + +def test_serialized_docstrings(): + @task + def z(a: int, b: str) -> typing.Tuple[int, str]: + """ + function z + + :param a: foo + :param b: bar + :return: ramen + """ + ... + + task_spec = get_serializable(OrderedDict(), serialization_settings, z) + assert task_spec.template.interface.inputs["a"].description == "foo" + assert task_spec.template.interface.inputs["b"].description == "bar" + assert task_spec.template.interface.outputs["o0"].description == "ramen" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index bfa187c4c4..bb732ed437 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -16,6 +16,9 @@ from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig + +# from flytekit.interfaces.data.data_proxy import FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.node import Node from flytekit.core.promise import NodeOutput, Promise, VoidPromise from flytekit.core.resources import Resources @@ -23,7 +26,6 @@ from flytekit.core.testing import patch, task_mock from flytekit.core.type_engine import RestrictedTypeError, TypeEngine from flytekit.core.workflow import workflow -from flytekit.interfaces.data.data_proxy import FileAccessProvider from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter @@ -93,7 +95,7 @@ def test_engine_file_output(): dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting") + fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting", raw_output_prefix="/tmp/flyteraw") ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 865861243e..544b0ab74d 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -208,8 +208,12 @@ def simple_wf() -> int: @workflow def my_wf_example(a: int) -> (int, int): - """ + """example + Workflows can have inputs and return outputs of simple or complex types. + + :param a: input a + :return: outputs """ x = add_5(a=a) @@ -242,3 +246,13 @@ def test_all_node_types(): assert len(sub_wf.nodes) == 1 assert sub_wf.nodes[0].id == "n0" assert sub_wf.nodes[0].task_node.reference_id.name == "test_workflows.add_5" + + +def test_wf_docstring(): + model_wf = get_serializable(OrderedDict(), serialization_settings, my_wf_example) + + assert len(model_wf.template.interface.outputs) == 2 + assert model_wf.template.interface.outputs["o0"].description == "outputs" + assert model_wf.template.interface.outputs["o1"].description == "outputs" + assert len(model_wf.template.interface.inputs) == 1 + assert model_wf.template.interface.inputs["a"].description == "input a" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 5ea389e0ee..a4a6d9bdeb 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -58,7 +58,7 @@ def test_remote_fetch_execute_entities_task_workflow_launchplan( mock_client = MagicMock() getattr(mock_client, CLIENT_METHODS[resource_type]).return_value = admin_entities, "" - remote = FlyteRemote.from_environment("p1", "d1") + remote = FlyteRemote.from_config("p1", "d1") remote._client = mock_client fetch_method = getattr(remote, REMOTE_METHODS[resource_type]) flyte_entity_latest = fetch_method(name="n1", version="latest") @@ -86,7 +86,7 @@ def test_remote_fetch_workflow_execution(mock_insecure, mock_url, mock_client_ma mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution - remote = FlyteRemote.from_environment("p1", "d1") + remote = FlyteRemote.from_config("p1", "d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_workflow_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id