From 60a94f36242eacf14538aa78418634ee7cf9cc7e Mon Sep 17 00:00:00 2001 From: Jeev B Date: Thu, 25 Mar 2021 12:06:47 -0700 Subject: [PATCH 01/33] Fix serialization of pod specs in pod plugin (#433) Signed-off-by: Jeev B Signed-off-by: Max Hoffman --- plugins/pod/flytekitplugins/pod/task.py | 3 +- plugins/tests/pod/test_pod.py | 87 ++++++++++++++++++++----- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/plugins/pod/flytekitplugins/pod/task.py b/plugins/pod/flytekitplugins/pod/task.py index bb6a244d7e..5449b7df5b 100644 --- a/plugins/pod/flytekitplugins/pod/task.py +++ b/plugins/pod/flytekitplugins/pod/task.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Tuple, Union from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements from flytekit import FlyteContext, PythonFunctionTask @@ -76,7 +77,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: self.task_config._pod_spec.containers = final_containers - return self.task_config.pod_spec.to_dict() + return ApiClient().sanitize_for_serialization(self.task_config.pod_spec) def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return {_PRIMARY_CONTAINER_NAME_FIELD: self.task_config.primary_container_name} diff --git a/plugins/tests/pod/test_pod.py b/plugins/tests/pod/test_pod.py index b06b1f6874..e4d86984d3 100644 --- a/plugins/tests/pod/test_pod.py +++ b/plugins/tests/pod/test_pod.py @@ -1,7 +1,10 @@ +import json from collections import OrderedDict from typing import List +from unittest.mock import MagicMock -from kubernetes.client.models import V1Container, V1PodSpec, V1VolumeMount +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1VolumeMount from flytekit import Resources, dynamic, task from flytekit.common.translator import get_serializable @@ -19,7 +22,7 @@ def get_pod_spec(): return pod_spec -def test_pod_task(): +def test_pod_task_deserialization(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) @@ -40,11 +43,17 @@ def simple_pod_task(i: int): image_config=ImageConfig(default_image=default_img, images=[default_img]), ) ) - assert custom["restart_policy"] == "OnFailure" - assert len(custom["containers"]) == 2 - primary_container = custom["containers"][0] - assert primary_container["name"] == "a container" - assert primary_container["args"] == [ + + # Test that custom is correctly serialized by deserializing it with the python API client + response = MagicMock() + response.data = json.dumps(custom) + deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec) + + assert deserialized_pod_spec.restart_policy == "OnFailure" + assert len(deserialized_pod_spec.containers) == 2 + primary_container = deserialized_pod_spec.containers[0] + assert primary_container.name == "a container" + assert primary_container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", @@ -60,14 +69,11 @@ def simple_pod_task(i: int): "task-name", "simple_pod_task", ] - assert primary_container["volume_mounts"][0]["mount_path"] == "some/where" - assert primary_container["volume_mounts"][0]["name"] == "volume mount" - assert primary_container["resources"] == { - "requests": {"cpu": "10"}, - "limits": {"gpu": "2"}, - } - assert primary_container["env"] == [{"name": "FOO", "value": "bar", "value_from": None}] - assert custom["containers"][1]["name"] == "another container" + assert primary_container.volume_mounts[0].mount_path == "some/where" + assert primary_container.volume_mounts[0].name == "volume mount" + assert primary_container.resources == V1ResourceRequirements(limits={"gpu": "2"}, requests={"cpu": "10"}) + assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] + assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( SerializationSettings( @@ -81,6 +87,57 @@ def simple_pod_task(i: int): assert config["primary_container_name"] == "a container" +def test_pod_task(): + pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") + + @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) + def simple_pod_task(i: int): + pass + + assert isinstance(simple_pod_task, PodFunctionTask) + assert simple_pod_task.task_config == pod + + default_img = Image(name="default", fqn="test", tag="tag") + + custom = simple_pod_task.get_custom( + SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + ) + assert custom["restartPolicy"] == "OnFailure" + assert len(custom["containers"]) == 2 + primary_container = custom["containers"][0] + assert primary_container["name"] == "a container" + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "plugins.tests.pod.test_pod", + "task-name", + "simple_pod_task", + ] + assert primary_container["volumeMounts"][0]["mountPath"] == "some/where" + assert primary_container["volumeMounts"][0]["name"] == "volume mount" + assert primary_container["resources"] == { + "requests": {"cpu": "10"}, + "limits": {"gpu": "2"}, + } + assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] + assert custom["containers"][1]["name"] == "another container" + + def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") From 908df48227979558659164544a897ae33c98569c Mon Sep 17 00:00:00 2001 From: Jeev B Date: Thu, 25 Mar 2021 14:23:25 -0700 Subject: [PATCH 02/33] Remove beta version from readme install instructions (#434) Signed-off-by: Jeev B Signed-off-by: Max Hoffman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eed732ef43..3bf3eb892b 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Flytekit is the core extensible library to author Flyte workflows and tasks and ### Base Installation ```bash -pip install flytekit==0.16.0b7 +pip install flytekit ``` ### Simple getting started From 39e6226eaf6847d494c27c0696d294ca1f8fd669 Mon Sep 17 00:00:00 2001 From: Honnix Date: Fri, 26 Mar 2021 22:13:35 +0100 Subject: [PATCH 03/33] Move configuraiton of a few tools to pyproject.toml (#428) * Move configuraiton of a few tools to pyproject.toml Signed-off-by: Hongxin Liang * Stop pinning black Formatted by later version of black Signed-off-by: Hongxin Liang Signed-off-by: Max Hoffman --- dev-requirements.in | 4 +- dev-requirements.txt | 8 +- doc-requirements.txt | 32 +- flytekit/bin/entrypoint.py | 36 ++- flytekit/clients/friendly.py | 49 ++- flytekit/clients/helpers.py | 16 +- flytekit/clients/raw.py | 11 +- flytekit/clis/auth/auth.py | 12 +- flytekit/clis/auth/discovery.py | 4 +- flytekit/clis/flyte_cli/main.py | 189 +++++++++--- flytekit/clis/helpers.py | 6 +- flytekit/clis/sdk_in_container/constants.py | 12 +- flytekit/clis/sdk_in_container/launch_plan.py | 23 +- flytekit/clis/sdk_in_container/pyflyte.py | 15 +- flytekit/clis/sdk_in_container/register.py | 4 +- flytekit/common/core/identifier.py | 21 +- flytekit/common/exceptions/scopes.py | 16 +- flytekit/common/exceptions/user.py | 5 +- flytekit/common/interface.py | 11 +- flytekit/common/launch_plan.py | 17 +- flytekit/common/local_workflow.py | 9 +- flytekit/common/mixins/launchable.py | 8 +- flytekit/common/promise.py | 8 +- flytekit/common/schedules.py | 9 +- flytekit/common/tasks/generic_spark_task.py | 9 +- flytekit/common/tasks/hive_task.py | 30 +- flytekit/common/tasks/presto_task.py | 10 +- flytekit/common/tasks/raw_container.py | 4 +- .../sagemaker/built_in_training_job_task.py | 13 +- .../common/tasks/sagemaker/hpo_job_task.py | 10 +- flytekit/common/tasks/sdk_dynamic.py | 12 +- flytekit/common/tasks/sdk_runnable.py | 11 +- flytekit/common/tasks/sidecar_task.py | 7 +- flytekit/common/tasks/task.py | 12 +- flytekit/common/translator.py | 33 +- flytekit/common/types/blobs.py | 21 +- flytekit/common/types/containers.py | 12 +- flytekit/common/types/impl/blobs.py | 15 +- flytekit/common/types/impl/schema.py | 31 +- flytekit/common/types/primitives.py | 7 +- flytekit/common/types/schema.py | 10 +- flytekit/common/workflow.py | 27 +- flytekit/common/workflow_execution.py | 4 +- flytekit/contrib/notebook/tasks.py | 12 +- flytekit/contrib/sensors/impl.py | 5 +- flytekit/contrib/sensors/task.py | 3 +- flytekit/core/base_task.py | 11 +- flytekit/core/condition.py | 6 +- flytekit/core/promise.py | 3 +- flytekit/core/python_auto_container.py | 6 +- flytekit/core/reference_entity.py | 7 +- flytekit/core/schedule.py | 9 +- flytekit/core/task.py | 5 +- flytekit/core/type_engine.py | 27 +- flytekit/core/workflow.py | 30 +- flytekit/engines/common.py | 8 +- flytekit/engines/flyte/engine.py | 110 +++++-- flytekit/engines/unit/engine.py | 12 +- flytekit/interfaces/data/data_proxy.py | 15 +- flytekit/interfaces/data/gcs/gcs_proxy.py | 5 +- flytekit/interfaces/stats/taggable.py | 10 +- flytekit/models/admin/task_execution.py | 10 +- flytekit/models/admin/workflow.py | 3 +- flytekit/models/array_job.py | 12 +- flytekit/models/core/compiler.py | 3 +- flytekit/models/core/condition.py | 3 +- flytekit/models/core/execution.py | 19 +- flytekit/models/core/identifier.py | 28 +- flytekit/models/core/workflow.py | 17 +- flytekit/models/execution.py | 13 +- flytekit/models/interface.py | 5 +- flytekit/models/launch_plan.py | 4 +- flytekit/models/literals.py | 16 +- flytekit/models/matchable_resource.py | 28 +- flytekit/models/named_entity.py | 22 +- flytekit/models/node_execution.py | 4 +- flytekit/models/presto.py | 5 +- flytekit/models/qubole.py | 6 +- flytekit/models/sagemaker/hpo_job.py | 12 +- flytekit/models/sagemaker/parameter_ranges.py | 40 ++- flytekit/models/sagemaker/training_job.py | 18 +- flytekit/models/security.py | 10 +- flytekit/models/task.py | 43 ++- flytekit/models/types.py | 13 +- flytekit/models/workflow_closure.py | 3 +- flytekit/plugins/__init__.py | 4 +- flytekit/sdk/tasks.py | 14 +- flytekit/sdk/workflow.py | 5 +- flytekit/tools/fast_registration.py | 3 +- flytekit/tools/module_loader.py | 6 +- flytekit/type_engines/default/flyte.py | 3 +- .../flytekitplugins/awssagemaker/training.py | 18 +- plugins/hive/flytekitplugins/hive/task.py | 14 +- .../flytekitplugins/papermill/task.py | 18 +- plugins/pod/flytekitplugins/pod/task.py | 6 +- plugins/spark/flytekitplugins/spark/task.py | 5 +- plugins/tests/awssagemaker/test_hpo.py | 20 +- plugins/tests/awssagemaker/test_training.py | 21 +- plugins/tests/pod/test_pod.py | 11 +- pyproject.toml | 26 +- requirements-spark3.txt | 35 ++- requirements.txt | 35 ++- setup.cfg | 17 - setup.py | 2 +- tests/flytekit/common/parameterizers.py | 50 ++- tests/flytekit/common/workflows/python.py | 5 +- tests/flytekit/common/workflows/sagemaker.py | 12 +- tests/flytekit/common/workflows/sidecar.py | 19 +- tests/flytekit/common/workflows/simple.py | 8 +- .../unit/bin/test_python_entrypoint.py | 18 +- .../flytekit/unit/cli/auth/test_discovery.py | 16 +- tests/flytekit/unit/cli/pyflyte/conftest.py | 5 +- tests/flytekit/unit/cli/test_cli_helpers.py | 89 +++++- tests/flytekit/unit/cli/test_flyte_cli.py | 9 +- .../common_tests/exceptions/test_system.py | 4 +- .../unit/common_tests/exceptions/test_user.py | 5 +- .../tasks/test_execution_params.py | 4 +- .../unit/common_tests/tasks/test_task.py | 8 +- .../unit/common_tests/test_interface.py | 13 +- .../unit/common_tests/test_launch_plan.py | 36 ++- .../flytekit/unit/common_tests/test_nodes.py | 27 +- .../unit/common_tests/test_schedules.py | 3 +- .../unit/common_tests/test_translator.py | 5 +- .../unit/common_tests/test_workflow.py | 19 +- .../common_tests/test_workflow_promote.py | 7 +- .../common_tests/types/impl/test_schema.py | 22 +- .../unit/common_tests/types/test_blobs.py | 5 +- .../unit/common_tests/types/test_helpers.py | 3 +- .../test_temporary_configuration.py | 3 +- .../unit/contrib/sensors/test_impl.py | 4 +- tests/flytekit/unit/core/test_references.py | 14 +- tests/flytekit/unit/core/test_schedule.py | 7 +- .../flytekit/unit/core/test_serialization.py | 6 +- tests/flytekit/unit/core/test_type_hints.py | 5 +- .../unit/engines/flyte/test_engine.py | 290 +++++++++++++++--- .../flytekit/unit/extras/sqlite3/test_task.py | 17 +- .../unit/models/core/test_identifier.py | 5 +- tests/flytekit/unit/models/core/test_types.py | 5 +- .../unit/models/core/test_workflow.py | 14 +- .../unit/models/sagemaker/test_hpo_job.py | 5 +- .../models/sagemaker/test_training_job.py | 5 +- .../flytekit/unit/models/test_dynamic_job.py | 18 +- .../flytekit/unit/models/test_launch_plan.py | 4 +- tests/flytekit/unit/models/test_schedule.py | 3 +- tests/flytekit/unit/models/test_tasks.py | 22 +- .../unit/models/test_workflow_closure.py | 17 +- .../sdk/tasks/test_dynamic_sidecar_tasks.py | 13 +- .../unit/sdk/tasks/test_hive_tasks.py | 8 +- .../unit/sdk/tasks/test_sagemaker_tasks.py | 50 ++- .../unit/sdk/tasks/test_sidecar_tasks.py | 17 +- tests/flytekit/unit/sdk/tasks/test_tasks.py | 10 +- tests/flytekit/unit/test_plugins.py | 5 +- .../default/test_flyte_type_engine.py | 6 +- 153 files changed, 2028 insertions(+), 604 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index d8950c5e52..6fa5ebb080 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,7 +1,7 @@ -c requirements.txt -black==19.10b0 -coverage +black +coverage[toml] flake8 flake8-black flake8-isort diff --git a/dev-requirements.txt b/dev-requirements.txt index 5df2507807..5acc7855b8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,9 +11,8 @@ appdirs==1.4.4 attrs==20.3.0 # via # -c requirements.txt - # black # pytest -black==19.10b0 +black==20.8b1 # via # -c requirements.txt # -r dev-requirements.in @@ -22,7 +21,7 @@ click==7.1.2 # via # -c requirements.txt # black -coverage==5.5 +coverage[toml]==5.5 # via -r dev-requirements.in flake8-black==0.2.1 # via -r dev-requirements.in @@ -46,6 +45,7 @@ mock==4.0.3 mypy-extensions==0.4.3 # via # -c requirements.txt + # black # mypy mypy==0.812 # via -r dev-requirements.in @@ -83,6 +83,7 @@ toml==0.10.2 # via # -c requirements.txt # black + # coverage # pytest typed-ast==1.4.2 # via @@ -92,4 +93,5 @@ typed-ast==1.4.2 typing-extensions==3.7.4.3 # via # -c requirements.txt + # black # mypy diff --git a/doc-requirements.txt b/doc-requirements.txt index 05d748a71c..963d54ce5e 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -12,17 +12,12 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython astroid==2.5.1 # via sphinx-autoapi async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree babel==2.9.0 @@ -35,15 +30,13 @@ beautifulsoup4==4.9.3 # via # sphinx-code-include # sphinx-material -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -64,10 +57,11 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 +cryptography==3.4.7 # via # -r doc-requirements.in # paramiko + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.2 @@ -120,6 +114,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via # nbconvert @@ -161,7 +159,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -287,6 +287,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -377,7 +379,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index c3b9ed76d3..86c94fa2d7 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -126,7 +126,11 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, _logging.error("!! Begin Unknown System Error Captured by Flyte !!") exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError("SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE,) + _error_models.ContainerError( + "SYSTEM:Unknown", + exc_str, + _error_models.ContainerError.Kind.RECOVERABLE, + ) ) _logging.error(exc_str) _logging.error("!! End Error Captured by Flyte !!") @@ -185,11 +189,13 @@ def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( - local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), + local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), + remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( - local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), + local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), + remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created @@ -353,7 +359,9 @@ def _pass_through(): @_click.option("--test", is_flag=True) @_click.option("--resolver", required=False) @_click.argument( - "resolver-args", type=_click.UNPROCESSED, nargs=-1, + "resolver-args", + type=_click.UNPROCESSED, + nargs=-1, ) def execute_task_cmd( task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args @@ -408,15 +416,29 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): @_click.option("--test", is_flag=True) @_click.option("--resolver", required=True) @_click.argument( - "resolver-args", type=_click.UNPROCESSED, nargs=-1, + "resolver-args", + type=_click.UNPROCESSED, + nargs=-1, ) def map_execute_task_cmd( - inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver, resolver_args, + inputs, + output_prefix, + raw_output_data_prefix, + max_concurrency, + test, + resolver, + resolver_args, ): _click.echo(_utils.get_version_message()) _execute_map_task( - inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver, resolver_args, + inputs, + output_prefix, + raw_output_data_prefix, + max_concurrency, + test, + resolver, + resolver_args, ) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 884fd8c22c..a5128f7e6d 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -316,7 +316,8 @@ def create_launch_plan(self, launch_plan_identifer, launch_plan_spec): """ super(SynchronousFlyteClient, self).create_launch_plan( _launch_plan_pb2.LaunchPlanCreateRequest( - id=launch_plan_identifer.to_flyte_idl(), spec=launch_plan_spec.to_flyte_idl(), + id=launch_plan_identifer.to_flyte_idl(), + spec=launch_plan_spec.to_flyte_idl(), ) ) @@ -506,7 +507,9 @@ def update_named_entity(self, resource_type, id, metadata): """ super(SynchronousFlyteClient, self).update_named_entity( _common_pb2.NamedEntityUpdateRequest( - resource_type=resource_type, id=id.to_flyte_idl(), metadata=metadata.to_flyte_idl(), + resource_type=resource_type, + id=id.to_flyte_idl(), + metadata=metadata.to_flyte_idl(), ) ) @@ -661,7 +664,12 @@ def get_node_execution_data(self, node_execution_identifier): ) def list_node_executions( - self, workflow_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + workflow_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ TODO: Comment @@ -689,7 +697,12 @@ def list_node_executions( ) def list_node_executions_for_task_paginated( - self, task_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + task_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ This returns nodes spawned by a specific task execution. This is generally from things like dynamic tasks. @@ -747,7 +760,12 @@ def get_task_execution_data(self, task_execution_identifier): ) def list_task_executions_paginated( - self, node_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + node_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ :param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier: @@ -786,7 +804,9 @@ def register_project(self, project): :rtype: flyteidl.admin.project_pb2.ProjectRegisterResponse """ super(SynchronousFlyteClient, self).register_project( - _project_pb2.ProjectRegisterRequest(project=project.to_flyte_idl(),) + _project_pb2.ProjectRegisterRequest( + project=project.to_flyte_idl(), + ) ) def update_project(self, project): @@ -853,7 +873,9 @@ def update_project_domain_attributes(self, project, domain, matching_attributes) super(SynchronousFlyteClient, self).update_project_domain_attributes( _project_domain_attributes_pb2.ProjectDomainAttributesUpdateRequest( attributes=_project_domain_attributes_pb2.ProjectDomainAttributes( - project=project, domain=domain, matching_attributes=matching_attributes.to_flyte_idl(), + project=project, + domain=domain, + matching_attributes=matching_attributes.to_flyte_idl(), ) ) ) @@ -888,7 +910,9 @@ def get_project_domain_attributes(self, project, domain, resource_type): """ return super(SynchronousFlyteClient, self).get_project_domain_attributes( _project_domain_attributes_pb2.ProjectDomainAttributesGetRequest( - project=project, domain=domain, resource_type=resource_type, + project=project, + domain=domain, + resource_type=resource_type, ) ) @@ -903,7 +927,10 @@ def get_workflow_attributes(self, project, domain, workflow, resource_type): """ return super(SynchronousFlyteClient, self).get_workflow_attributes( _workflow_attributes_pb2.WorkflowAttributesGetRequest( - project=project, domain=domain, workflow=workflow, resource_type=resource_type, + project=project, + domain=domain, + workflow=workflow, + resource_type=resource_type, ) ) @@ -914,5 +941,7 @@ def list_matchable_attributes(self, resource_type): :return: """ return super(SynchronousFlyteClient, self).list_matchable_attributes( - _matchable_resource_pb2.ListMatchableAttributesRequest(resource_type=resource_type,) + _matchable_resource_pb2.ListMatchableAttributesRequest( + resource_type=resource_type, + ) ) diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 75b2232636..4d8a7912a4 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,5 +1,9 @@ def iterate_node_executions( - client, workflow_execution_identifier=None, task_execution_identifier=None, limit=None, filters=None, + client, + workflow_execution_identifier=None, + task_execution_identifier=None, + limit=None, + filters=None, ): """ This returns a generator for node executions. @@ -25,7 +29,10 @@ def iterate_node_executions( ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, filters=filters, + task_execution_identifier=task_execution_identifier, + limit=num_to_fetch, + token=token, + filters=filters, ) for n in node_execs: counter += 1 @@ -53,7 +60,10 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte counter = 0 while True: task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, filters=filters, + node_execution_identifier=node_execution_identifier, + limit=num_to_fetch, + token=token, + filters=filters, ) for t in task_execs: counter += 1 diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index cb10666151..18e923ff62 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -151,7 +151,9 @@ def __init__(self, url, insecure=False, credentials=None, options=None): self._channel = _insecure_channel(url, options=list((options or {}).items())) else: self._channel = _secure_channel( - url, credentials or _ssl_channel_credentials(), options=list((options or {}).items()), + url, + credentials or _ssl_channel_credentials(), + options=list((options or {}).items()), ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None @@ -165,7 +167,12 @@ def url(self) -> str: def set_access_token(self, access_token): # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for. - self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), "Bearer {}".format(access_token),)] + self._metadata = [ + ( + _creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), + "Bearer {}".format(access_token), + ) + ] def force_auth_flow(self): refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 2afc4644e7..43ce73ee0a 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -118,7 +118,12 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): """ def __init__( - self, server_address, RequestHandlerClass, bind_and_activate=True, redirect_path=None, queue=None, + self, + server_address, + RequestHandlerClass, + bind_and_activate=True, + redirect_path=None, + queue=None, ): _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self._redirect_path = redirect_path @@ -233,7 +238,10 @@ def request_access_token(self, auth_code): {"code": auth_code.code, "code_verifier": self._code_verifier, "grant_type": "authorization_code"} ) resp = _requests.post( - url=self._token_endpoint, data=self._params, headers=self._headers, allow_redirects=False, + url=self._token_endpoint, + data=self._params, + headers=self._headers, + allow_redirects=False, ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index d661f6302e..5134eab974 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -46,7 +46,9 @@ def authorization_endpoints(self): def get_authorization_endpoints(self): if self.authorization_endpoints is not None: return self.authorization_endpoints - resp = _requests.get(url=self._discovery_url,) + resp = _requests.get( + url=self._discovery_url, + ) response_body = resp.json() diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 4d8cefea45..74f082b42e 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -114,7 +114,11 @@ def _get_io_string(literal_map, verbose=False): if value_dict: return "\n" + "\n".join( "{:30}: {}".format( - k, _prefix_lines("{:30} ".format(""), v.verbose_string() if verbose else v.short_string(),), + k, + _prefix_lines( + "{:30} ".format(""), + v.verbose_string() if verbose else v.short_string(), + ), ) for k, v in _six.iteritems(value_dict) ) @@ -203,7 +207,10 @@ def _secho_node_execution_status(status, nl=True): fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, + "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), + bold=True, + fg=fg, + nl=nl, ) @@ -228,7 +235,10 @@ def _secho_task_execution_status(status, nl=True): fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, + "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), + bold=True, + fg=fg, + nl=nl, ) @@ -245,7 +255,8 @@ def _secho_one_execution(ex, urns_only): _secho_workflow_status(ex.closure.phase) else: _click.echo( - "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), nl=True, + "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), + nl=True, ) @@ -305,19 +316,33 @@ def _render_schedule_expr(lp): _project_option = _click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to query.") _optional_project_option = _click.option( - *_PROJECT_FLAGS, required=False, default=None, help="[Optional] The project namespace to query.", + *_PROJECT_FLAGS, + required=False, + default=None, + help="[Optional] The project namespace to query.", ) _domain_option = _click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to query.") _optional_domain_option = _click.option( - *_DOMAIN_FLAGS, required=False, default=None, help="[Optional] The domain namespace to query.", + *_DOMAIN_FLAGS, + required=False, + default=None, + help="[Optional] The domain namespace to query.", ) _name_option = _click.option(*_NAME_FLAGS, required=True, help="The name to query.") _optional_name_option = _click.option( - *_NAME_FLAGS, required=False, type=str, default=None, help="[Optional] The name to query.", + *_NAME_FLAGS, + required=False, + type=str, + default=None, + help="[Optional] The name to query.", ) _principal_option = _click.option(*_PRINCIPAL_FLAGS, required=True, help="Your team name, or your name") _optional_principal_option = _click.option( - *_PRINCIPAL_FLAGS, required=False, type=str, default=None, help="[Optional] Your team name, or your name", + *_PRINCIPAL_FLAGS, + required=False, + type=str, + default=None, + help="[Optional] Your team name, or your name", ) _insecure_option = _click.option(*_INSECURE_FLAGS, is_flag=True, required=True, help="Do not use SSL") _urn_option = _click.option("-u", "--urn", required=True, help="The unique identifier for an entity.") @@ -340,10 +365,19 @@ def _render_schedule_expr(lp): help="Pagination token from which to start listing in the list of results.", ) _limit_option = _click.option( - "-l", "--limit", required=False, default=100, type=int, help="Maximum number of results to return for this call.", + "-l", + "--limit", + required=False, + default=100, + type=int, + help="Maximum number of results to return for this call.", ) _show_all_option = _click.option( - "-a", "--show-all", is_flag=True, default=False, help="Set this flag to page through and list all results.", + "-a", + "--show-all", + is_flag=True, + default=False, + help="Set this flag to page through and list all results.", ) # TODO: Provide documentation on filter format _filter_option = _click.option( @@ -367,10 +401,15 @@ def _render_schedule_expr(lp): help="The state change to apply to a named entity", ) _named_entity_description_option = _click.option( - "--description", required=False, type=str, help="Concise description for the entity.", + "--description", + required=False, + type=str, + help="Concise description for the entity.", ) _sort_by_option = _click.option( - "--sort-by", required=False, help="Provide an entity field to be sorted. i.e. asc(name) or desc(name)", + "--sort-by", + required=False, + help="Provide an entity field to be sorted. i.e. asc(name) or desc(name)", ) _show_io_option = _click.option( "--show-io", @@ -380,7 +419,10 @@ def _render_schedule_expr(lp): " inputs and outputs.", ) _verbose_option = _click.option( - "--verbose", is_flag=True, default=False, help="Set this flag to view the full textual description of all fields.", + "--verbose", + is_flag=True, + default=False, + help="Set this flag to view the full textual description of all fields.", ) _filename_option = _click.option("-f", "--filename", required=True, help="File path of pb file") @@ -391,7 +433,10 @@ def _render_schedule_expr(lp): help="Dot (.) separated path to Python IDL class. (e.g. flyteidl.core.workflow_closure_pb2.WorkflowClosure)", ) _cause_option = _click.option( - "-c", "--cause", required=True, help="The message signaling the cause of the termination of the execution(s)", + "-c", + "--cause", + required=True, + help="The message signaling the cause of the termination of the execution(s)", ) _optional_urns_only_option = _click.option( "--urns-only", @@ -401,13 +446,25 @@ def _render_schedule_expr(lp): help="[Optional] Set the flag if you want to output the urn(s) only. Setting this will override the verbose flag", ) _project_identifier_option = _click.option( - "-p", "--identifier", required=True, type=str, help="Unique identifier for the project.", + "-p", + "--identifier", + required=True, + type=str, + help="Unique identifier for the project.", ) _project_name_option = _click.option( - "-n", "--name", required=True, type=str, help="The human-readable name for the project.", + "-n", + "--name", + required=True, + type=str, + help="The human-readable name for the project.", ) _project_description_option = _click.option( - "-d", "--description", required=True, type=str, help="Concise description for the project.", + "-d", + "--description", + required=True, + type=str, + help="Concise description for the project.", ) _watch_option = _click.option( "-w", @@ -428,7 +485,11 @@ def _render_schedule_expr(lp): _output_location_prefix_option = _click.option( "-o", "--output-location-prefix", help="Custom output location prefix for offloaded types (files/schemas)" ) -_files_argument = _click.argument("files", type=_click.Path(exists=True), nargs=-1,) +_files_argument = _click.argument( + "files", + type=_click.Path(exists=True), + nargs=-1, +) class _FlyteSubCommand(_click.Command): @@ -612,7 +673,12 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for t in task_list: - _click.echo("{:50} {:40}".format(_tt(t.id.version), _tt(_identifier.Identifier.promote_from_model(t.id)),)) + _click.echo( + "{:50} {:40}".format( + _tt(t.id.version), + _tt(_identifier.Identifier.promote_from_model(t.id)), + ) + ) if show_all is not True: if next_token: @@ -770,7 +836,12 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for w in wf_list: - _click.echo("{:50} {:40}".format(_tt(w.id.version), _tt(_identifier.Identifier.promote_from_model(w.id)),)) + _click.echo( + "{:50} {:40}".format( + _tt(w.id.version), + _tt(_identifier.Identifier.promote_from_model(w.id)), + ) + ) if show_all is not True: if next_token: @@ -912,7 +983,17 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show @_sort_by_option @_optional_urns_only_option def list_launch_plan_versions( - project, domain, name, host, insecure, token, limit, show_all, filter, sort_by, urns_only, + project, + domain, + name, + host, + insecure, + token, + limit, + show_all, + filter, + sort_by, + urns_only, ): """ List the versions of all the launch plans under the scope specified by {project, domain}. @@ -937,7 +1018,10 @@ def list_launch_plan_versions( _click.echo(_tt(_identifier.Identifier.promote_from_model(l.id))) else: _click.echo( - "{:50} {:80} ".format(_tt(l.id.version), _tt(_identifier.Identifier.promote_from_model(l.id)),), + "{:50} {:80} ".format( + _tt(l.id.version), + _tt(_identifier.Identifier.promote_from_model(l.id)), + ), nl=False, ) if l.spec.entity_metadata.schedule is not None and ( @@ -1303,21 +1387,27 @@ def _get_io(node_executions, wf_execution, show_io, verbose): def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbose): _click.echo( "\nExecution {project}:{domain}:{name}\n".format( - project=_tt(wf_execution.id.project), domain=_tt(wf_execution.id.domain), name=_tt(wf_execution.id.name), + project=_tt(wf_execution.id.project), + domain=_tt(wf_execution.id.domain), + name=_tt(wf_execution.id.name), ) ) _click.echo("\t{:15} ".format("State:"), nl=False) _secho_workflow_status(wf_execution.closure.phase) _click.echo( "\t{:15} {}".format( - "Launch Plan:", _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), + "Launch Plan:", + _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), ) ) if show_io: _click.secho( "\tInputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose),) + _prefix_lines( + "\t\t", + _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose), + ) ) ) if wf_execution.closure.outputs is not None: @@ -1326,14 +1416,20 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos "\tOutputs: {}\n".format( _prefix_lines( "\t\t", - uri_to_message_map.get(wf_execution.closure.outputs.uri, wf_execution.closure.outputs.uri,), + uri_to_message_map.get( + wf_execution.closure.outputs.uri, + wf_execution.closure.outputs.uri, + ), ) ) ) elif wf_execution.closure.outputs.values is not None: _click.secho( "\tOutputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.outputs.values, verbose=verbose),) + _prefix_lines( + "\t\t", + _get_io_string(wf_execution.closure.outputs.values, verbose=verbose), + ) ) ) else: @@ -1341,7 +1437,9 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos if wf_execution.closure.error is not None: _click.secho( - _prefix_lines("\t", _render_error(wf_execution.closure.error)), fg="red", bold=True, + _prefix_lines("\t", _render_error(wf_execution.closure.error)), + fg="red", + bold=True, ) @@ -1360,7 +1458,9 @@ def _get_all_task_executions_for_node(client, node_execution_identifier): while True: num_to_fetch = 100 task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, + node_execution_identifier=node_execution_identifier, + limit=num_to_fetch, + token=token, ) for te in task_execs: fetched_task_execs.append(te) @@ -1379,11 +1479,15 @@ def _get_all_node_executions(client, workflow_execution_identifier=None, task_ex num_to_fetch = 100 if workflow_execution_identifier: node_execs, next_token = client.list_node_executions( - workflow_execution_identifier=workflow_execution_identifier, limit=num_to_fetch, token=token, + workflow_execution_identifier=workflow_execution_identifier, + limit=num_to_fetch, + token=token, ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, + task_execution_identifier=task_execution_identifier, + limit=num_to_fetch, + token=token, ) all_node_execs.extend(node_execs) if not next_token: @@ -1412,7 +1516,11 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure _click.echo("\t\t\t{:15} {:60} ".format("Duration:", _tt(ne.closure.duration))) _click.echo( "\t\t\t{:15} {}".format( - "Input:", _prefix_lines("\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.input_uri, ne.input_uri),), + "Input:", + _prefix_lines( + "\t\t\t{:15} ".format(""), + uri_to_message_map.get(ne.input_uri, ne.input_uri), + ), ) ) if ne.closure.output_uri: @@ -1420,13 +1528,16 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure "\t\t\t{:15} {}".format( "Output:", _prefix_lines( - "\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri), + "\t\t\t{:15} ".format(""), + uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri), ), ) ) if ne.closure.error is not None: _click.secho( - _prefix_lines("\t\t\t", _render_error(ne.closure.error)), bold=True, fg="red", + _prefix_lines("\t\t\t", _render_error(ne.closure.error)), + bold=True, + fg="red", ) task_executions = node_executions_to_task_executions.get(ne.id, []) @@ -1450,7 +1561,9 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure if te.closure.error is not None: _click.secho( - _prefix_lines("\t\t\t\t\t", _render_error(te.closure.error)), bold=True, fg="red", + _prefix_lines("\t\t\t\t\t", _render_error(te.closure.error)), + bold=True, + fg="red", ) if te.is_parent: @@ -1497,7 +1610,8 @@ def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) node_execs = _get_all_node_executions( - client, task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), + client, + task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), ) _render_node_executions(client, node_execs, show_io, verbose, host, insecure) @@ -2197,7 +2311,8 @@ def setup_config(host, insecure): config_dir = _os.path.join(_get_user_filepath_home(), _default_config_file_dir) if not _os.path.isdir(config_dir): _click.secho( - "Creating default Flyte configuration directory at {}".format(_tt(config_dir)), fg="blue", + "Creating default Flyte configuration directory at {}".format(_tt(config_dir)), + fg="blue", ) _os.mkdir(config_dir) diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index b764786104..0d264054de 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -149,7 +149,11 @@ def _hydrate_workflow_template_nodes( def hydrate_registration_parameters( - resource_type: int, project: str, domain: str, version: str, entity: Union[LaunchPlan, WorkflowSpec, TaskSpec], + resource_type: int, + project: str, + domain: str, + version: str, + entity: Union[LaunchPlan, WorkflowSpec, TaskSpec], ) -> Tuple[_identifier_pb2.Identifier, Union[LaunchPlan, WorkflowSpec, TaskSpec]]: """ This is called at registration time to fill out identifier fields (e.g. project, domain, version) that are mutable. diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 84d8956cb8..bc2ff68666 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -16,8 +16,16 @@ help="Flyte project to use. You can have more than one project per repo", ) domain_option = _click.option( - "-d", "--domain", required=True, type=str, help="This is usually development, staging, or production", + "-d", + "--domain", + required=True, + type=str, + help="This is usually development, staging, or production", ) version_option = _click.option( - "-v", "--version", required=False, type=str, help="This is the version to apply globally for this context", + "-v", + "--version", + required=False, + type=str, + help="This is the version to apply globally for this context", ) diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index 49102070b0..e367fa08d8 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -63,7 +63,9 @@ def get_command(self, ctx, lp_argument): launch_plan = ctx.obj["lps"][lp_argument] else: for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False, + pkgs, + include_entities={_SdkLaunchPlan}, + detect_unreferenced_entities=False, ): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) if lp_argument == safe_name: @@ -114,7 +116,10 @@ def _execute_lp(**kwargs): notification_overrides=ctx.obj.get(_constants.CTX_NOTIFICATIONS, None), ) click.echo( - click.style("Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), fg="blue",) + click.style( + "Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), + fg="blue", + ) ) command = click.Command(name=cmd_name, callback=_execute_lp) @@ -130,7 +135,12 @@ def _execute_lp(**kwargs): if param.required: # If it's a required input, add the required flag - wrapper = click.option("--{}".format(var_name), required=True, type=_six.text_type, help=help_msg,) + wrapper = click.option( + "--{}".format(var_name), + required=True, + type=_six.text_type, + help=help_msg, + ) else: # If it's not a required input, it should have a default # Use to_python_std so that the text of the default ends up being parseable, if not, the click @@ -217,7 +227,8 @@ def activate_all_schedules(ctx, version=None): The behavior of this command is identical to activate-all. """ click.secho( - "activate-all-schedules is deprecated, please use activate-all instead.", color="yellow", + "activate-all-schedules is deprecated, please use activate-all instead.", + color="yellow", ) project = ctx.obj[_constants.CTX_PROJECT] domain = ctx.obj[_constants.CTX_DOMAIN] @@ -234,7 +245,9 @@ def activate_all_schedules(ctx, version=None): help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", ) @click.option( - "--ignore-schedules", is_flag=True, help="Activate all except for launch plans with schedules.", + "--ignore-schedules", + is_flag=True, + help="Activate all except for launch plans with schedules.", ) @click.pass_context def activate_all(ctx, version=None, ignore_schedules=False): diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a888367ca6..cbb15fe02a 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -20,7 +20,11 @@ @click.group("pyflyte", invoke_without_command=True) @click.option( - "-c", "--config", required=False, type=str, help="Path to config file for use within container", + "-c", + "--config", + required=False, + type=str, + help="Path to config file for use within container", ) @click.option( "-k", @@ -31,7 +35,11 @@ "option will override the option specified in the configuration file, or environment variable", ) @click.option( - "-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin", + "-i", + "--insecure", + required=False, + type=bool, + help="Do not use SSL to connect to Admin", ) @click.pass_context def main(ctx, config=None, pkgs=None, insecure=None): @@ -71,7 +79,8 @@ def update_configuration_file(config_file_path): configuration_file = Path(config_file_path or CONFIGURATION_PATH.get()) if configuration_file.is_file(): click.secho( - "Using configuration file at {}".format(configuration_file.absolute().as_posix()), fg="green", + "Using configuration file at {}".format(configuration_file.absolute().as_posix()), + fg="green", ) set_flyte_config_file(configuration_file.as_posix()) else: diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 25fe1ee9c7..876629b851 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -71,7 +71,9 @@ def register_tasks_only(project, domain, pkgs, test, version): @version_option # --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead @click.option( - "--pkgs", multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword", + "--pkgs", + multiple=True, + help="DEPRECATED. This arg can only be used before the 'register' keyword", ) @click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") @click.pass_context diff --git a/flytekit/common/core/identifier.py b/flytekit/common/core/identifier.py index 1510bdf293..c7b12a5190 100644 --- a/flytekit/common/core/identifier.py +++ b/flytekit/common/core/identifier.py @@ -21,7 +21,11 @@ def promote_from_model(cls, base_model): :rtype: Identifier """ return cls( - base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version, + base_model.resource_type, + base_model.project, + base_model.domain, + base_model.name, + base_model.version, ) @classmethod @@ -66,7 +70,11 @@ def promote_from_model(cls, base_model): :param flytekit.models.core.identifier.WorkflowExecutionIdentifier base_model: :rtype: WorkflowExecutionIdentifier """ - return cls(base_model.project, base_model.domain, base_model.name,) + return cls( + base_model.project, + base_model.domain, + base_model.name, + ) @classmethod def from_python_std(cls, string): @@ -91,7 +99,11 @@ def from_python_std(cls, string): "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) - return cls(project, domain, name,) + return cls( + project, + domain, + name, + ) def __str__(self): return "ex:{}:{}:{}".format(self.project, self.domain, self.name) @@ -136,7 +148,8 @@ def from_python_std(cls, string): return cls( task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), node_execution_id=_core_identifier.NodeExecutionIdentifier( - node_id=node_id, execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + node_id=node_id, + execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), ), retry_attempt=int(retry), ) diff --git a/flytekit/common/exceptions/scopes.py b/flytekit/common/exceptions/scopes.py index fdd4e1a802..994211f6a6 100644 --- a/flytekit/common/exceptions/scopes.py +++ b/flytekit/common/exceptions/scopes.py @@ -159,7 +159,9 @@ def system_entry_point(wrapped, instance, args, kwargs): except _user_exceptions.FlyteUserException: # Re-raise from here. _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) except Exception: # System error, raise full stack-trace all the way up the chain. @@ -198,17 +200,23 @@ def user_entry_point(wrapped, instance, args, kwargs): _reraise(*_exc_info()) except _user_exceptions.FlyteUserException: _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) except _system_exceptions.FlyteSystemException: _reraise( - FlyteScopedSystemException, FlyteScopedSystemException(*_exc_info()), _exc_info()[2], + FlyteScopedSystemException, + FlyteScopedSystemException(*_exc_info()), + _exc_info()[2], ) except Exception: # Any non-platform raised exception is a user exception. # This will also catch FlyteUserException re-raised by the system_entry_point handler _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) finally: _CONTEXT_STACK.pop() diff --git a/flytekit/common/exceptions/user.py b/flytekit/common/exceptions/user.py index 671ebf66ec..acb5dd7997 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/common/exceptions/user.py @@ -34,7 +34,10 @@ def _create_verbose_message(cls, received_type, expected_type, received_value=No def __init__(self, received_type, expected_type, additional_msg=None, received_value=None): super(FlyteTypeException, self).__init__( self._create_verbose_message( - received_type, expected_type, received_value=received_value, additional_msg=additional_msg, + received_type, + expected_type, + received_value=received_value, + additional_msg=additional_msg, ) ) diff --git a/flytekit/common/interface.py b/flytekit/common/interface.py index 0f8e5569fa..c8f5160e15 100644 --- a/flytekit/common/interface.py +++ b/flytekit/common/interface.py @@ -30,7 +30,12 @@ def promote_from_model(cls, model): :param flytekit.models.literals.BindingData model: :rtype: BindingData """ - return cls(scalar=model.scalar, collection=model.collection, promise=model.promise, map=model.map,) + return cls( + scalar=model.scalar, + collection=model.collection, + promise=model.promise, + map=model.map, + ) @classmethod def from_python_std(cls, literal_type, t_value, upstream_nodes=None): @@ -75,7 +80,9 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): collection = _literal_models.BindingDataCollection( [ BindingData.from_python_std( - downstream_sdk_type.sub_type.to_flyte_literal_type(), v, upstream_nodes=upstream_nodes, + downstream_sdk_type.sub_type.to_flyte_literal_type(), + v, + upstream_nodes=upstream_nodes, ) for v in t_value ] diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py index 0cb1cc872e..7098d3e556 100644 --- a/flytekit/common/launch_plan.py +++ b/flytekit/common/launch_plan.py @@ -184,7 +184,8 @@ def auth_role(self): ) assumable_iam_role = _sdk_config.ROLE.get() return _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) @property @@ -274,7 +275,13 @@ def execute_with_literals( Deprecated. """ return self.launch_with_literals( - project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + literal_inputs, + name, + notification_overrides, + label_overrides, + annotation_overrides, ) @_exception_scopes.system_entry_point @@ -425,7 +432,8 @@ def __init__( super(SdkRunnableLaunchPlan, self).__init__( None, _launch_plan_models.LaunchPlanMetadata( - schedule=schedule or _schedule_model.Schedule(""), notifications=notifications or [], + schedule=schedule or _schedule_model.Schedule(""), + notifications=notifications or [], ), _interface_models.ParameterMap(default_inputs), _type_helpers.pack_python_std_map_to_literal_map( @@ -442,7 +450,8 @@ def __init__( raw_output_data_config or _common_models.RawOutputDataConfig(""), ) self._interface = _interface.TypedInterface( - {k: v.var for k, v in _six.iteritems(default_inputs)}, sdk_workflow.interface.outputs, + {k: v.var for k, v in _six.iteritems(default_inputs)}, + sdk_workflow.interface.outputs, ) self._upstream_entities = {sdk_workflow} self._sdk_workflow = sdk_workflow diff --git a/flytekit/common/local_workflow.py b/flytekit/common/local_workflow.py index 9aaecf5385..eb2067578f 100644 --- a/flytekit/common/local_workflow.py +++ b/flytekit/common/local_workflow.py @@ -280,7 +280,8 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan if role: assumable_iam_role = role # For backwards compatibility auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) raw_output_config = _common_models.RawOutputDataConfig(raw_output_data_prefix or "") @@ -359,13 +360,15 @@ def _discover_workflow_components(workflow_class): elif isinstance(current_obj, _promise.Input): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, "Detected workflow input specified outside of top level.", + attribute_name, + "Detected workflow input specified outside of top level.", ) inputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, Output): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, "Detected workflow output specified outside of top level.", + attribute_name, + "Detected workflow output specified outside of top level.", ) outputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, list) or isinstance(current_obj, set) or isinstance(current_obj, tuple): diff --git a/flytekit/common/mixins/launchable.py b/flytekit/common/mixins/launchable.py index 689623e867..110ba663af 100644 --- a/flytekit/common/mixins/launchable.py +++ b/flytekit/common/mixins/launchable.py @@ -119,5 +119,11 @@ def execute_with_literals( Deprecated. """ return self.launch_with_literals( - project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + literal_inputs, + name, + notification_overrides, + label_overrides, + annotation_overrides, ) diff --git a/flytekit/common/promise.py b/flytekit/common/promise.py index 2cf3bfba50..79352dd638 100644 --- a/flytekit/common/promise.py +++ b/flytekit/common/promise.py @@ -109,7 +109,13 @@ def promote_from_model(cls, model): if model.default is not None: default_value = sdk_type.from_flyte_idl(model.default.to_flyte_idl()).to_python_std() - return cls("", sdk_type, help=model.var.description, required=False, default=default_value,) + return cls( + "", + sdk_type, + help=model.var.description, + required=False, + default=default_value, + ) else: return cls("", sdk_type, help=model.var.description, required=True) diff --git a/flytekit/common/schedules.py b/flytekit/common/schedules.py index 250b7aff39..d6e71c6f83 100644 --- a/flytekit/common/schedules.py +++ b/flytekit/common/schedules.py @@ -165,15 +165,18 @@ def _translate_duration(duration): ) elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), _schedule_models.Schedule.FixedRateUnit.DAY, + int(duration.total_seconds() / _SECONDS_TO_DAYS), + _schedule_models.Schedule.FixedRateUnit.DAY, ) elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), _schedule_models.Schedule.FixedRateUnit.HOUR, + int(duration.total_seconds() / _SECONDS_TO_HOURS), + _schedule_models.Schedule.FixedRateUnit.HOUR, ) else: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, + int(duration.total_seconds() / _SECONDS_TO_MINUTES), + _schedule_models.Schedule.FixedRateUnit.MINUTE, ) @classmethod diff --git a/flytekit/common/tasks/generic_spark_task.py b/flytekit/common/tasks/generic_spark_task.py index 3a7c0ef1bc..a83a7afbe4 100644 --- a/flytekit/common/tasks/generic_spark_task.py +++ b/flytekit/common/tasks/generic_spark_task.py @@ -76,7 +76,11 @@ def __init__( task_type, _task_models.TaskMetadata( discoverable, - _task_models.RuntimeMetadata(_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "spark",), + _task_models.RuntimeMetadata( + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "spark", + ), timeout, _literal_models.RetryStrategy(retries), interruptible, @@ -121,7 +125,8 @@ def add_inputs(self, inputs): self.interface.inputs.update(inputs) def _get_container_definition( - self, environment=None, + self, + environment=None, ): """ :rtype: Container diff --git a/flytekit/common/tasks/hive_task.py b/flytekit/common/tasks/hive_task.py index 6023bf0ec3..6be5db70f9 100644 --- a/flytekit/common/tasks/hive_task.py +++ b/flytekit/common/tasks/hive_task.py @@ -111,7 +111,9 @@ def _generate_plugin_objects(self, context, inputs_dict): for q in queries_from_task: hive_query = _qubole.HiveQuery( - query=q, timeout_sec=self.metadata.timeout.seconds, retry_count=self.metadata.retries.retries, + query=q, + timeout_sec=self.metadata.timeout.seconds, + retry_count=self.metadata.retries.retries, ) # TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are @@ -121,7 +123,12 @@ def _generate_plugin_objects(self, context, inputs_dict): query_collection = _qubole.HiveQueryCollection([hive_query]) plugin_objects.append( - _qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags, query_collection=query_collection,) + _qubole.QuboleHiveJob( + hive_query, + self._cluster_label, + self._tags, + query_collection=query_collection, + ) ) return plugin_objects @@ -145,11 +152,13 @@ def _validate_task_parameters(cluster_label, tags): ) if len(tags) > ALLOWED_TAGS_COUNT: raise _FlyteValueException( - len(tags), "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), + len(tags), + "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), ) if not all(len(tag) for tag in tags): raise _FlyteValueException( - tags, "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), + tags, + "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), ) @staticmethod @@ -190,7 +199,8 @@ def _produce_dynamic_job_spec(self, context, inputs): # Create output bindings always - this has to happen after user code has run output_bindings = [ _literal_models.Binding( - var=name, binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), + var=name, + binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), ) for name, b in _six.iteritems(outputs_dict) ] @@ -203,7 +213,11 @@ def _produce_dynamic_job_spec(self, context, inputs): i += 1 dynamic_job_spec = _dynamic_job.DynamicJobSpec( - min_successes=len(nodes), tasks=tasks, nodes=nodes, outputs=output_bindings, subworkflows=[], + min_successes=len(nodes), + tasks=tasks, + nodes=nodes, + outputs=output_bindings, + subworkflows=[], ) return dynamic_job_spec @@ -263,7 +277,9 @@ class SdkHiveJob(_base_task.SdkTask): """ def __init__( - self, hive_job, metadata, + self, + hive_job, + metadata, ): """ :param _qubole.QuboleHiveJob hive_job: Hive job spec diff --git a/flytekit/common/tasks/presto_task.py b/flytekit/common/tasks/presto_task.py index 1bf8d0e51c..47e198494d 100644 --- a/flytekit/common/tasks/presto_task.py +++ b/flytekit/common/tasks/presto_task.py @@ -69,7 +69,10 @@ def __init__( ) presto_query = _presto_models.PrestoQuery( - routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, + routing_group=routing_group or "", + catalog=catalog or "", + schema=schema or "", + statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit @@ -99,7 +102,10 @@ def __init__( ) super(SdkPrestoTask, self).__init__( - _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), + _constants.SdkTaskType.PRESTO_TASK, + metadata, + i, + _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs diff --git a/flytekit/common/tasks/raw_container.py b/flytekit/common/tasks/raw_container.py index 9f8a5ccc1b..c5437477a9 100644 --- a/flytekit/common/tasks/raw_container.py +++ b/flytekit/common/tasks/raw_container.py @@ -158,7 +158,9 @@ def __init__( discoverable, # This needs to have the proper version reflected in it _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, flytekit.__version__, "python", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + flytekit.__version__, + "python", ), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), diff --git a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py index e1e235df09..356f933c3e 100644 --- a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py +++ b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py @@ -41,7 +41,8 @@ def __init__( """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( - algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, + algorithm_specification=algorithm_specification, + training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training @@ -52,7 +53,9 @@ def __init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -64,7 +67,8 @@ def __init__( interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( - type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), description="", + type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), + description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType( @@ -89,7 +93,8 @@ def __init__( "model": _interface_model.Variable( type=_idl_types.LiteralType( blob=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), description="", diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py index 518816cf2b..fd4d0a3ce1 100644 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -63,7 +63,8 @@ def __init__( inputs.update( { "hyperparameter_tuning_job_config": _interface_model.Variable( - HyperparameterTuningJobConfig.to_flyte_literal_type(), "", + HyperparameterTuningJobConfig.to_flyte_literal_type(), + "", ), } ) @@ -80,7 +81,9 @@ def __init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -95,7 +98,8 @@ def __init__( "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), description="", diff --git a/flytekit/common/tasks/sdk_dynamic.py b/flytekit/common/tasks/sdk_dynamic.py index e1ffc6f861..2a44111be5 100644 --- a/flytekit/common/tasks/sdk_dynamic.py +++ b/flytekit/common/tasks/sdk_dynamic.py @@ -81,7 +81,9 @@ def _create_array_job(self, inputs_prefix): :rtype: _array_job.ArrayJob """ return _array_job.ArrayJob( - parallelism=self._max_concurrency if self._max_concurrency else 0, size=1, min_successes=1, + parallelism=self._max_concurrency if self._max_concurrency else 0, + size=1, + min_successes=1, ) @staticmethod @@ -137,7 +139,9 @@ def _produce_dynamic_job_spec(self, context, inputs): _literal_models.Binding( var=name, binding=_interface.BindingData.from_python_std( - b.sdk_type.to_flyte_literal_type(), b.raw_value, upstream_nodes=upstream_nodes, + b.sdk_type.to_flyte_literal_type(), + b.raw_value, + upstream_nodes=upstream_nodes, ), ) for name, b in _six.iteritems(outputs_dict) @@ -284,7 +288,9 @@ def execute(self, context, inputs): class SdkDynamicTask( - SdkDynamicTaskMixin, _sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.ExtendedSdkType, + SdkDynamicTaskMixin, + _sdk_runnable.SdkRunnableTask, + metaclass=_sdk_bases.ExtendedSdkType, ): """ diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py index 522eccdc7c..d938d51458 100644 --- a/flytekit/common/tasks/sdk_runnable.py +++ b/flytekit/common/tasks/sdk_runnable.py @@ -242,7 +242,12 @@ class SdkRunnableContainer(_task_models.Container, metaclass=_sdk_bases.Extended """ def __init__( - self, command, args, resources, env, config, + self, + command, + args, + resources, + env, + config, ): super(SdkRunnableContainer, self).__init__("", command, args, resources, env or {}, config) @@ -396,7 +401,9 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "python", ), timeout, _literal_models.RetryStrategy(retries), diff --git a/flytekit/common/tasks/sidecar_task.py b/flytekit/common/tasks/sidecar_task.py index bf1f6019ec..51be8a8f40 100644 --- a/flytekit/common/tasks/sidecar_task.py +++ b/flytekit/common/tasks/sidecar_task.py @@ -132,14 +132,17 @@ def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name): pod_spec.containers.extend(final_containers) sidecar_job_plugin = _task_models.SidecarJob( - pod_spec=pod_spec, primary_container_name=primary_container_name, + pod_spec=pod_spec, + primary_container_name=primary_container_name, ).to_flyte_idl() self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin)) class SdkDynamicSidecarTask( - _sdk_dynamic.SdkDynamicTaskMixin, SdkSidecarTask, metaclass=_sdk_bases.ExtendedSdkType, + _sdk_dynamic.SdkDynamicTaskMixin, + SdkSidecarTask, + metaclass=_sdk_bases.ExtendedSdkType, ): """ diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py index 07e4cee420..ba55399382 100644 --- a/flytekit/common/tasks/task.py +++ b/flytekit/common/tasks/task.py @@ -143,7 +143,10 @@ def __call__(self, *args, **input_map): return _nodes.SdkNode( id=None, metadata=_workflow_model.NodeMetadata( - "DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible, + "DEADBEEF", + self.metadata.timeout, + self.metadata.retries, + self.metadata.interruptible, ), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, @@ -216,7 +219,9 @@ def fetch_latest(cls, project, domain, name): named_task = _common_model.NamedEntityIdentifier(project, domain, name) client = _flyte_engine.get_client() task_list, _ = client.list_tasks_paginated( - named_task, limit=1, sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), + named_task, + limit=1, + sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), ) admin_task = task_list[0] if task_list else None @@ -386,7 +391,8 @@ def launch_with_literals( ) assumable_iam_role = _sdk_config.ROLE.get() auth_role = _common_model.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) client = _flyte_engine.get_client() diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index fc4014d9ac..4a4218f07f 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -61,7 +61,10 @@ def to_serializable_cases( def get_serializable_references( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: # TODO: This entire function isn't necessary. We should just return None or raise an Exception or something. # Reference entities should already exist on the Admin control plane - they should not be serialized/registered @@ -114,7 +117,10 @@ def get_serializable_references( def get_serializable_task( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: cp_entity = SdkTask( type=entity.task_type, @@ -152,7 +158,10 @@ def get_serializable_task( def get_serializable_workflow( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: WorkflowBase, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: WorkflowBase, + fast: bool, ) -> FlyteControlPlaneEntity: workflow_id = _identifier_model.Identifier( _identifier_model.ResourceType.WORKFLOW, settings.project, settings.domain, entity.name, settings.version @@ -182,13 +191,17 @@ def get_serializable_workflow( def get_serializable_launch_plan( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: sdk_workflow = get_serializable(entity_mapping, settings, entity.workflow) cp_entity = SdkLaunchPlan( workflow_id=sdk_workflow.id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( - schedule=entity.schedule, notifications=entity.notifications, + schedule=entity.schedule, + notifications=entity.notifications, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, @@ -213,7 +226,10 @@ def get_serializable_launch_plan( def get_serializable_node( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: if entity._flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") @@ -269,7 +285,10 @@ def get_serializable_node( def get_serializable_branch_node( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: # We have to iterate through the blocks to convert the nodes from their current type to SDKNode # TODO this should be cleaned up instead of mutation, we probaby should just create a new object diff --git a/flytekit/common/types/blobs.py b/flytekit/common/types/blobs.py index 206fbf198f..7870cb75bd 100644 --- a/flytekit/common/types/blobs.py +++ b/flytekit/common/types/blobs.py @@ -168,7 +168,8 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException( - string_value, "Cannot create a MultiPartBlob from the provided path " "value.", + string_value, + "Cannot create a MultiPartBlob from the provided path " "value.", ) return cls(_blob_impl.MultiPartBlob.from_string(string_value, mode="rb")) @@ -201,7 +202,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) @classmethod @@ -321,7 +325,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ) @classmethod @@ -391,7 +398,8 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException( - string_value, "Cannot create a MultiPartCSV from the provided path value.", + string_value, + "Cannot create a MultiPartCSV from the provided path value.", ) return cls(_blob_impl.MultiPartBlob.from_string(string_value, format="csv", mode="r")) @@ -428,7 +436,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) @classmethod diff --git a/flytekit/common/types/containers.py b/flytekit/common/types/containers.py index f61a882bd0..9267273b1e 100644 --- a/flytekit/common/types/containers.py +++ b/flytekit/common/types/containers.py @@ -58,12 +58,16 @@ def from_string(cls, string_value): items = _json.loads(string_value) except ValueError: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg="String not parseable to json {}".format(string_value), + _six.text_type, + cls, + additional_msg="String not parseable to json {}".format(string_value), ) if type(items) != list: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg="String is not a list {}".format(string_value), + _six.text_type, + cls, + additional_msg="String is not a list {}".format(string_value), ) # Instead of recursively calling from_string(), we're changing to from_python_std() instead because json @@ -137,7 +141,9 @@ def short_string(self): if len(self.collection.literals) > num_to_print: to_print.append("...") return "{}(len={}, [{}])".format( - type(self).short_class_string(), len(self.collection.literals), ", ".join(to_print), + type(self).short_class_string(), + len(self.collection.literals), + ", ".join(to_print), ) def verbose_string(self): diff --git a/flytekit/common/types/impl/blobs.py b/flytekit/common/types/impl/blobs.py index f69d4215b8..45210da62a 100644 --- a/flytekit/common/types/impl/blobs.py +++ b/flytekit/common/types/impl/blobs.py @@ -349,7 +349,8 @@ def __enter__(self): "path is specified." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, + _uuid.uuid4().hex, + tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, ) self._is_managed = True self._directory.__enter__() @@ -361,7 +362,10 @@ def __enter__(self): self._blobs = [] file_handles = [] for local_path in sorted(self._directory.list_dir(), key=lambda x: x.lower()): - b = Blob(_os.path.join(self.remote_location, _os.path.basename(local_path)), mode=self.mode,) + b = Blob( + _os.path.join(self.remote_location, _os.path.basename(local_path)), + mode=self.mode, + ) b._local_path = local_path file_handles.append(b.__enter__()) self._blobs.append(b) @@ -426,10 +430,13 @@ def create_part(self, name=None): name = _uuid.uuid4().hex if ":" in name or "/" in name: raise _user_exceptions.FlyteAssertion( - name, "Cannot create a part of a multi-part object with ':' or '/' in the name.", + name, + "Cannot create a part of a multi-part object with ':' or '/' in the name.", ) return Blob.create_at_known_location( - _os.path.join(self.remote_location, name), mode=self.mode, format=self.metadata.type.format, + _os.path.join(self.remote_location, name), + mode=self.mode, + format=self.metadata.type.format, ) @_exception_scopes.system_entry_point diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py index 10a28eaa29..04e4109d1c 100644 --- a/flytekit/common/types/impl/schema.py +++ b/flytekit/common/types/impl/schema.py @@ -352,7 +352,9 @@ def write(self, data_frame, coerce_timestamps="us", allow_truncated_timestamps=F try: filename = self._local_dir.get_named_tempfile(_os.path.join(str(self._index).zfill(6))) data_frame.to_parquet( - filename, coerce_timestamps=coerce_timestamps, allow_truncated_timestamps=allow_truncated_timestamps, + filename, + coerce_timestamps=coerce_timestamps, + allow_truncated_timestamps=allow_truncated_timestamps, ) if self._index == len(self._chunks): self._chunks.append(filename) @@ -379,7 +381,8 @@ def __enter__(self): "specify a path when calling this function." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, + _uuid.uuid4().hex, + tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, ) self._is_managed = True self._directory.__enter__() @@ -630,7 +633,12 @@ def from_string(cls, string_value, schema_type=None): @classmethod @_exception_scopes.system_entry_point def create_from_hive_query( - cls, select_query, stage_query=None, schema_to_table_name_map=None, schema_type=None, known_location=None, + cls, + select_query, + stage_query=None, + schema_to_table_name_map=None, + schema_type=None, + known_location=None, ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed @@ -647,7 +655,9 @@ def create_from_hive_query( :return: Schema, Text """ schema_object = cls( - known_location or _data_proxy.Data.get_remote_directory(), mode="wb", schema_type=schema_type, + known_location or _data_proxy.Data.get_remote_directory(), + mode="wb", + schema_type=schema_type, ) if len(schema_object.type.sdk_columns) > 0: @@ -660,13 +670,15 @@ def create_from_hive_query( if sdk_type == _primitives.Float: columnar_clauses.append( "CAST({table_column_name} as double) {schema_name}".format( - table_column_name=schema_to_table_name_map[name], schema_name=name, + table_column_name=schema_to_table_name_map[name], + schema_name=name, ) ) else: columnar_clauses.append( "{table_column_name} as {schema_name}".format( - table_column_name=schema_to_table_name_map[name], schema_name=name, + table_column_name=schema_to_table_name_map[name], + schema_name=name, ) ) columnar_query = ",\n\t\t".join(columnar_clauses) @@ -844,7 +856,8 @@ def get_write_partition_to_hive_table_query( for partition_name, partition_value in partitions: where_clauses.append( "\n\t\t{schema_name} = {value_str} AND ".format( - schema_name=table_to_schema_name_map[partition_name], value_str=partition_value, + schema_name=table_to_schema_name_map[partition_name], + value_str=partition_value, ) ) where_string = "WHERE\n\t\t{where_clauses}".format(where_clauses=" AND\n\t\t".join(where_clauses)) @@ -863,7 +876,9 @@ def get_write_partition_to_hive_table_query( ) return _format_insert_partition_query( - remote_location=self.remote_location, table_name=table_name, partition_string=partition_string, + remote_location=self.remote_location, + table_name=table_name, + partition_string=partition_string, ) def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False): diff --git a/flytekit/common/types/primitives.py b/flytekit/common/types/primitives.py index 77fbf7b313..6fa61c01ab 100644 --- a/flytekit/common/types/primitives.py +++ b/flytekit/common/types/primitives.py @@ -183,7 +183,9 @@ def from_string(cls, string_value): elif string_value == "0" or string_value.lower() == "false": return cls(False) raise _user_exceptions.FlyteTypeException( - _six.text_type, bool, additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), + _six.text_type, + bool, + additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), ) @classmethod @@ -377,7 +379,8 @@ def from_python_std(cls, t_value): raise _user_exceptions.FlyteTypeException(type(t_value), _datetime.datetime, t_value) elif t_value.tzinfo is None: raise _user_exceptions.FlyteValueException( - t_value, "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", + t_value, + "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", ) return cls(t_value) diff --git a/flytekit/common/types/schema.py b/flytekit/common/types/schema.py index 8f0a9555d1..eaf38d1c88 100644 --- a/flytekit/common/types/schema.py +++ b/flytekit/common/types/schema.py @@ -29,7 +29,11 @@ def create(cls): return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) def create_from_hive_query( - cls, select_query, stage_query=None, schema_to_table_name_map=None, known_location=None, + cls, + select_query, + stage_query=None, + schema_to_table_name_map=None, + known_location=None, ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed @@ -150,7 +154,9 @@ def short_string(self): """ :rtype: Text """ - return "{}".format(self.scalar.schema,) + return "{}".format( + self.scalar.schema, + ) def schema_instantiator(columns=None): diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index 42cb946579..3fc4f498fd 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -39,7 +39,13 @@ class only a control plane class. Workflow constructs that rely on local code be """ def __init__( - self, nodes, interface, output_bindings, id, metadata, metadata_defaults, + self, + nodes, + interface, + output_bindings, + id, + metadata, + metadata_defaults, ): """ :param list[flytekit.common.nodes.SdkNode] nodes: @@ -221,7 +227,13 @@ def register(self, project, domain, name, version): try: client = _flyte_engine.get_client() sub_workflows = self.get_sub_workflows() - client.create_workflow(id_to_register, _admin_workflow_model.WorkflowSpec(self, sub_workflows,)) + client.create_workflow( + id_to_register, + _admin_workflow_model.WorkflowSpec( + self, + sub_workflows, + ), + ) self._id = id_to_register self._has_registered = True return str(id_to_register) @@ -240,7 +252,10 @@ def serialize(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ sub_workflows = self.get_sub_workflows() - return _admin_workflow_model.WorkflowSpec(self, sub_workflows,).to_flyte_idl() + return _admin_workflow_model.WorkflowSpec( + self, + sub_workflows, + ).to_flyte_idl() @_exception_scopes.system_entry_point def validate(self): @@ -255,13 +270,15 @@ def create_launch_plan(self, *args, **kwargs): if not (assumable_iam_role or kubernetes_service_account): raise _user_exceptions.FlyteValidationException("No assumable role or service account found") auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) return SdkLaunchPlan( workflow_id=self.id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( - schedule=_schedule_models.Schedule(""), notifications=[], + schedule=_schedule_models.Schedule(""), + notifications=[], ), default_inputs=_interface_models.ParameterMap({}), fixed_inputs=_literal_models.LiteralMap(literals={}), diff --git a/flytekit/common/workflow_execution.py b/flytekit/common/workflow_execution.py index f5064cc136..14695d0e68 100644 --- a/flytekit/common/workflow_execution.py +++ b/flytekit/common/workflow_execution.py @@ -19,7 +19,9 @@ class SdkWorkflowExecution( - _execution_models.Execution, _artifact.ExecutionArtifact, metaclass=_sdk_bases.ExtendedSdkType, + _execution_models.Execution, + _artifact.ExecutionArtifact, + metaclass=_sdk_bases.ExtendedSdkType, ): def __init__(self, *args, **kwargs): super(SdkWorkflowExecution, self).__init__(*args, **kwargs) diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py index 14ef50f668..5836717574 100644 --- a/flytekit/contrib/notebook/tasks.py +++ b/flytekit/contrib/notebook/tasks.py @@ -123,7 +123,9 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "notebook", ), timeout, _literal_models.RetryStrategy(retries), @@ -369,7 +371,13 @@ def _get_container_definition( storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit ) - return _sdk_runnable.SdkRunnableContainer(command=[], args=[], resources=resources, env=environment, config={},) + return _sdk_runnable.SdkRunnableContainer( + command=[], + args=[], + resources=resources, + env=environment, + config={}, + ) def spark_notebook( diff --git a/flytekit/contrib/sensors/impl.py b/flytekit/contrib/sensors/impl.py index 670df86b5d..8276320922 100644 --- a/flytekit/contrib/sensors/impl.py +++ b/flytekit/contrib/sensors/impl.py @@ -96,7 +96,10 @@ def _do_poll(self): """ with self._hive_metastore_client as client: partitions = client.get_partitions_by_filter( - db_name=self._schema, tbl_name=self._table_name, filter=self._partition_filter, max_parts=1, + db_name=self._schema, + tbl_name=self._table_name, + filter=self._partition_filter, + max_parts=1, ) if partitions: return True, None diff --git a/flytekit/contrib/sensors/task.py b/flytekit/contrib/sensors/task.py index 2bacae122f..31e8a8d14e 100644 --- a/flytekit/contrib/sensors/task.py +++ b/flytekit/contrib/sensors/task.py @@ -10,7 +10,8 @@ def _execute_user_code(self, context, inputs): if sensor is not None: if not isinstance(sensor, _Sensor): raise _user_exceptions.FlyteTypeException( - received_type=type(sensor), expected_type=_Sensor, + received_type=type(sensor), + expected_type=_Sensor, ) succeeded = sensor.sense() if not succeeded: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 68a895d0d3..6a1bf0b7dd 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -58,7 +58,7 @@ class TaskMetadata(object): retries: for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times. timeout: the max amount of time for which one execution of this task should be executed for. If the execution will be terminated if the runtime exceeds the given timeout (approximately) - """ + """ cache: bool = False cache_version: str = "" @@ -278,7 +278,9 @@ def get_config(self, settings: SerializationSettings) -> Dict[str, str]: @abstractmethod def dispatch_execute( - self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap, + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, ) -> _literal_models.LiteralMap: """ This method translates Flyte's Type system based input values and invokes the actual call to the executor @@ -332,7 +334,10 @@ def __init__( a dictionary of key/value pairs """ super().__init__( - task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), **kwargs, + task_type=task_type, + name=name, + interface=transform_interface_to_typed_interface(interface), + **kwargs, ) self._python_interface = interface if interface else Interface() self._environment = environment if environment else {} diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 65395e975e..60d536ec88 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -300,7 +300,11 @@ def transform_to_conj_expr( left, left_promises = transform_to_boolexpr(expr.lhs) right, right_promises = transform_to_boolexpr(expr.rhs) return ( - _core_cond.ConjunctionExpression(left_expression=left, right_expression=right, operator=_logical_ops[expr.op],), + _core_cond.ConjunctionExpression( + left_expression=left, + right_expression=right, + operator=_logical_ops[expr.op], + ), merge_promises(*left_promises, *right_promises), ) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 02df45c988..894d44fc0e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -621,8 +621,7 @@ def node_id(self): @property def node(self) -> Node: - """ - """ + """""" return self._node def __repr__(self) -> str: diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index b58b538564..6a574e5576 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -77,7 +77,11 @@ def __init__( raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( - task_type=task_type, name=name, task_config=task_config, security_ctx=sec_ctx, **kwargs, + task_type=task_type, + name=name, + task_config=task_config, + security_ctx=sec_ctx, + **kwargs, ) self._container_image = container_image # TODO(katrogan): Implement resource overrides diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 4030cf22fb..ae260a1fd9 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -176,7 +176,12 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P return create_task_output(vals, self.python_interface) def compile(self, ctx: FlyteContext, *args, **kwargs): - return create_and_link_node(ctx, entity=self, interface=self.python_interface, **kwargs,) + return create_and_link_node( + ctx, + entity=self, + interface=self.python_interface, + **kwargs, + ) def __call__(self, *args, **kwargs): # When a Task is () aka __called__, there are three things we may do: diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 2e2847d9d8..9a16a9dda6 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -170,13 +170,16 @@ def _translate_duration(duration: datetime.timedelta): ) elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), _schedule_models.Schedule.FixedRateUnit.DAY, + int(duration.total_seconds() / _SECONDS_TO_DAYS), + _schedule_models.Schedule.FixedRateUnit.DAY, ) elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), _schedule_models.Schedule.FixedRateUnit.HOUR, + int(duration.total_seconds() / _SECONDS_TO_HOURS), + _schedule_models.Schedule.FixedRateUnit.HOUR, ) else: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, + int(duration.total_seconds() / _SECONDS_TO_MINUTES), + _schedule_models.Schedule.FixedRateUnit.MINUTE, ) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index e396de10f0..c32b1e393c 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -199,7 +199,10 @@ def __init__( def reference_task( - project: str, domain: str, name: str, version: str, + project: str, + domain: str, + name: str, + version: str, ) -> Callable[[Callable[..., Any]], ReferenceTask]: """ A reference task is a pointer to a task that already exists on your Flyte installation. This diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e2f891202e..e17137f8ae 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -423,11 +423,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".txt"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".txt"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -454,11 +457,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".bin"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".bin"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: Type[typing.BinaryIO]) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: typing.BinaryIO, python_type: Type[typing.BinaryIO], expected: LiteralType @@ -484,11 +490,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".bin"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".bin"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: Type[os.PathLike]) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: os.PathLike, python_type: Type[os.PathLike], expected: LiteralType @@ -591,7 +600,11 @@ def _register_default_type_transformers(): TypeEngine.register( SimpleTransformer( - "none", None, _type_models.LiteralType(simple=_type_models.SimpleType.NONE), lambda x: None, lambda x: None, + "none", + None, + _type_models.LiteralType(simple=_type_models.SimpleType.NONE), + lambda x: None, + lambda x: None, ) ) TypeEngine.register(ListTransformer()) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 79cc4c6ecd..69b7b27b4f 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -36,7 +36,11 @@ from flytekit.models.core import workflow as _workflow_model GLOBAL_START_NODE = Node( - id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=None, + id=_common_constants.GLOBAL_INPUT_NODE_ID, + metadata=None, + bindings=[], + upstream_nodes=[], + flyte_entity=None, ) @@ -347,7 +351,10 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P class ImperativeWorkflow(WorkflowBase): def __init__( - self, name: str, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: Optional[bool] = False, + self, + name: str, + failure_policy: Optional[WorkflowFailurePolicy] = None, + interruptible: Optional[bool] = False, ): metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -625,7 +632,11 @@ def compile(self, **kwargs): ) t = self.python_interface.outputs[output_names[0]] b = binding_from_python_std( - ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, + ctx, + output_names[0], + self.interface.outputs[output_names[0]].type, + workflow_outputs, + t, ) bindings.append(b) elif len(output_names) > 1: @@ -637,7 +648,13 @@ def compile(self, **kwargs): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError("A Conditional block (if-else) should always end with an `else_()` clause") t = self.python_interface.outputs[out] - b = binding_from_python_std(ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t,) + b = binding_from_python_std( + ctx, + out, + self.interface.outputs[out].type, + workflow_outputs[i], + t, + ) bindings.append(b) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain @@ -712,7 +729,10 @@ def __init__( def reference_workflow( - project: str, domain: str, name: str, version: str, + project: str, + domain: str, + name: str, + version: str, ) -> Callable[[Callable[..., Any]], ReferenceWorkflow]: """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index 76731f2b8a..ed1dfd9d37 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -367,7 +367,13 @@ def fetch_latest_task(self, named_task): class EngineContext(object): def __init__( - self, execution_date, tmp_dir, stats, execution_id, logging, raw_output_data_prefix=None, + self, + execution_date, + tmp_dir, + stats, + execution_id, + logging, + raw_output_data_prefix=None, ): self._stats = stats self._execution_date = execution_date diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 97ee73198d..55c20e7dbb 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -100,7 +100,8 @@ def get_workflow_execution(self, wf_exec): return FlyteWorkflowExecution(wf_exec) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_workflow_execution(self, wf_exec_id): """ @@ -112,7 +113,8 @@ def fetch_workflow_execution(self, wf_exec_id): ).client.get_execution(wf_exec_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_task(self, task_id): """ @@ -125,7 +127,8 @@ def fetch_task(self, task_id): ).client.get_task(task_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_latest_task(self, named_task): """ @@ -136,12 +139,15 @@ def fetch_latest_task(self, named_task): task_list, _ = _FlyteClientManager( _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.list_tasks_paginated( - named_task, limit=1, sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), + named_task, + limit=1, + sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), ) return task_list[0] if task_list else None @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_launch_plan(self, launch_plan_id): """ @@ -162,7 +168,8 @@ def fetch_launch_plan(self, launch_plan_id): ).client.get_active_launch_plan(named_entity_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_workflow(self, workflow_id): """ @@ -177,7 +184,8 @@ def fetch_workflow(self, workflow_id): class FlyteLaunchPlan(_common_engine.BaseLaunchPlanLauncher): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client @@ -201,11 +209,18 @@ def execute( Deprecated. Use launch instead. """ return self.launch( - project, domain, name, inputs, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + name, + inputs, + notification_overrides, + label_overrides, + annotation_overrides, ) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def launch( self, @@ -261,7 +276,8 @@ def launch( return client.get_execution(exec_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def update(self, identifier, state): """ @@ -275,20 +291,28 @@ def update(self, identifier, state): class FlyteWorkflow(_common_engine.BaseWorkflowExecutor): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: sub_workflows = self.sdk_workflow.get_sub_workflows() - return client.create_workflow(identifier, _workflow_model.WorkflowSpec(self.sdk_workflow, sub_workflows,),) + return client.create_workflow( + identifier, + _workflow_model.WorkflowSpec( + self.sdk_workflow, + sub_workflows, + ), + ) except _user_exceptions.FlyteEntityAlreadyExistsException: pass class FlyteTask(_common_engine.BaseTaskExecutor): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client @@ -363,7 +387,9 @@ def execute(self, inputs, context=None): exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( - "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE, + "SYSTEM:Unknown", + exc_str, + _error_models.ContainerError.Kind.RECOVERABLE, ) ) _logging.error(exc_str) @@ -372,11 +398,14 @@ def execute(self, inputs, context=None): for k, v in _six.iteritems(output_file_dict): _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) _data_proxy.Data.put_data( - temp_dir.name, context["output_prefix"], is_multipart=True, + temp_dir.name, + context["output_prefix"], + is_multipart=True, ) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def launch( self, @@ -420,7 +449,8 @@ def launch( ) assumable_iam_role = _sdk_config.ROLE.get() auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) try: @@ -452,7 +482,8 @@ def launch( class FlyteWorkflowExecution(_common_engine.BaseWorkflowExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_node_executions(self, filters=None): """ @@ -465,7 +496,8 @@ def get_node_executions(self, filters=None): } @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -475,7 +507,8 @@ def sync(self): self.sdk_workflow_execution._closure = client.get_execution(self.sdk_workflow_execution.id).closure @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -498,7 +531,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -521,7 +555,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def terminate(self, cause): """ @@ -534,7 +569,8 @@ def terminate(self, cause): class FlyteNodeExecution(_common_engine.BaseNodeExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_task_executions(self): """ @@ -544,7 +580,8 @@ def get_task_executions(self): return list(_iterate_task_executions(client, self.sdk_node_execution.id)) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_subworkflow_executions(self): """ @@ -553,7 +590,8 @@ def get_subworkflow_executions(self): raise NotImplementedError("Cannot retrieve sub-workflow information from a node execution yet.") @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -576,7 +614,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -599,7 +638,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -611,7 +651,8 @@ def sync(self): class FlyteTaskExecution(_common_engine.BaseTaskExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -634,7 +675,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -657,7 +699,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -667,7 +710,8 @@ def sync(self): self.sdk_task_execution._closure = client.get_task_execution(self.sdk_task_execution.id).closure @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_child_executions(self, filters=None): """ @@ -678,6 +722,8 @@ def get_child_executions(self, filters=None): return { v.id.node_id: v for v in _iterate_node_executions( - client, task_execution_identifier=self.sdk_task_execution.id, filters=filters, + client, + task_execution_identifier=self.sdk_task_execution.id, + filters=filters, ) } diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index 083c8d56d9..2e3c4e187b 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -87,7 +87,8 @@ def execute(self, inputs, context=None): :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(__file__), "unit.config"), internal_overrides={"image": "unit_image"}, + _os.path.join(_os.path.dirname(__file__), "unit.config"), + internal_overrides={"image": "unit_image"}, ): with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: with _data_proxy.LocalWorkingDirectoryContext(working_directory): @@ -146,7 +147,8 @@ def _transform_for_user_output(self, outputs): literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] return { name: _type_helpers.get_sdk_value_from_literal( - literal_map.literals[name], sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), + literal_map.literals[name], + sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), ).to_python_std() for name, variable in _six.iteritems(self.sdk_task.interface.outputs) } @@ -236,7 +238,11 @@ def execute_array_task(root_input_path, task, array_inputs): array_job = _array_job.ArrayJob.from_dict(task.custom) outputs = {} for job_index in _six_moves.range(0, array_job.size): - inputs_path = _os.path.join(root_input_path, _six.text_type(job_index), _sdk_constants.INPUT_FILE_NAME,) + inputs_path = _os.path.join( + root_input_path, + _six.text_type(job_index), + _sdk_constants.INPUT_FILE_NAME, + ) if inputs_path not in array_inputs: raise _system_exception.FlyteSystemAssertion( "dynamic task hasn't generated expected inputs document [{}].".format(inputs_path) diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index 18593a6099..a0babeb9ed 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -131,7 +131,10 @@ def get_data(cls, remote_path, local_path, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) @@ -153,7 +156,10 @@ def put_data(cls, local_path, remote_path, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) @@ -342,7 +348,10 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py index 4f79c02c65..c3ba29d635 100644 --- a/flytekit/interfaces/data/gcs/gcs_proxy.py +++ b/flytekit/interfaces/data/gcs/gcs_proxy.py @@ -126,7 +126,10 @@ def upload_directory(self, local_path, remote_path): GCSProxy._check_binary() cmd = self._maybe_with_gsutil_parallelism( - "cp", "-r", _amend_path(local_path), remote_path if remote_path.endswith("/") else remote_path + "/", + "cp", + "-r", + _amend_path(local_path), + remote_path if remote_path.endswith("/") else remote_path + "/", ) return _update_cmd_config_and_execute(cmd) diff --git a/flytekit/interfaces/stats/taggable.py b/flytekit/interfaces/stats/taggable.py index 32bc6f3ebf..42d1f93ac0 100644 --- a/flytekit/interfaces/stats/taggable.py +++ b/flytekit/interfaces/stats/taggable.py @@ -45,12 +45,18 @@ def extend_tags(self, tags): def pipeline(self): return TaggableStats( - self._client.pipeline(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + self._client.pipeline(), + self._full_prefix, + prefix=self._scope_prefix, + tags=dict(self._tags), ) def __enter__(self): return TaggableStats( - self._client.__enter__(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + self._client.__enter__(), + self._full_prefix, + prefix=self._scope_prefix, + tags=dict(self._tags), ) def get_stats(self, name, copy_tags=True): diff --git a/flytekit/models/admin/task_execution.py b/flytekit/models/admin/task_execution.py index 41d2e85c69..d0a6d4ed2d 100644 --- a/flytekit/models/admin/task_execution.py +++ b/flytekit/models/admin/task_execution.py @@ -7,7 +7,15 @@ class TaskExecutionClosure(_common.FlyteIdlEntity): def __init__( - self, phase, logs, started_at, duration, created_at, updated_at, output_uri=None, error=None, + self, + phase, + logs, + started_at, + duration, + created_at, + updated_at, + output_uri=None, + error=None, ): """ :param int phase: Enum value from flytekit.models.core.execution.TaskExecutionPhase diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index ce05f430ef..07ee1fa007 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -35,7 +35,8 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ return _admin_workflow.WorkflowSpec( - template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + template=self._template.to_flyte_idl(), + sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], ) @classmethod diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 718d43603c..4e4bf99cc7 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -71,7 +71,11 @@ def to_dict(self): :rtype: dict[T, Text] """ return _json_format.MessageToDict( - _array_job.ArrayJob(parallelism=self.parallelism, size=self.size, min_successes=self.min_successes,) + _array_job.ArrayJob( + parallelism=self.parallelism, + size=self.size, + min_successes=self.min_successes, + ) ) @classmethod @@ -82,4 +86,8 @@ def from_dict(cls, idl_dict): """ pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) - return cls(parallelism=pb2_object.parallelism, size=pb2_object.size, min_successes=pb2_object.min_successes,) + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_successes=pb2_object.min_successes, + ) diff --git a/flytekit/models/core/compiler.py b/flytekit/models/core/compiler.py index 6b6e03003c..3246ee22b3 100644 --- a/flytekit/models/core/compiler.py +++ b/flytekit/models/core/compiler.py @@ -105,7 +105,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.compiler_pb2.CompiledWorkflow """ return _compiler_pb2.CompiledWorkflow( - template=self.template.to_flyte_idl(), connections=self.connections.to_flyte_idl(), + template=self.template.to_flyte_idl(), + connections=self.connections.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 54b99e6b21..845b3b4f79 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -165,7 +165,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.condition_pb2.Operand """ return _condition.Operand( - primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + primitive=self.primitive.to_flyte_idl() if self.primitive else None, + var=self.var if self.var else None, ) @classmethod diff --git a/flytekit/models/core/execution.py b/flytekit/models/core/execution.py index 2c20edc48b..5323e0489c 100644 --- a/flytekit/models/core/execution.py +++ b/flytekit/models/core/execution.py @@ -147,7 +147,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.execution_pb2.ExecutionError """ - return _execution_pb2.ExecutionError(code=self.code, message=self.message, error_uri=self.error_uri,) + return _execution_pb2.ExecutionError( + code=self.code, + message=self.message, + error_uri=self.error_uri, + ) @classmethod def from_flyte_idl(cls, p): @@ -155,7 +159,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.ExecutionError p: :rtype: ExecutionError """ - return cls(code=p.code, message=p.message, error_uri=p.error_uri,) + return cls( + code=p.code, + message=p.message, + error_uri=p.error_uri, + ) class TaskLog(_common.FlyteIdlEntity): @@ -219,4 +227,9 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.TaskLog p: :rtype: TaskLog """ - return cls(uri=p.uri, name=p.name, message_format=p.message_format, ttl=p.ttl.ToTimedelta(),) + return cls( + uri=p.uri, + name=p.name, + message_format=p.message_format, + ttl=p.ttl.ToTimedelta(), + ) diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index 0fa6a87594..b65f7d269d 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -82,7 +82,13 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.Identifier p: :rtype: Identifier """ - return cls(resource_type=p.resource_type, project=p.project, domain=p.domain, name=p.name, version=p.version,) + return cls( + resource_type=p.resource_type, + project=p.project, + domain=p.domain, + name=p.name, + version=p.version, + ) class WorkflowExecutionIdentifier(_common_models.FlyteIdlEntity): @@ -121,7 +127,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier """ - return _identifier_pb2.WorkflowExecutionIdentifier(project=self.project, domain=self.domain, name=self.name,) + return _identifier_pb2.WorkflowExecutionIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) @classmethod def from_flyte_idl(cls, p): @@ -129,7 +139,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier p: :rtype: WorkflowExecutionIdentifier """ - return cls(project=p.project, domain=p.domain, name=p.name,) + return cls( + project=p.project, + domain=p.domain, + name=p.name, + ) class NodeExecutionIdentifier(_common_models.FlyteIdlEntity): @@ -160,7 +174,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier """ return _identifier_pb2.NodeExecutionIdentifier( - node_id=self.node_id, execution_id=self.execution_id.to_flyte_idl(), + node_id=self.node_id, + execution_id=self.execution_id.to_flyte_idl(), ) @classmethod @@ -169,7 +184,10 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.NodeExecutionIdentifier p: :rtype: NodeExecutionIdentifier """ - return cls(node_id=p.node_id, execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id),) + return cls( + node_id=p.node_id, + execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id), + ) class TaskExecutionIdentifier(_common_models.FlyteIdlEntity): diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index f15ea7b415..9bec4768a8 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -197,7 +197,9 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_pb2.NodeMetadata """ node_metadata = _core_workflow.NodeMetadata( - name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible, + name=self.name, + retries=self.retries.to_flyte_idl(), + interruptible=self.interruptible, ) if self.timeout: node_metadata.timeout.FromTimedelta(self.timeout) @@ -206,7 +208,9 @@ def to_flyte_idl(self): @classmethod def from_flyte_idl(cls, pb2_object): return cls( - pb2_object.name, pb2_object.timeout.ToTimedelta(), _RetryStrategy.from_flyte_idl(pb2_object.retries), + pb2_object.name, + pb2_object.timeout.ToTimedelta(), + _RetryStrategy.from_flyte_idl(pb2_object.retries), ) @@ -541,7 +545,14 @@ def from_flyte_idl(cls, pb2_object): class WorkflowTemplate(_common.FlyteIdlEntity): def __init__( - self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None, + self, + id, + metadata, + metadata_defaults, + interface, + nodes, + outputs, + failure_node=None, ): """ A workflow template encapsulates all the task, branch, and subworkflow nodes to run a statically analyzable, diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 977495bae6..68ff7d44e8 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -62,7 +62,11 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.execution_pb2.ExecutionMetadata pb2_object: :return: ExecutionMetadata """ - return cls(mode=pb2_object.mode, principal=pb2_object.principal, nesting=pb2_object.nesting,) + return cls( + mode=pb2_object.mode, + principal=pb2_object.principal, + nesting=pb2_object.nesting, + ) class ExecutionSpec(_common_models.FlyteIdlEntity): @@ -203,7 +207,8 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.LiteralMapBlob """ return _execution_pb2.LiteralMapBlob( - values=self.values.to_flyte_idl() if self.values is not None else None, uri=self.uri, + values=self.values.to_flyte_idl() if self.values is not None else None, + uri=self.uri, ) @classmethod @@ -256,7 +261,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.Execution """ return _execution_pb2.Execution( - id=self.id.to_flyte_idl(), closure=self.closure.to_flyte_idl(), spec=self.spec.to_flyte_idl(), + id=self.id.to_flyte_idl(), + closure=self.closure.to_flyte_idl(), + spec=self.spec.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index d824b6ba60..87127bc13c 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -45,7 +45,10 @@ def from_flyte_idl(cls, variable_proto): :param flyteidl.core.interface_pb2.Variable variable_proto: :rtype: Variable """ - return cls(type=_types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description,) + return cls( + type=_types.LiteralType.from_flyte_idl(variable_proto.type), + description=variable_proto.description, + ) class VariableMap(_common.FlyteIdlEntity): diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index ba92b0f7d7..d5efac5857 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -370,7 +370,9 @@ def to_flyte_idl(self): else _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, None, None, None, None) ) return _launch_plan.LaunchPlan( - id=identifier.to_flyte_idl(), spec=self.spec.to_flyte_idl(), closure=self.closure.to_flyte_idl(), + id=identifier.to_flyte_idl(), + spec=self.spec.to_flyte_idl(), + closure=self.closure.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 3b8405e106..c1f871bf21 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -45,7 +45,13 @@ def from_flyte_idl(cls, pb2_object): class Primitive(_common.FlyteIdlEntity): def __init__( - self, integer=None, float_value=None, string_value=None, boolean=None, datetime=None, duration=None, + self, + integer=None, + float_value=None, + string_value=None, + boolean=None, + datetime=None, + duration=None, ): """ This object proxies the primitives supported by the Flyte IDL system. Only one value can be set. @@ -134,7 +140,10 @@ def to_flyte_idl(self): :rtype: flyteidl.core.literals_pb2.Primitive """ primitive = _literals_pb2.Primitive( - integer=self.integer, float_value=self.float_value, string_value=self.string_value, boolean=self.boolean, + integer=self.integer, + float_value=self.float_value, + string_value=self.string_value, + boolean=self.boolean, ) if self.datetime is not None: # Convert to UTC and remove timezone so protobuf behaves. @@ -434,7 +443,8 @@ def to_literal_model(self): """ if self.promise: raise _user_exceptions.FlyteValueException( - self.promise, "Cannot convert BindingData to a Literal because " "it has a promise.", + self.promise, + "Cannot convert BindingData to a Literal because " "it has a promise.", ) elif self.scalar: return Literal(scalar=self.scalar) diff --git a/flytekit/models/matchable_resource.py b/flytekit/models/matchable_resource.py index 8b3e9d144a..64247f5bf5 100644 --- a/flytekit/models/matchable_resource.py +++ b/flytekit/models/matchable_resource.py @@ -73,7 +73,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes """ - return _matchable_resource.ClusterResourceAttributes(attributes=self.attributes,) + return _matchable_resource.ClusterResourceAttributes( + attributes=self.attributes, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -81,7 +83,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes pb2_object: :rtype: ClusterResourceAttributes """ - return cls(attributes=pb2_object.attributes,) + return cls( + attributes=pb2_object.attributes, + ) class ExecutionQueueAttributes(_common.FlyteIdlEntity): @@ -104,7 +108,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes """ - return _matchable_resource.ExecutionQueueAttributes(tags=self.tags,) + return _matchable_resource.ExecutionQueueAttributes( + tags=self.tags, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -112,7 +118,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes pb2_object: :rtype: ExecutionQueueAttributes """ - return cls(tags=pb2_object.tags,) + return cls( + tags=pb2_object.tags, + ) class ExecutionClusterLabel(_common.FlyteIdlEntity): @@ -135,7 +143,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel """ - return _matchable_resource.ExecutionClusterLabel(value=self.value,) + return _matchable_resource.ExecutionClusterLabel( + value=self.value, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -143,7 +153,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel pb2_object: :rtype: ExecutionClusterLabel """ - return cls(value=pb2_object.value,) + return cls( + value=pb2_object.value, + ) class PluginOverride(_common.FlyteIdlEntity): @@ -201,7 +213,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.matchable_resource_pb2.PluginOverride """ return _matchable_resource.PluginOverride( - task_type=self.task_type, plugin_id=self.plugin_id, missing_plugin_behavior=self.missing_plugin_behavior, + task_type=self.task_type, + plugin_id=self.plugin_id, + missing_plugin_behavior=self.missing_plugin_behavior, ) @classmethod diff --git a/flytekit/models/named_entity.py b/flytekit/models/named_entity.py index 80d70aa35c..63dd598d98 100644 --- a/flytekit/models/named_entity.py +++ b/flytekit/models/named_entity.py @@ -57,7 +57,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier """ - return _common.NamedEntityIdentifier(project=self.project, domain=self.domain, name=self.name,) + return _common.NamedEntityIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) @classmethod def from_flyte_idl(cls, p): @@ -65,7 +69,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityIdentifier p: :rtype: Identifier """ - return cls(project=p.project, domain=p.domain, name=p.name,) + return cls( + project=p.project, + domain=p.domain, + name=p.name, + ) class NamedEntityMetadata(_common_models.FlyteIdlEntity): @@ -97,7 +105,10 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityMetadata """ - return _common.NamedEntityMetadata(description=self.description, state=self.state,) + return _common.NamedEntityMetadata( + description=self.description, + state=self.state, + ) @classmethod def from_flyte_idl(cls, p): @@ -105,4 +116,7 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityMetadata p: :rtype: Identifier """ - return cls(description=p.description, state=p.state,) + return cls( + description=p.description, + state=p.state, + ) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index b0c103891b..4550fb6443 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -121,7 +121,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.node_execution_pb2.NodeExecution """ return _node_execution_pb2.NodeExecution( - id=self.id.to_flyte_idl(), input_uri=self.input_uri, closure=self.closure.to_flyte_idl(), + id=self.id.to_flyte_idl(), + input_uri=self.input_uri, + closure=self.closure.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/presto.py b/flytekit/models/presto.py index 04c4b22b41..2f5f998153 100644 --- a/flytekit/models/presto.py +++ b/flytekit/models/presto.py @@ -55,7 +55,10 @@ def to_flyte_idl(self): :rtype: _presto.PrestoQuery """ return _presto.PrestoQuery( - routing_group=self._routing_group, catalog=self._catalog, schema=self._schema, statement=self._statement, + routing_group=self._routing_group, + catalog=self._catalog, + schema=self._schema, + statement=self._statement, ) @classmethod diff --git a/flytekit/models/qubole.py b/flytekit/models/qubole.py index 3464158ad3..2247d6e5fa 100644 --- a/flytekit/models/qubole.py +++ b/flytekit/models/qubole.py @@ -51,7 +51,11 @@ def from_flyte_idl(cls, pb2_object): :param _qubole.HiveQuery pb2_object: :return: HiveQuery """ - return cls(query=pb2_object.query, timeout_sec=pb2_object.timeout_sec, retry_count=pb2_object.retryCount,) + return cls( + query=pb2_object.query, + timeout_sec=pb2_object.timeout_sec, + retry_count=pb2_object.retryCount, + ) class HiveQueryCollection(_common.FlyteIdlEntity): diff --git a/flytekit/models/sagemaker/hpo_job.py b/flytekit/models/sagemaker/hpo_job.py index 966318adab..ea484f26cf 100644 --- a/flytekit/models/sagemaker/hpo_job.py +++ b/flytekit/models/sagemaker/hpo_job.py @@ -18,7 +18,9 @@ class HyperparameterTuningObjective(_common.FlyteIdlEntity): """ def __init__( - self, objective_type: int, metric_name: str, + self, + objective_type: int, + metric_name: str, ): self._objective_type = objective_type self._metric_name = metric_name @@ -44,13 +46,17 @@ def metric_name(self) -> str: def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective: return _pb2_hpo_job.HyperparameterTuningObjective( - objective_type=self.objective_type, metric_name=self._metric_name, + objective_type=self.objective_type, + metric_name=self._metric_name, ) @classmethod def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective): - return cls(objective_type=pb2_object.objective_type, metric_name=pb2_object.metric_name,) + return cls( + objective_type=pb2_object.objective_type, + metric_name=pb2_object.metric_name, + ) class HyperparameterTuningStrategy: diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/flytekit/models/sagemaker/parameter_ranges.py index e9016f2a1d..4328749d72 100644 --- a/flytekit/models/sagemaker/parameter_ranges.py +++ b/flytekit/models/sagemaker/parameter_ranges.py @@ -15,7 +15,10 @@ class HyperparameterScalingType(object): class ContinuousParameterRange(_common.FlyteIdlEntity): def __init__( - self, max_value: float, min_value: float, scaling_type: int, + self, + max_value: float, + min_value: float, + scaling_type: int, ): """ @@ -57,7 +60,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ContinuousParameterRange: """ return _idl_parameter_ranges.ContinuousParameterRange( - max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, ) @classmethod @@ -68,13 +73,18 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ContinuousParameterRan :rtype: ContinuousParameterRange """ return cls( - max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, ) class IntegerParameterRange(_common.FlyteIdlEntity): def __init__( - self, max_value: int, min_value: int, scaling_type: int, + self, + max_value: int, + min_value: int, + scaling_type: int, ): """ :param int max_value: @@ -113,7 +123,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.IntegerParameterRange: :rtype: _idl_parameter_ranges.IntegerParameterRange """ return _idl_parameter_ranges.IntegerParameterRange( - max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, ) @classmethod @@ -124,13 +136,16 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.IntegerParameterRange) :rtype: IntegerParameterRange """ return cls( - max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, ) class CategoricalParameterRange(_common.FlyteIdlEntity): def __init__( - self, values: List[str], + self, + values: List[str], ): """ @@ -163,7 +178,8 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.CategoricalParameterRa class ParameterRanges(_common.FlyteIdlEntity): def __init__( - self, parameter_range_map: Dict[str, _common.FlyteIdlEntity], + self, + parameter_range_map: Dict[str, _common.FlyteIdlEntity], ): self._parameter_range_map = parameter_range_map @@ -188,7 +204,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: ), ) - return _idl_parameter_ranges.ParameterRanges(parameter_range_map=converted,) + return _idl_parameter_ranges.ParameterRanges( + parameter_range_map=converted, + ) @classmethod def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): @@ -206,7 +224,9 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): else: converted[k] = CategoricalParameterRange.from_flyte_idl(v.categorical_parameter_range) - return cls(parameter_range_map=converted,) + return cls( + parameter_range_map=converted, + ) class ParameterRangeOneOf(_common.FlyteIdlEntity): diff --git a/flytekit/models/sagemaker/training_job.py b/flytekit/models/sagemaker/training_job.py index faf4dd7e29..674effcbc4 100644 --- a/flytekit/models/sagemaker/training_job.py +++ b/flytekit/models/sagemaker/training_job.py @@ -96,7 +96,9 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJobResourceConfig) class MetricDefinition(_common.FlyteIdlEntity): def __init__( - self, name: str, regex: str, + self, + name: str, + regex: str, ): self._name = name self._regex = regex @@ -123,7 +125,10 @@ def to_flyte_idl(self) -> _training_job_pb2.MetricDefinition: :rtype: _training_job_pb2.MetricDefinition """ - return _training_job_pb2.MetricDefinition(name=self.name, regex=self.regex,) + return _training_job_pb2.MetricDefinition( + name=self.name, + regex=self.regex, + ) @classmethod def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): @@ -132,7 +137,10 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): :param pb2_object: _training_job_pb2.MetricDefinition :rtype: MetricDefinition """ - return cls(name=pb2_object.name, regex=pb2_object.regex,) + return cls( + name=pb2_object.name, + regex=pb2_object.regex, + ) # TODO Convert to Enum @@ -270,7 +278,9 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.AlgorithmSpecification): class TrainingJob(_common.FlyteIdlEntity): def __init__( - self, algorithm_specification: AlgorithmSpecification, training_job_resource_config: TrainingJobResourceConfig, + self, + algorithm_specification: AlgorithmSpecification, + training_job_resource_config: TrainingJobResourceConfig, ): self._algorithm_specification = algorithm_specification self._training_job_resource_config = training_job_resource_config diff --git a/flytekit/models/security.py b/flytekit/models/security.py index dbe5196a7d..e4ea655e22 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -68,11 +68,17 @@ class OAuth2Client(_common.FlyteIdlEntity): client_secret: str def to_flyte_idl(self) -> _sec.OAuth2Client: - return _sec.OAuth2Client(client_id=self.client_id, client_secret=self.client_secret,) + return _sec.OAuth2Client( + client_id=self.client_id, + client_secret=self.client_secret, + ) @classmethod def from_flyte_idl(cls, pb2_object: _sec.OAuth2Client) -> "OAuth2Client": - return cls(client_id=pb2_object.client_id, client_secret=pb2_object.client_secret,) + return cls( + client_id=pb2_object.client_id, + client_secret=pb2_object.client_secret, + ) @dataclass diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 3f517d0aee..8e34b90cf7 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -100,7 +100,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.tasks_pb2.Resources """ return _core_task.Resources( - requests=[r.to_flyte_idl() for r in self.requests], limits=[r.to_flyte_idl() for r in self.limits], + requests=[r.to_flyte_idl() for r in self.requests], + limits=[r.to_flyte_idl() for r in self.limits], ) @classmethod @@ -172,7 +173,14 @@ def from_flyte_idl(cls, pb2_object): class TaskMetadata(_common.FlyteIdlEntity): def __init__( - self, discoverable, runtime, timeout, retries, interruptible, discovery_version, deprecated_error_message, + self, + discoverable, + runtime, + timeout, + retries, + interruptible, + discovery_version, + deprecated_error_message, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -496,7 +504,10 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.Task """ - return _admin_task.Task(closure=self.closure.to_flyte_idl(), id=self.id.to_flyte_idl(),) + return _admin_task.Task( + closure=self.closure.to_flyte_idl(), + id=self.id.to_flyte_idl(), + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -570,7 +581,13 @@ def from_flyte_idl(cls, pb2_object): class SparkJob(_common.FlyteIdlEntity): def __init__( - self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path, + self, + spark_type, + application_file, + main_class, + spark_conf, + hadoop_conf, + executor_path, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -730,7 +747,10 @@ def to_flyte_idl(self) -> _core_task.IOStrategy: def from_flyte_idl(cls, pb2_object: _core_task.IOStrategy): if pb2_object is None: return None - return cls(download_mode=pb2_object.download_mode, upload_mode=pb2_object.upload_mode,) + return cls( + download_mode=pb2_object.download_mode, + upload_mode=pb2_object.upload_mode, + ) class DataLoadingConfig(_common.FlyteIdlEntity): @@ -931,7 +951,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.task_pb2.Task pb2_object: :rtype: Container """ - return cls(pod_spec=pb2_object.pod_spec, primary_container_name=pb2_object.primary_container_name,) + return cls( + pod_spec=pb2_object.pod_spec, + primary_container_name=pb2_object.primary_container_name, + ) class PyTorchJob(_common.FlyteIdlEntity): @@ -943,11 +966,15 @@ def workers_count(self): return self._workers_count def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask(workers=self.workers_count,) + return _pytorch_task.DistributedPyTorchTrainingTask( + workers=self.workers_count, + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(workers_count=pb2_object.workers,) + return cls( + workers_count=pb2_object.workers, + ) class TensorFlowJob(_common.FlyteIdlEntity): diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 7d68645c7a..69cfb75dd2 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -100,7 +100,13 @@ def from_flyte_idl(cls, proto): class LiteralType(_common.FlyteIdlEntity): def __init__( - self, simple=None, schema=None, collection_type=None, map_value_type=None, blob=None, metadata=None, + self, + simple=None, + schema=None, + collection_type=None, + map_value_type=None, + blob=None, + metadata=None, ): """ Only one of the kwargs may be set. @@ -258,7 +264,10 @@ def __init__(self, failed_node_id: str, message: str): self._failed_node_id = failed_node_id def to_flyte_idl(self) -> _types_pb2.Error: - return _types_pb2.Error(message=self._message, failed_node_id=self._failed_node_id,) + return _types_pb2.Error( + message=self._message, + failed_node_id=self._failed_node_id, + ) @classmethod def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error": diff --git a/flytekit/models/workflow_closure.py b/flytekit/models/workflow_closure.py index 412a52e958..fbf0b08688 100644 --- a/flytekit/models/workflow_closure.py +++ b/flytekit/models/workflow_closure.py @@ -33,7 +33,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_closure_pb2.WorkflowClosure """ return _workflow_closure_pb2.WorkflowClosure( - workflow=self.workflow.to_flyte_idl(), tasks=[t.to_flyte_idl() for t in self.tasks], + workflow=self.workflow.to_flyte_idl(), + tasks=[t.to_flyte_idl() for t in self.tasks], ) @classmethod diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index 5c928bbb60..e20bf5eb36 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -30,7 +30,9 @@ _lazy_loader.LazyLoadPlugin("sidecar", ["k8s-proto>=0.0.3,<1.0.0"], [k8s, flyteidl]) _lazy_loader.LazyLoadPlugin( - "schema", ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], [numpy, pandas], + "schema", + ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], + [numpy, pandas], ) _lazy_loader.LazyLoadPlugin("hive_sensor", ["hmsclient>=0.0.1,<1.0.0"], [hmsclient]) diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index 7ecd5682ab..1e6dcd5e4c 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -44,8 +44,11 @@ def my_task(wf_params, in1, in2, out1, out2): def apply_inputs_wrapper(task): if not isinstance(task, _task.SdkTask): - additional_msg = "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, task.__name__ if hasattr(task, "__name__") else "", + additional_msg = ( + "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, + task.__name__ if hasattr(task, "__name__") else "", + ) ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, @@ -94,8 +97,11 @@ def apply_outputs_wrapper(task): if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask) and not isinstance( task, _nb_tasks.SdkNotebookTask ): - additional_msg = "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, task.__name__ if hasattr(task, "__name__") else "", + additional_msg = ( + "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, + task.__name__ if hasattr(task, "__name__") else "", + ) ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, diff --git a/flytekit/sdk/workflow.py b/flytekit/sdk/workflow.py index d4cf8aedea..34561dc592 100644 --- a/flytekit/sdk/workflow.py +++ b/flytekit/sdk/workflow.py @@ -40,7 +40,10 @@ def __init__(self, value, sdk_type=None, help=None): this value be provided as the SDK might not always be able to infer the correct type. """ super(Output, self).__init__( - "", value, sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, help=help, + "", + value, + sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, + help=help, ) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index 417cfee5c4..00ef36855e 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -105,6 +105,7 @@ def download_distribution(additional_distribution: str, destination: str): # This will overwrite the existing user flyte workflow code in the current working code dir. result = _subprocess.run( - ["tar", "-xvf", _os.path.join(destination, tarfile_name), "-C", destination], stdout=_subprocess.PIPE, + ["tar", "-xvf", _os.path.join(destination, tarfile_name), "-C", destination], + stdout=_subprocess.PIPE, ) result.check_returncode() diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index 7983ecae8b..7149e532af 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -139,7 +139,11 @@ def iterate(): def iterate_registerable_entities_in_order( - pkgs, local_source_root=None, ignore_entities=None, include_entities=None, detect_unreferenced_entities=True, + pkgs, + local_source_root=None, + ignore_entities=None, + include_entities=None, + detect_unreferenced_entities=True, ): """ This function will iterate all discovered entities in the given package list. It will then attempt to diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py index b930a63d97..0ec3a4c982 100644 --- a/flytekit/type_engines/default/flyte.py +++ b/flytekit/type_engines/default/flyte.py @@ -21,7 +21,8 @@ def _load_type_from_tag(tag: str) -> Type: if "." not in tag: raise _user_exceptions.FlyteValueException( - tag, "Protobuf tag must include at least one '.' to delineate package and object name.", + tag, + "Protobuf tag must include at least one '.' to delineate package and object name.", ) module, name = tag.rsplit(".", 1) diff --git a/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py b/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py index 392ebebb92..8ff51b1d15 100644 --- a/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py @@ -49,7 +49,10 @@ class SagemakerBuiltinAlgorithmsTask(PythonTask[SagemakerTrainingJobConfig]): OUTPUT_TYPE = TypeVar("tar.gz") def __init__( - self, name: str, task_config: SagemakerTrainingJobConfig, **kwargs, + self, + name: str, + task_config: SagemakerTrainingJobConfig, + **kwargs, ): """ Args: @@ -75,7 +78,11 @@ def __init__( outputs=kwtypes(model=FlyteFile[self.OUTPUT_TYPE]), ) super().__init__( - self._SAGEMAKER_TRAINING_JOB_TASK, name, interface=interface, task_config=task_config, **kwargs, + self._SAGEMAKER_TRAINING_JOB_TASK, + name, + interface=interface, + task_config=task_config, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: @@ -110,7 +117,10 @@ class SagemakerCustomTrainingTask(PythonFunctionTask[SagemakerTrainingJobConfig] _SAGEMAKER_CUSTOM_TRAINING_JOB_TASK = "sagemaker_custom_training_job_task" def __init__( - self, task_config: SagemakerTrainingJobConfig, task_function: Callable, **kwargs, + self, + task_config: SagemakerTrainingJobConfig, + task_function: Callable, + **kwargs, ): super().__init__( task_config=task_config, @@ -145,7 +155,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: exec_state = FlyteContext.current_context().execution_state if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION: """ - This mode indicates we are actually in a remote execute environment (within sagemaker in this case) + This mode indicates we are actually in a remote execute environment (within sagemaker in this case) """ dist_ctx = DistributedTrainingContext.from_env() else: diff --git a/plugins/hive/flytekitplugins/hive/task.py b/plugins/hive/flytekitplugins/hive/task.py index 31c08fd46f..9552268265 100644 --- a/plugins/hive/flytekitplugins/hive/task.py +++ b/plugins/hive/flytekitplugins/hive/task.py @@ -87,7 +87,11 @@ def tags(self) -> List[str]: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # timeout_sec and retry_count will become deprecated, please use timeout and retry settings on the Task query = HiveQuery(query=self.query_template, timeout_sec=0, retry_count=0) - job = QuboleHiveJob(query=query, cluster_label=self.cluster_label, tags=self.tags,) + job = QuboleHiveJob( + query=query, + cluster_label=self.cluster_label, + tags=self.tags, + ) return MessageToDict(job.to_flyte_idl()) @@ -116,10 +120,10 @@ def __init__( **kwargs, ): """ - Args: - select_query: Singular query that returns a Tabular dataset - stage_query: optional query that should be executed before the actual ``select_query``. This can usually - be used for setting memory or the an alternate execution engine like :ref:`tez`_/ + Args: + select_query: Singular query that returns a Tabular dataset + stage_query: optional query that should be executed before the actual ``select_query``. This can usually + be used for setting memory or the an alternate execution engine like :ref:`tez`_/ """ query_template = HiveSelectTask._HIVE_QUERY_FORMATTER.format( stage_query_str=stage_query or "", select_query_str=select_query.strip().strip(";") diff --git a/plugins/papermill/flytekitplugins/papermill/task.py b/plugins/papermill/flytekitplugins/papermill/task.py index 197681cdc6..c8ffa825ab 100644 --- a/plugins/papermill/flytekitplugins/papermill/task.py +++ b/plugins/papermill/flytekitplugins/papermill/task.py @@ -147,8 +147,8 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: @staticmethod def extract_outputs(nb: str) -> LiteralMap: """ - Parse Outputs from Notebook. - This looks for a cell, with the tag "outputs" to be present. + Parse Outputs from Notebook. + This looks for a cell, with the tag "outputs" to be present. """ with open(nb) as json_file: data = json.load(json_file) @@ -164,9 +164,9 @@ def extract_outputs(nb: str) -> LiteralMap: @staticmethod def render_nb_html(from_nb: str, to: str): """ - render output notebook to html - We are using nbconvert htmlexporter and its classic template - later about how to customize the exporter further. + render output notebook to html + We are using nbconvert htmlexporter and its classic template + later about how to customize the exporter further. """ html_exporter = HTMLExporter() html_exporter.template_name = "classic" @@ -213,10 +213,10 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: def record_outputs(**kwargs) -> str: """ - Use this method to record outputs from a notebook. - It will convert all outputs to a Flyte understandable format. For Files, Directories, please use FlyteFile or - FlyteDirectory, or wrap up your paths in these decorators. - """ + Use this method to record outputs from a notebook. + It will convert all outputs to a Flyte understandable format. For Files, Directories, please use FlyteFile or + FlyteDirectory, or wrap up your paths in these decorators. + """ if kwargs is None: return "" diff --git a/plugins/pod/flytekitplugins/pod/task.py b/plugins/pod/flytekitplugins/pod/task.py index 5449b7df5b..9733d5596c 100644 --- a/plugins/pod/flytekitplugins/pod/task.py +++ b/plugins/pod/flytekitplugins/pod/task.py @@ -33,7 +33,11 @@ def primary_container_name(self) -> str: class PodFunctionTask(PythonFunctionTask[Pod]): def __init__(self, task_config: Pod, task_function: Callable, **kwargs): super(PodFunctionTask, self).__init__( - task_config=task_config, task_type="sidecar", task_function=task_function, task_type_version=1, **kwargs, + task_config=task_config, + task_type="sidecar", + task_function=task_function, + task_type_version=1, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: diff --git a/plugins/spark/flytekitplugins/spark/task.py b/plugins/spark/flytekitplugins/spark/task.py index 7d7253f149..2415392376 100644 --- a/plugins/spark/flytekitplugins/spark/task.py +++ b/plugins/spark/flytekitplugins/spark/task.py @@ -70,7 +70,10 @@ class PysparkFunctionTask(PythonFunctionTask[Spark]): def __init__(self, task_config: Spark, task_function: Callable, **kwargs): super(PysparkFunctionTask, self).__init__( - task_config=task_config, task_type=self._SPARK_TASK_TYPE, task_function=task_function, **kwargs, + task_config=task_config, + task_type=self._SPARK_TASK_TYPE, + task_function=task_function, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: diff --git a/plugins/tests/awssagemaker/test_hpo.py b/plugins/tests/awssagemaker/test_hpo.py index 0762109e51..741f952d8b 100644 --- a/plugins/tests/awssagemaker/test_hpo.py +++ b/plugins/tests/awssagemaker/test_hpo.py @@ -25,13 +25,21 @@ def test_hpo_for_builtin(): name="builtin-trainer", task_config=SagemakerTrainingJobConfig( training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, instance_type="ml-xlarge", volume_size_in_gb=1, + instance_count=1, + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.XGBOOST, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.XGBOOST,), ), ) - hpo = SagemakerHPOTask(name="test", task_config=HPOJob(10, 10, ["x"]), training_task=trainer,) + hpo = SagemakerHPOTask( + name="test", + task_config=HPOJob(10, 10, ["x"]), + training_task=trainer, + ) assert hpo.python_interface.inputs.keys() == { "static_hyperparameters", @@ -59,7 +67,8 @@ def test_hpo_for_builtin(): hyperparameter_tuning_job_config=HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="x", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="x", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, ), @@ -73,7 +82,8 @@ def test_hpoconfig_transformer(): o = HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="x", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="x", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, ) diff --git a/plugins/tests/awssagemaker/test_training.py b/plugins/tests/awssagemaker/test_training.py index 22ff1cb104..b2a6a14ec3 100644 --- a/plugins/tests/awssagemaker/test_training.py +++ b/plugins/tests/awssagemaker/test_training.py @@ -34,9 +34,13 @@ def test_builtin_training(): name="builtin-trainer", task_config=SagemakerTrainingJobConfig( training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, instance_type="ml-xlarge", volume_size_in_gb=1, + instance_count=1, + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.XGBOOST, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.XGBOOST,), ), ) @@ -62,8 +66,13 @@ def test_builtin_training(): def test_custom_training(): @task( task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig(instance_type="ml-xlarge", volume_size_in_gb=1,), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.CUSTOM,), + training_job_resource_config=TrainingJobResourceConfig( + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.CUSTOM, + ), ) ) def my_custom_trainer(x: int) -> int: @@ -91,7 +100,9 @@ def test_distributed_custom_training(): instance_count=2, # Indicates distributed training distributed_protocol=DistributedProtocol.MPI, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.CUSTOM,), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.CUSTOM, + ), ) ) def my_custom_trainer(x: int) -> int: diff --git a/plugins/tests/pod/test_pod.py b/plugins/tests/pod/test_pod.py index e4d86984d3..59adc0313d 100644 --- a/plugins/tests/pod/test_pod.py +++ b/plugins/tests/pod/test_pod.py @@ -14,9 +14,16 @@ def get_pod_spec(): - a_container = V1Container(name="a container",) + a_container = V1Container( + name="a container", + ) a_container.command = ["fee", "fi", "fo", "fum"] - a_container.volume_mounts = [V1VolumeMount(name="volume mount", mount_path="some/where",)] + a_container.volume_mounts = [ + V1VolumeMount( + name="volume mount", + mount_path="some/where", + ) + ] pod_spec = V1PodSpec(restart_policy="OnFailure", containers=[a_container, V1Container(name="another container")]) return pod_spec diff --git a/pyproject.toml b/pyproject.toml index 339d113e41..f7e217c339 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,14 @@ [tool.black] line-length = 120 -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | _build - | build - | dist - )/ -) -''' + +[tool.isort] +profile = "black" +line_length = 120 + +[tool.pytest.ini_options] +norecursedirs = ["common", "workflows", "spark"] +log_cli = true +log_cli_level = 20 + +[tool.coverage.run] +branch = true diff --git a/requirements-spark3.txt b/requirements-spark3.txt index 4b0f8eff11..69aec3fb38 100644 --- a/requirements-spark3.txt +++ b/requirements-spark3.txt @@ -10,30 +10,23 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -54,8 +47,10 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 - # via paramiko +cryptography==3.4.7 + # via + # paramiko + # secretstorage dataclasses-json==0.5.2 # via flytekit decorator==4.4.2 @@ -100,6 +95,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -134,7 +133,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -250,6 +251,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -300,7 +303,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 6820125a79..866325c7c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,30 +10,23 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -54,8 +47,10 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 - # via paramiko +cryptography==3.4.7 + # via + # paramiko + # secretstorage dataclasses-json==0.5.2 # via flytekit decorator==4.4.2 @@ -100,6 +95,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -134,7 +133,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -250,6 +251,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -300,7 +303,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/setup.cfg b/setup.cfg index d2d5096bfb..c5c8bbd7a1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,3 @@ -[isort] -multi_line_output = 3 -include_trailing_comma = True -force_grid_wrap = 0 -use_parentheses = True -ensure_newline_before_comments = True -line_length = 120 - [flake8] max-line-length = 120 extend-ignore = E203, E266, E501, W503, E741 @@ -21,14 +13,5 @@ ignore_missing_imports = True follow_imports = skip cache_dir = /dev/null -[tool:pytest] -norecursedirs = common workflows spark -log_cli = true -log_cli_level = 20 - -[coverage:run] -branch = True - [metadata] license_files = LICENSE - diff --git a/setup.py b/setup.py index bf2e9e08b3..0ea2a69fa8 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ sidecar = ["k8s-proto>=0.0.3,<1.0.0"] schema = ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>2.0.0,<4.0.0"] hive_sensor = ["hmsclient>=0.0.1,<1.0.0"] -notebook = ["papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", "black==19.10b0"] +notebook = ["papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0"] sagemaker = ["sagemaker-training>=3.6.2,<4.0.0"] all_but_spark = sidecar + schema + hive_sensor + notebook + sagemaker diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index b2b90c9e81..b79ec97ca3 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -32,16 +32,28 @@ ) ), types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ), ] @@ -60,7 +72,8 @@ LIST_OF_INTERFACES = [ interface.TypedInterface( - {"a": interface.Variable(t, "description 1")}, {"b": interface.Variable(t, "description 2")}, + {"a": interface.Variable(t, "description 1")}, + {"b": interface.Variable(t, "description 2")}, ) for t in LIST_OF_ALL_LITERAL_TYPES ] @@ -95,7 +108,13 @@ LIST_OF_TASK_METADATA = [ task.TaskMetadata( - discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, + discoverable, + runtime_metadata, + timeout, + retry_strategy, + interruptible, + discovery_version, + deprecated, ) for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated in product( [True, False], @@ -117,7 +136,12 @@ interfaces, {"a": 1, "b": [1, 2, 3], "c": "abc", "d": {"x": 1, "y": 2, "z": 3}}, container=task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), ) for task_metadata, interfaces, resources in product(LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES) @@ -125,7 +149,12 @@ LIST_OF_CONTAINERS = [ task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ) for resources in LIST_OF_RESOURCES ] @@ -137,7 +166,10 @@ (literals.Scalar(primitive=literals.Primitive(float_value=500.0)), 500.0), (literals.Scalar(primitive=literals.Primitive(boolean=True)), True), (literals.Scalar(primitive=literals.Primitive(string_value="hello")), "hello"), - (literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), timedelta(seconds=5),), + ( + literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), + timedelta(seconds=5), + ), (literals.Scalar(none_type=literals.Void()), None), ( literals.Scalar( diff --git a/tests/flytekit/common/workflows/python.py b/tests/flytekit/common/workflows/python.py index ff8f41f29d..a0d423b86c 100644 --- a/tests/flytekit/common/workflows/python.py +++ b/tests/flytekit/common/workflows/python.py @@ -29,7 +29,10 @@ def sum_non_none(workflow_parameters, value1_to_print, value2_to_print, out): @inputs( - value1_to_add=Types.Integer, value2_to_add=Types.Integer, value3_to_add=Types.Integer, value4_to_add=Types.Integer, + value1_to_add=Types.Integer, + value2_to_add=Types.Integer, + value3_to_add=Types.Integer, + value4_to_add=Types.Integer, ) @outputs(out=Types.Integer) @python_task(cache_version="1") diff --git a/tests/flytekit/common/workflows/sagemaker.py b/tests/flytekit/common/workflows/sagemaker.py index b0fe7d36d6..044bba29f2 100644 --- a/tests/flytekit/common/workflows/sagemaker.py +++ b/tests/flytekit/common/workflows/sagemaker.py @@ -54,7 +54,9 @@ builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -85,7 +87,8 @@ class SageMakerHPO(object): default=_HyperparameterTuningJobConfig( tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="validation:error", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, ), @@ -105,7 +108,10 @@ class SageMakerHPO(object): sagemaker_hpo_lp = SageMakerHPO.create_launch_plan() with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): print("Printing WF definition") diff --git a/tests/flytekit/common/workflows/sidecar.py b/tests/flytekit/common/workflows/sidecar.py index aa01094cf6..ff0d4e7484 100644 --- a/tests/flytekit/common/workflows/sidecar.py +++ b/tests/flytekit/common/workflows/sidecar.py @@ -10,10 +10,16 @@ def generate_pod_spec_for_task(): pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container(name="secondary", image="alpine",) + secondary_container = generated_pb2.Container( + name="secondary", + image="alpine", + ) secondary_container.command.extend(["/bin/sh"]) secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount(name="shared-data", mountPath="/data",) + shared_volume_mount = generated_pb2.VolumeMount( + name="shared-data", + mountPath="/data", + ) secondary_container.volumeMounts.extend([shared_volume_mount]) primary_container = generated_pb2.Container(name="primary") @@ -23,7 +29,11 @@ def generate_pod_spec_for_task(): [ generated_pb2.Volume( name="shared-data", - volumeSource=generated_pb2.VolumeSource(emptyDir=generated_pb2.EmptyDirVolumeSource(medium="Memory",)), + volumeSource=generated_pb2.VolumeSource( + emptyDir=generated_pb2.EmptyDirVolumeSource( + medium="Memory", + ) + ), ) ] ) @@ -32,7 +42,8 @@ def generate_pod_spec_for_task(): @sidecar_task( - pod_spec=generate_pod_spec_for_task(), primary_container_name="primary", + pod_spec=generate_pod_spec_for_task(), + primary_container_name="primary", ) def a_sidecar_task(wfparams): while not os.path.isfile("/data/message.txt"): diff --git a/tests/flytekit/common/workflows/simple.py b/tests/flytekit/common/workflows/simple.py index 2e21a7f300..f264fe39bf 100644 --- a/tests/flytekit/common/workflows/simple.py +++ b/tests/flytekit/common/workflows/simple.py @@ -105,4 +105,10 @@ class SimpleWorkflow(object): c = subtract_one(a=input_1) d = write_special_types() - e = read_special_types(a=d.outputs.a, b=d.outputs.b, c=d.outputs.c, d=d.outputs.d, e=d.outputs.e,) + e = read_special_types( + a=d.outputs.a, + b=d.outputs.b, + c=d.outputs.c, + d=d.outputs.d, + e=d.outputs.e, + ) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 4813676fe1..bea2c9f333 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -30,7 +30,8 @@ def test_single_step_entrypoint_in_proc(): ): with _utils.AutoDeletingTempDir("in") as input_dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -46,7 +47,8 @@ def test_single_step_entrypoint_in_proc(): ) p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), @@ -63,7 +65,8 @@ def test_single_step_entrypoint_out_of_proc(): ): with _utils.AutoDeletingTempDir("in") as input_dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -78,7 +81,8 @@ def test_single_step_entrypoint_out_of_proc(): assert result.exit_code == 0 p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), @@ -95,7 +99,8 @@ def test_arrayjob_entrypoint_in_proc(): ): with _utils.AutoDeletingTempDir("dir") as dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_dir = os.path.join(dir.name, "1") @@ -128,7 +133,8 @@ def test_arrayjob_entrypoint_in_proc(): raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl( _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), ) ), _type_map_from_variable_map(_task_defs.add_one.interface.outputs), diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py index 5813d18bf0..c75427f35d 100644 --- a/tests/flytekit/unit/cli/auth/test_discovery.py +++ b/tests/flytekit/unit/cli/auth/test_discovery.py @@ -11,7 +11,9 @@ def test_get_authorization_endpoints(): auth_endpoint = "http://flyte-admin.com/authorization" token_endpoint = "http://flyte-admin.com/token" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + responses.GET, + discovery_url, + json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -26,7 +28,9 @@ def test_get_authorization_endpoints_relative(): auth_endpoint = "/authorization" token_endpoint = "/token" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + responses.GET, + discovery_url, + json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -38,7 +42,9 @@ def test_get_authorization_endpoints_relative(): def test_get_authorization_endpoints_missing_authorization_endpoint(): discovery_url = "http://flyte-admin.com/discovery" responses.add( - responses.GET, discovery_url, json={"token_endpoint": "http://flyte-admin.com/token"}, + responses.GET, + discovery_url, + json={"token_endpoint": "http://flyte-admin.com/token"}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -50,7 +56,9 @@ def test_get_authorization_endpoints_missing_authorization_endpoint(): def test_get_authorization_endpoints_missing_token_endpoint(): discovery_url = "http://flyte-admin.com/discovery" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, + responses.GET, + discovery_url, + json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) diff --git a/tests/flytekit/unit/cli/pyflyte/conftest.py b/tests/flytekit/unit/cli/pyflyte/conftest.py index a829db158d..723fb4878b 100644 --- a/tests/flytekit/unit/cli/pyflyte/conftest.py +++ b/tests/flytekit/unit/cli/pyflyte/conftest.py @@ -21,7 +21,10 @@ def _fake_module_load(names): @pytest.yield_fixture( scope="function", params=[ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), "/foo/bar", None, ], diff --git a/tests/flytekit/unit/cli/test_cli_helpers.py b/tests/flytekit/unit/cli/test_cli_helpers.py index 8a6c6e0263..9188b2fc3a 100644 --- a/tests/flytekit/unit/cli/test_cli_helpers.py +++ b/tests/flytekit/unit/cli/test_cli_helpers.py @@ -111,7 +111,8 @@ def test_hydrate_workflow_template(): id="launchplan_ref", workflow_node=_core_workflow_pb2.WorkflowNode( launchplan_ref=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.LAUNCH_PLAN, project="project2", + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project2", ) ), ) @@ -121,7 +122,9 @@ def test_hydrate_workflow_template(): id="sub_workflow_ref", workflow_node=_core_workflow_pb2.WorkflowNode( sub_workflow_ref=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", ) ), ) @@ -245,7 +248,13 @@ def test_hydrate_registration_parameters__launch_plan_already_set(): ) ), ) - identifier, entity = hydrate_registration_parameters(LAUNCH_PLAN, "project", "domain", "12345", launch_plan,) + identifier, entity = hydrate_registration_parameters( + LAUNCH_PLAN, + "project", + "domain", + "12345", + launch_plan, + ) assert identifier == _identifier_pb2.Identifier( resource_type=_identifier_pb2.LAUNCH_PLAN, project="project2", @@ -258,16 +267,30 @@ def test_hydrate_registration_parameters__launch_plan_already_set(): def test_hydrate_registration_parameters__launch_plan_nothing_set(): launch_plan = _launch_plan_pb2.LaunchPlan( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.LAUNCH_PLAN, name="lp_name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, + name="lp_name", + ), spec=_launch_plan_pb2.LaunchPlanSpec( - workflow_id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="workflow_name",) + workflow_id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + name="workflow_name", + ) ), ) identifier, entity = hydrate_registration_parameters( - _identifier_pb2.LAUNCH_PLAN, "project", "domain", "12345", launch_plan, + _identifier_pb2.LAUNCH_PLAN, + "project", + "domain", + "12345", + launch_plan, ) assert identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.LAUNCH_PLAN, project="project", domain="domain", name="lp_name", version="12345", + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project", + domain="domain", + name="lp_name", + version="12345", ) assert entity.spec.workflow_id == _identifier_pb2.Identifier( resource_type=_identifier_pb2.WORKFLOW, @@ -282,7 +305,11 @@ def test_hydrate_registration_parameters__task_already_set(): task = _task_pb2.TaskSpec( template=_core_task_pb2.TaskTemplate( id=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.TASK, + project="project2", + domain="domain2", + name="name", + version="abc", ), ) ) @@ -290,7 +317,11 @@ def test_hydrate_registration_parameters__task_already_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.TASK, + project="project2", + domain="domain2", + name="name", + version="abc", ) == entity.template.id ) @@ -299,14 +330,21 @@ def test_hydrate_registration_parameters__task_already_set(): def test_hydrate_registration_parameters__task_nothing_set(): task = _task_pb2.TaskSpec( template=_core_task_pb2.TaskTemplate( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.TASK, name="name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, + name="name", + ), ) ) identifier, entity = hydrate_registration_parameters(_identifier_pb2.TASK, "project", "domain", "12345", task) assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="name", version="12345", + resource_type=_identifier_pb2.TASK, + project="project", + domain="domain", + name="name", + version="12345", ) == entity.template.id ) @@ -330,7 +368,11 @@ def test_hydrate_registration_parameters__workflow_already_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", + name="name", + version="abc", ) == entity.template.id ) @@ -339,7 +381,10 @@ def test_hydrate_registration_parameters__workflow_already_set(): def test_hydrate_registration_parameters__workflow_nothing_set(): workflow = _workflow_pb2.WorkflowSpec( template=_core_workflow_pb2.WorkflowTemplate( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + name="name", + ), nodes=[ _core_workflow_pb2.Node( id="foo", @@ -356,13 +401,21 @@ def test_hydrate_registration_parameters__workflow_nothing_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="name", version="12345", + resource_type=_identifier_pb2.WORKFLOW, + project="project", + domain="domain", + name="name", + version="12345", ) == entity.template.id ) assert len(workflow.template.nodes) == 1 assert workflow.template.nodes[0].task_node.reference_id == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="task1", version="12345", + resource_type=_identifier_pb2.TASK, + project="project", + domain="domain", + name="task1", + version="12345", ) @@ -401,5 +454,9 @@ def test_hydrate_registration_parameters__subworkflows(): ) assert entity.sub_workflows[0].id == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="subworkflow", version="12345", + resource_type=_identifier_pb2.WORKFLOW, + project="project", + domain="domain", + name="subworkflow", + version="12345", ) diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index da63efa1dc..0fe7e54555 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -34,7 +34,8 @@ def my_task(wf_params, a, b): def test__extract_files(load_mock): t = get_sample_task() with TemporaryConfiguration( - "", internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, + "", + internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): task_spec = t.serialize() @@ -52,7 +53,11 @@ def test__extract_files(load_mock): @_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") def test__extract_files_with_unspecified_resource_type(load_mock): id = _core_identifier.Identifier( - _core_identifier.ResourceType.UNSPECIFIED, "myproject", "development", "name", "v", + _core_identifier.ResourceType.UNSPECIFIED, + "myproject", + "development", + "name", + "v", ) load_mock.return_value = id.to_flyte_idl() diff --git a/tests/flytekit/unit/common_tests/exceptions/test_system.py b/tests/flytekit/unit/common_tests/exceptions/test_system.py index 4610703efa..d53ed00f6c 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_system.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_system.py @@ -48,7 +48,9 @@ def test_flyte_entrypoint_not_loadable_exception(): try: raise system.FlyteEntrypointNotLoadable( - "fake.module", task_name="secret_task", additional_msg="Shouldn't have used a fake module!", + "fake.module", + task_name="secret_task", + additional_msg="Shouldn't have used a fake module!", ) except Exception as e: assert ( diff --git a/tests/flytekit/unit/common_tests/exceptions/test_user.py b/tests/flytekit/unit/common_tests/exceptions/test_user.py index f8851bb122..e3b3fbd319 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_user.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_user.py @@ -22,7 +22,10 @@ def test_flyte_type_exception(): try: raise user.FlyteTypeException( - "int", ("list", "set"), received_value=1, additional_msg="That was a bad idea!", + "int", + ("list", "set"), + received_value=1, + additional_msg="That was a bad idea!", ) except Exception as e: assert ( diff --git a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py b/tests/flytekit/unit/common_tests/tasks/test_execution_params.py index 76509d54f0..cf634bfa96 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py +++ b/tests/flytekit/unit/common_tests/tasks/test_execution_params.py @@ -29,7 +29,9 @@ def test_secrets_manager_get_file(): with pytest.raises(ValueError): sec.get_secrets_file("", "x") assert sec.get_secrets_file("group", "test") == os.path.join( - secrets.SECRETS_DEFAULT_DIR.get(), "group", f"{secrets.SECRETS_FILE_PREFIX.get()}test", + secrets.SECRETS_DEFAULT_DIR.get(), + "group", + f"{secrets.SECRETS_FILE_PREFIX.get()}test", ) diff --git a/tests/flytekit/unit/common_tests/tasks/test_task.py b/tests/flytekit/unit/common_tests/tasks/test_task.py index 9212211b0e..e33a757412 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_task.py +++ b/tests/flytekit/unit/common_tests/tasks/test_task.py @@ -21,7 +21,8 @@ def test_fetch_latest(mock_url, mock_client_manager): mock_url.get.return_value = "localhost" admin_task = _task_models.Task( - _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), _MagicMock(), + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + _MagicMock(), ) mock_client = _MagicMock() mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task], "")) @@ -58,7 +59,10 @@ def my_task(wf_params, a, b): def test_task_serialization(): t = get_sample_task() with TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = t.serialize() diff --git a/tests/flytekit/unit/common_tests/test_interface.py b/tests/flytekit/unit/common_tests/test_interface.py index e3eb65f081..b6627c1a6b 100644 --- a/tests/flytekit/unit/common_tests/test_interface.py +++ b/tests/flytekit/unit/common_tests/test_interface.py @@ -21,19 +21,23 @@ def test_binding_data_primitive_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), "abc", + primitives.Float.to_flyte_literal_type(), + "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), [1.0, 2.0, 3.0], + primitives.Float.to_flyte_literal_type(), + [1.0, 2.0, 3.0], ) def test_binding_data_list_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), ["abc", "cde"], upstream_nodes=upstream_nodes, + containers.List(primitives.String).to_flyte_literal_type(), + ["abc", "cde"], + upstream_nodes=upstream_nodes, ) assert len(upstream_nodes) == 0 @@ -47,7 +51,8 @@ def test_binding_data_list_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), "abc", + containers.List(primitives.String).to_flyte_literal_type(), + "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py index 495aa0d3c4..92943d37f1 100644 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ b/tests/flytekit/unit/common_tests/test_launch_plan.py @@ -18,7 +18,10 @@ def test_default_assumable_iam_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ) ): workflow_to_test = _workflow.workflow( {}, @@ -45,7 +48,10 @@ def test_hard_coded_assumable_iam_role(): def test_default_deprecated_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/deprecated_local.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/deprecated_local.config", + ) ): workflow_to_test = _workflow.workflow( {}, @@ -150,7 +156,11 @@ def test_schedule(schedule, cron_expression, cron_schedule): "default_input": _workflow.Input(_types.Types.Integer, default=5), }, ) - lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5}, schedule=schedule, role="what",) + lp = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + schedule=schedule, + role="what", + ) assert lp.entity_metadata.schedule.kickoff_time_input_arg is None assert lp.entity_metadata.schedule.cron_expression == cron_expression assert lp.entity_metadata.schedule.cron_schedule == cron_schedule @@ -180,7 +190,8 @@ def test_schedule_pointing_to_datetime(): }, ) lp = workflow_to_test.create_launch_plan( - schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), role="what", + schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), + role="what", ) assert lp.entity_metadata.schedule.kickoff_time_input_arg == "required_input" assert lp.entity_metadata.schedule.cron_expression == "* * ? * * *" @@ -310,10 +321,16 @@ def test_serialize(): }, ) workflow_to_test.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") - lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5}, role="iam_role",) + lp = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + role="iam_role", + ) with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = lp.serialize() @@ -360,9 +377,12 @@ def test_raw_data_output_prefix(): }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, raw_output_data_prefix="s3://bucket-name", + fixed_inputs={"required_input": 5}, + raw_output_data_prefix="s3://bucket-name", ) assert lp.raw_output_data_config.output_location_prefix == "s3://bucket-name" - lp2 = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5},) + lp2 = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + ) assert lp2.raw_output_data_config.output_location_prefix == "" diff --git a/tests/flytekit/unit/common_tests/test_nodes.py b/tests/flytekit/unit/common_tests/test_nodes.py index 7d2b97e81c..8dd90097c6 100644 --- a/tests/flytekit/unit/common_tests/test_nodes.py +++ b/tests/flytekit/unit/common_tests/test_nodes.py @@ -26,7 +26,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -58,7 +59,8 @@ def testy_test(wf_params, a, b): [n], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), ) ], _core_workflow_models.NodeMetadata("abc2", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -93,7 +95,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc3", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -131,7 +134,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc4", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -191,7 +195,11 @@ def testy_test(wf_params, a, b): # Test floating ID testy_test._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", + _identifier.ResourceType.TASK, + "new_project", + "new_domain", + "new_name", + "new_version", ) assert n.reference_id.project == "new_project" assert n.reference_id.domain == "new_domain" @@ -219,7 +227,8 @@ class test_workflow(object): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -267,7 +276,11 @@ class test_workflow(object): # Test floating ID lp._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", + _identifier.ResourceType.TASK, + "new_project", + "new_domain", + "new_name", + "new_version", ) assert n.launchplan_ref.project == "new_project" assert n.launchplan_ref.domain == "new_domain" diff --git a/tests/flytekit/unit/common_tests/test_schedules.py b/tests/flytekit/unit/common_tests/test_schedules.py index 0009b424dc..4e6f231307 100644 --- a/tests/flytekit/unit/common_tests/test_schedules.py +++ b/tests/flytekit/unit/common_tests/test_schedules.py @@ -92,7 +92,8 @@ def test_cron_schedule_schedule_validation(schedule): @_pytest.mark.parametrize( - "schedule", ["foo", "* *"], + "schedule", + ["foo", "* *"], ) def test_cron_schedule_schedule_validation_invalid(schedule): with _pytest.raises(_user_exceptions.FlyteAssertion): diff --git a/tests/flytekit/unit/common_tests/test_translator.py b/tests/flytekit/unit/common_tests/test_translator.py index 07a9437826..70927fbfdf 100644 --- a/tests/flytekit/unit/common_tests/test_translator.py +++ b/tests/flytekit/unit/common_tests/test_translator.py @@ -58,7 +58,10 @@ def my_wf(a: int, b: str) -> (int, str): sdk_task = get_serializable(OrderedDict(), serialization_settings, t1, True) assert "pyflyte-execute" in sdk_task.container.args - lp = LaunchPlan.create("testlp", my_wf,) + lp = LaunchPlan.create( + "testlp", + my_wf, + ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert sdk_lp.id.name == "testlp" diff --git a/tests/flytekit/unit/common_tests/test_workflow.py b/tests/flytekit/unit/common_tests/test_workflow.py index ae1ef6df34..13a9e65d8e 100644 --- a/tests/flytekit/unit/common_tests/test_workflow.py +++ b/tests/flytekit/unit/common_tests/test_workflow.py @@ -137,7 +137,8 @@ class my_workflow(object): a = _local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) w = _local_workflow.build_sdk_workflow_from_metaclass( - my_workflow, on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, + my_workflow, + on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, ) assert w.should_create_default_launch_plan is True @@ -223,7 +224,9 @@ def my_list_task(wf_params, a, b): wf_out = [ _local_workflow.Output( - "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], + "nested_out", + [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], + sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] @@ -299,7 +302,8 @@ def my_task(wf_params, a, b): [], [ _literals.Binding( - "a", interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], None, @@ -349,7 +353,9 @@ def my_list_task(wf_params, a, b): wf_out = [ _local_workflow.Output( - "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], + "nested_out", + [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], + sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] @@ -369,6 +375,9 @@ class MyWorkflow(object): input_1 = promise.Input("input_1", primitives.Integer) input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") - w = build_sdk_workflow_from_metaclass(MyWorkflow, disable_default_launch_plan=True,) + w = build_sdk_workflow_from_metaclass( + MyWorkflow, + disable_default_launch_plan=True, + ) assert w.should_create_default_launch_plan is False diff --git a/tests/flytekit/unit/common_tests/test_workflow_promote.py b/tests/flytekit/unit/common_tests/test_workflow_promote.py index bbfd5e627e..cdd1231107 100644 --- a/tests/flytekit/unit/common_tests/test_workflow_promote.py +++ b/tests/flytekit/unit/common_tests/test_workflow_promote.py @@ -39,7 +39,12 @@ def get_sample_container(): resources = _task_model.Resources(requests=[cpu_resource], limits=[cpu_resource]) return _task_model.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {}, + {}, ) diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py index 90733564d7..de5bba4a46 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_schema.py @@ -164,7 +164,9 @@ def test_fetch(value_type_pair): with _utils.AutoDeletingTempDir("test2") as local_dir: schema_obj = _schema_impl.Schema.fetch( - tmpdir.name, local_path=local_dir.get_named_tempfile("schema_test"), schema_type=schema_type, + tmpdir.name, + local_path=local_dir.get_named_tempfile("schema_test"), + schema_type=schema_type, ) with schema_obj as reader: for df in reader.iter_chunks(): @@ -283,7 +285,8 @@ def uuid4(self): ) SET LOCATION 's3://my_fixed_path/'; """ query = df.get_write_partition_to_hive_table_query( - "some_table", partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), + "some_table", + partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), ) full_query = " ".join(full_query.split()) query = " ".join(query.split()) @@ -299,7 +302,8 @@ def test_partial_column_read(): writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + a.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with b as reader: df = reader.read(columns=["b"]) @@ -322,7 +326,8 @@ def single_dataframe(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: df2 = reader.read() @@ -338,7 +343,8 @@ def list_of_dataframes(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: actual = [] @@ -365,7 +371,8 @@ def empty_list(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: df = reader.read() @@ -474,7 +481,8 @@ def test_extra_schema_read(): writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.remote_prefix, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), + a.remote_prefix, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), ) with b as reader: df = reader.read(concat=True, truncate_extra_columns=False) diff --git a/tests/flytekit/unit/common_tests/types/test_blobs.py b/tests/flytekit/unit/common_tests/types/test_blobs.py index 4fdee08b3b..3057b11c6a 100644 --- a/tests/flytekit/unit/common_tests/types/test_blobs.py +++ b/tests/flytekit/unit/common_tests/types/test_blobs.py @@ -37,7 +37,10 @@ def test_blob_promote_from_model(): scalar=_literal_models.Scalar( blob=_literal_models.Blob( _literal_models.BlobMetadata( - _core_types.BlobType(format="f", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + _core_types.BlobType( + format="f", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), "some/path", ) diff --git a/tests/flytekit/unit/common_tests/types/test_helpers.py b/tests/flytekit/unit/common_tests/types/test_helpers.py index 425feef30f..dd8b45af23 100644 --- a/tests/flytekit/unit/common_tests/types/test_helpers.py +++ b/tests/flytekit/unit/common_tests/types/test_helpers.py @@ -35,7 +35,8 @@ def test_get_sdk_value_from_literal(): assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), sdk_type=_sdk_types.Types.Integer, + _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), + sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() is None diff --git a/tests/flytekit/unit/configuration/test_temporary_configuration.py b/tests/flytekit/unit/configuration/test_temporary_configuration.py index a41f131ec8..fe51f46d06 100644 --- a/tests/flytekit/unit/configuration/test_temporary_configuration.py +++ b/tests/flytekit/unit/configuration/test_temporary_configuration.py @@ -13,7 +13,8 @@ def test_configuration_file(): def test_internal_overrides(): with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), {"foo": "bar"}, + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), + {"foo": "bar"}, ): assert _os.environ.get("FLYTE_INTERNAL_FOO") == "bar" assert _os.environ.get("FLYTE_INTERNAL_FOO") is None diff --git a/tests/flytekit/unit/contrib/sensors/test_impl.py b/tests/flytekit/unit/contrib/sensors/test_impl.py index e14d4f75da..8382dae762 100644 --- a/tests/flytekit/unit/contrib/sensors/test_impl.py +++ b/tests/flytekit/unit/contrib/sensors/test_impl.py @@ -32,7 +32,9 @@ def test_HiveNamedPartitionSensor(): assert interval is None with mock.patch.object( - HMSClient, "get_partition_by_name", side_effect=_ttypes.NoSuchObjectException(), + HMSClient, + "get_partition_by_name", + side_effect=_ttypes.NoSuchObjectException(), ): success, interval = hive_named_partition_sensor._do_poll() assert not success diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index 0fb66c7235..56d9bd090c 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -117,7 +117,11 @@ def inner_test(ref_mock): def test_ref_plain_no_outputs(): - r1 = ReferenceEntity(TaskReference("proj", "domain", "some.name", "abc"), inputs=kwtypes(a=str, b=int), outputs={},) + r1 = ReferenceEntity( + TaskReference("proj", "domain", "some.name", "abc"), + inputs=kwtypes(a=str, b=int), + outputs={}, + ) # Reference entities should always raise an exception when not mocked out. with pytest.raises(Exception) as e: @@ -207,7 +211,13 @@ def inner_test(ref_mock): ) def test_lps(resource_type): ref_entity = get_reference_entity( - resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, + resource_type, + "proj", + "dom", + "app.other.flyte_entity", + "123", + inputs=kwtypes(a=str, b=int), + outputs={}, ) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_schedule.py b/tests/flytekit/unit/core/test_schedule.py index 8564064e05..bd76a24405 100644 --- a/tests/flytekit/unit/core/test_schedule.py +++ b/tests/flytekit/unit/core/test_schedule.py @@ -95,7 +95,8 @@ def test_cron_schedule_schedule_validation(schedule): @_pytest.mark.parametrize( - "schedule", ["foo", "* *"], + "schedule", + ["foo", "* *"], ) def test_cron_schedule_schedule_validation_invalid(schedule): with _pytest.raises(ValueError): @@ -143,7 +144,9 @@ def quadruple(a: int) -> int: return c lp = LaunchPlan.create( - "schedule_test", quadruple, schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"), + "schedule_test", + quadruple, + schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"), ) assert lp.schedule == _schedule_models.Schedule( "kickoff_input", rate=_schedule_models.Schedule.FixedRate(12, _schedule_models.Schedule.FixedRateUnit.HOUR) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 89f9f43b99..6d81ce317d 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -260,7 +260,11 @@ def t5(a: int) -> int: os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( - project="project", domain="domain", version="version", env=None, image_config=get_image_config(), + project="project", + domain="domain", + version="version", + env=None, + image_config=get_image_config(), ) t1_ser = get_serializable(OrderedDict(), rs, t1) assert t1_ser.container.image == "docker.io/xyz:version" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 4dcdaec4b7..271cc3cd62 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -80,7 +80,10 @@ def my_task() -> str: def test_engine_file_output(): - basic_blob_type = _core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + basic_blob_type = _core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting") with context_manager.FlyteContext.current_context().new_file_access_context(file_access_provider=fs) as ctx: diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index 452f9beeba..ba7d478cd0 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -29,7 +29,10 @@ @pytest.fixture(scope="function", autouse=True) def temp_config(): with TemporaryConfiguration( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), internal_overrides={ "image": "myflyteimage:{}".format(os.environ.get("IMAGE_VERSION", "sha")), "project": "myflyteproject", @@ -71,7 +74,10 @@ def test_task_system_failure(): engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) + utils.load_proto_from_file( + errors_pb2.ErrorDocument, + os.path.join(tmp.name, constants.ERROR_FILE_NAME), + ) ) assert doc.error.code == "SYSTEM:Unknown" assert doc.error.kind == errors.ContainerError.Kind.RECOVERABLE @@ -87,7 +93,10 @@ def test_task_user_failure(): engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) + utils.load_proto_from_file( + errors_pb2.ErrorDocument, + os.path.join(tmp.name, constants.ERROR_FILE_NAME), + ) ) assert doc.error.code == "USER:Unknown" assert doc.error.kind == errors.ContainerError.Kind.NON_RECOVERABLE @@ -112,7 +121,13 @@ def test_execution_notification_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, ), @@ -140,7 +155,13 @@ def test_execution_notification_soft_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), notifications=_execution_models.NotificationList([notification]), ), @@ -161,7 +182,12 @@ def test_execution_label_overrides(mock_client_factory): labels = _common_models.Labels({"my": "label"}) engine.FlyteLaunchPlan(m).execute( - "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], label_overrides=labels, + "xp", + "xd", + "xn", + literals.LiteralMap({}), + notification_overrides=[], + label_overrides=labels, ) mock_client.create_execution.assert_called_once_with( @@ -169,7 +195,13 @@ def test_execution_label_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, labels=labels, @@ -191,7 +223,12 @@ def test_execution_annotation_overrides(mock_client_factory): annotations = _common_models.Annotations({"my": "annotation"}) engine.FlyteLaunchPlan(m).launch( - "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations, + "xp", + "xd", + "xn", + literals.LiteralMap({}), + notification_overrides=[], + annotation_overrides=annotations, ) mock_client.create_execution.assert_called_once_with( @@ -199,7 +236,13 @@ def test_execution_annotation_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, annotations=annotations, @@ -254,12 +297,23 @@ def test_fetch_active_launch_plan(mock_client_factory): def test_get_full_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) + return_value=_execution_models.WorkflowExecutionGetDataResponse( + None, + None, + _INPUT_MAP, + _OUTPUT_MAP, + ) ) mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 @@ -280,7 +334,13 @@ def test_get_execution_inputs(mock_client_factory, execution_data_locations): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 @@ -299,7 +359,13 @@ def test_get_full_execution_outputs(mock_client_factory): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) outputs = engine.FlyteWorkflowExecution(m).get_outputs() assert len(outputs.literals) == 1 @@ -320,7 +386,13 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_outputs() assert len(inputs.literals) == 1 @@ -334,14 +406,24 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): def test_get_full_node_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) + return_value=_execution_models.NodeExecutionGetDataResponse( + None, + None, + _INPUT_MAP, + _OUTPUT_MAP, + ) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -350,7 +432,12 @@ def test_get_full_node_execution_inputs(mock_client_factory): assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -368,7 +455,12 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -377,7 +469,12 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -393,7 +490,12 @@ def test_get_full_node_execution_outputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -402,7 +504,12 @@ def test_get_full_node_execution_outputs(mock_client_factory): assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -420,7 +527,12 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -429,7 +541,12 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -445,9 +562,20 @@ def test_get_full_task_execution_inputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -458,9 +586,20 @@ def test_get_full_task_execution_inputs(mock_client_factory): assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -480,9 +619,20 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -493,9 +643,20 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -513,9 +674,20 @@ def test_get_full_task_execution_outputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -526,9 +698,20 @@ def test_get_full_task_execution_outputs(mock_client_factory): assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -548,9 +731,20 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -561,9 +755,20 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -573,7 +778,12 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location @pytest.mark.parametrize( "tasks", [ - [_task_models.Task(identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), MagicMock(),)], + [ + _task_models.Task( + identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + MagicMock(), + ) + ], [], ], ) diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 9b8e360f69..45dd18e9ae 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -10,7 +10,12 @@ # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( - "test", query_template="select * from tracks", task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + "test", + query_template="select * from tracks", + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) @@ -27,7 +32,10 @@ def test_task_schema(): query_template="select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], - task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) assert sql_task.output_columns is not None @@ -44,7 +52,10 @@ def my_task(df: pandas.DataFrame) -> int: "test", query_template="select * from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), - task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) @workflow diff --git a/tests/flytekit/unit/models/core/test_identifier.py b/tests/flytekit/unit/models/core/test_identifier.py index bf00aed216..7ca65daf1a 100644 --- a/tests/flytekit/unit/models/core/test_identifier.py +++ b/tests/flytekit/unit/models/core/test_identifier.py @@ -35,7 +35,10 @@ def test_node_execution_identifier(): def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") - node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) + node_exec_id = identifier.NodeExecutionIdentifier( + "node_id", + wf_exec_id, + ) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id diff --git a/tests/flytekit/unit/models/core/test_types.py b/tests/flytekit/unit/models/core/test_types.py index 744e1f90d5..e7e98cf166 100644 --- a/tests/flytekit/unit/models/core/test_types.py +++ b/tests/flytekit/unit/models/core/test_types.py @@ -9,7 +9,10 @@ def test_blob_dimensionality(): def test_blob_type(): - o = _types.BlobType(format="csv", dimensionality=_types.BlobType.BlobDimensionality.SINGLE,) + o = _types.BlobType( + format="csv", + dimensionality=_types.BlobType.BlobDimensionality.SINGLE, + ) assert o.format == "csv" assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index b64b87eb0b..2cc2e51420 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -41,7 +41,12 @@ def test_workflow_template(): {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) wf_node = _workflow.Node( - id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, + id="some:node:id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, @@ -111,7 +116,12 @@ def test_node_task_with_no_inputs(): task = _workflow.TaskNode(reference_id=_generic_id) obj = _workflow.Node( - id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, + id="some:node:id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=task, ) assert obj.target == task assert obj.id == "some:node:id" diff --git a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py index eab5f209b4..4b38672300 100644 --- a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py +++ b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py @@ -38,7 +38,10 @@ def test_hyperparameter_tuning_job(): input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) - tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) + tj = training_job.TrainingJob( + training_job_resource_config=rc, + algorithm_specification=alg, + ) hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj) hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/sagemaker/test_training_job.py b/tests/flytekit/unit/models/sagemaker/test_training_job.py index cc5e5f4d3d..271669b16c 100644 --- a/tests/flytekit/unit/models/sagemaker/test_training_job.py +++ b/tests/flytekit/unit/models/sagemaker/test_training_job.py @@ -70,7 +70,10 @@ def test_training_job(): input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) - tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) + tj = training_job.TrainingJob( + training_job_resource_config=rc, + algorithm_specification=alg, + ) tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl()) # checking tj == tj2 would return false because we don't have the __eq__ magic method defined diff --git a/tests/flytekit/unit/models/test_dynamic_job.py b/tests/flytekit/unit/models/test_dynamic_job.py index 0a9dff117f..1aff800abd 100644 --- a/tests/flytekit/unit/models/test_dynamic_job.py +++ b/tests/flytekit/unit/models/test_dynamic_job.py @@ -20,11 +20,18 @@ interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), ) for task_metadata, interfaces, resources in product( - parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES, + parameterizers.LIST_OF_TASK_METADATA, + parameterizers.LIST_OF_INTERFACES, + parameterizers.LIST_OF_RESOURCES, ) ] @@ -34,7 +41,12 @@ def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata("node-name", _timedelta(minutes=10), rs) n = _workflow.Node( - id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id), + id="id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=_workflow.TaskNode(task.id), ) n.to_flyte_idl() doc = _dynamic_job.DynamicJobSpec( diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 788c12119b..8c2b7db8a7 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -26,7 +26,9 @@ def test_lp_closure(): parameter_map.to_flyte_idl() variable_map = interface.VariableMap({"vvv": v}) obj = launch_plan.LaunchPlanClosure( - state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, expected_outputs=variable_map, + state=launch_plan.LaunchPlanState.ACTIVE, + expected_inputs=parameter_map, + expected_outputs=variable_map, ) assert obj.expected_inputs == parameter_map assert obj.expected_outputs == variable_map diff --git a/tests/flytekit/unit/models/test_schedule.py b/tests/flytekit/unit/models/test_schedule.py index 8bade49fcb..b7fad79124 100644 --- a/tests/flytekit/unit/models/test_schedule.py +++ b/tests/flytekit/unit/models/test_schedule.py @@ -37,7 +37,8 @@ def test_schedule_fixed_rate(): @_pytest.mark.parametrize( - "offset", [None, "P1D"], + "offset", + [None, "P1D"], ) def test_schedule_cron_schedule(offset): cs = _schedule.Schedule.CronSchedule("days", offset) diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index 0665f34c9f..0abd3a0402 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -95,7 +95,12 @@ def test_task_template(in_tuple): interfaces, {"a": 1, "b": {"c": 2, "d": 3}}, container=task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), config={"a": "b"}, ) @@ -143,7 +148,8 @@ def test_task_template_security_context(sec_ctx): @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): obj = task.Task( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), task_closure, + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + task_closure, ) assert obj.id.project == "project" assert obj.id.domain == "domain" @@ -156,7 +162,12 @@ def test_task(task_closure): @pytest.mark.parametrize("resources", parameterizers.LIST_OF_RESOURCES) def test_container(resources): obj = task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ) obj.image == "my_image" obj.command == ["this", "is", "a", "cmd"] @@ -182,7 +193,10 @@ def test_sidecar_task(): def test_dataloadingconfig(): dlc = task.DataLoadingConfig( - "s3://input/path", "s3://output/path", True, task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + "s3://input/path", + "s3://output/path", + True, + task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, ) dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3368387a7f..3e19a80657 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -17,7 +17,8 @@ def test_workflow_closure(): ) b0 = _literals.Binding( - "a", _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), + "a", + _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), ) b1 = _literals.Binding("b", _literals.BindingData(promise=_types.OutputReference("my_node", "b"))) b2 = _literals.Binding("c", _literals.BindingData(promise=_types.OutputReference("my_node", "c"))) @@ -46,13 +47,23 @@ def test_workflow_closure(): typed_interface, {"a": 1, "b": {"c": 2, "d": 3}}, container=_task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {}, + {}, ), ) task_node = _workflow.TaskNode(task.id) node = _workflow.Node( - id="my_node", metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node, + id="my_node", + metadata=node_metadata, + inputs=[b0], + upstream_node_ids=[], + output_aliases=[], + task_node=task_node, ) template = _workflow.WorkflowTemplate( diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py index 532a92d5a8..2de75de96c 100644 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py @@ -12,9 +12,18 @@ def get_pod_spec(): a_container = generated_pb2.Container(name="main") a_container.command.extend(["foo", "bar"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="scratch", mountPath="/scratch",)]) + a_container.volumeMounts.extend( + [ + generated_pb2.VolumeMount( + name="scratch", + mountPath="/scratch", + ) + ] + ) - pod_spec = generated_pb2.PodSpec(restartPolicy="Never",) + pod_spec = generated_pb2.PodSpec( + restartPolicy="Never", + ) pod_spec.containers.extend([a_container, generated_pb2.Container(name="sidecar")]) return pod_spec diff --git a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py index 63542e3331..e81edec5de 100644 --- a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py @@ -36,7 +36,9 @@ def sample_hive_task_no_queries(wf_params): @qubole_hive_task( - cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[], + cache_version="1", + cluster_label=_six.text_type("cluster_label"), + tags=[], ) def sample_qubole_hive_task_no_input(wf_params): return _six.text_type("select 5") @@ -44,7 +46,9 @@ def sample_qubole_hive_task_no_input(wf_params): @inputs(in1=Types.Integer) @qubole_hive_task( - cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[_six.text_type("tag1")], + cache_version="1", + cluster_label=_six.text_type("cluster_label"), + tags=[_six.text_type("tag1")], ) def sample_qubole_hive_task(wf_params, in1): return _six.text_type("select ") + _six.text_type(in1) diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py index f41c547ad5..19b4b8e2a1 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -83,7 +83,9 @@ def test_builtin_algorithm_training_job_task(): builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -100,7 +102,10 @@ def test_builtin_algorithm_training_job_task(): assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) assert builtin_algorithm_training_job_task.interface.inputs["train"].description == "" assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert ( builtin_algorithm_training_job_task.interface.inputs["train"].type @@ -112,7 +117,10 @@ def test_builtin_algorithm_training_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].description == "" assert ( @@ -133,13 +141,16 @@ def test_builtin_algorithm_training_job_task(): assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom["algorithmSpecification"].keys() ParseDict( - builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], _pb2_TrainingJobResourceConfig(), + builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], + _pb2_TrainingJobResourceConfig(), ) # fails the test if it cannot be parsed builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -175,7 +186,10 @@ def test_simple_hpo_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert simple_xgboost_hpo_job_task.interface.inputs["validation"].description == "" assert ( @@ -183,7 +197,10 @@ def test_simple_hpo_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].description == "" assert ( @@ -227,7 +244,9 @@ def test_custom_training_job(): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -252,7 +271,8 @@ class MyWf(object): default=_HyperparameterTuningJobConfig( tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="validation:error", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, ), @@ -309,7 +329,9 @@ def setUp(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -360,7 +382,9 @@ def setUp(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=2, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -496,7 +520,9 @@ def test_if_wf_param_has_dist_context(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=2, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, diff --git a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py index a0fbf1f8a4..0ca8011beb 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py @@ -10,11 +10,22 @@ def get_pod_spec(): - a_container = generated_pb2.Container(name="a container",) + a_container = generated_pb2.Container( + name="a container", + ) a_container.command.extend(["fee", "fi", "fo", "fum"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="volume mount", mountPath="some/where",)]) + a_container.volumeMounts.extend( + [ + generated_pb2.VolumeMount( + name="volume mount", + mountPath="some/where", + ) + ] + ) - pod_spec = generated_pb2.PodSpec(restartPolicy="OnFailure",) + pod_spec = generated_pb2.PodSpec( + restartPolicy="OnFailure", + ) pod_spec.containers.extend([a_container, generated_pb2.Container(name="another container")]) return pod_spec diff --git a/tests/flytekit/unit/sdk/tasks/test_tasks.py b/tests/flytekit/unit/sdk/tasks/test_tasks.py index a21ffaed04..33e0287199 100644 --- a/tests/flytekit/unit/sdk/tasks/test_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_tasks.py @@ -43,7 +43,10 @@ def test_default_python_task(): def test_default_resources(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../configuration/configs/good.config", + ) ): @inputs(in1=Types.Integer) @@ -69,7 +72,10 @@ def default_task2(wf_params, in1, out1): def test_overriden_resources(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../configuration/configs/good.config", + ) ): @inputs(in1=Types.Integer) diff --git a/tests/flytekit/unit/test_plugins.py b/tests/flytekit/unit/test_plugins.py index 1a5fb33ccf..c6680e34da 100644 --- a/tests/flytekit/unit/test_plugins.py +++ b/tests/flytekit/unit/test_plugins.py @@ -26,7 +26,10 @@ def test_schema_plugin(): @pytest.mark.run(order=2) def test_sidecar_plugin(): assert isinstance(plugins.k8s.io.api.core.v1.generated_pb2, lazy_loader._LazyLoadModule) - assert isinstance(plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, lazy_loader._LazyLoadModule,) + assert isinstance( + plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, + lazy_loader._LazyLoadModule, + ) import k8s.io.api.core.v1.generated_pb2 import k8s.io.apimachinery.pkg.api.resource.generated_pb2 diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py index b37ce8981a..8157739456 100644 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py @@ -55,7 +55,11 @@ def test_infer_proto_from_literal(): _literal_models.Literal( scalar=_literal_models.Scalar( binary=_literal_models.Binary( - value="", tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError",), + value="", + tag="{}{}".format( + _proto.Protobuf.TAG_PREFIX, + "flyteidl.core.errors_pb2.ContainerError", + ), ) ) ) From 928e865f9f74abbe50c167360f124e54cc56fe53 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 26 Mar 2021 14:21:56 -0700 Subject: [PATCH 04/33] Pass through FlyteFile and FlyteDirectory if created from a remote source (#436) Signed-off-by: Max Hoffman --- flytekit/types/directory/types.py | 5 +++ flytekit/types/file/file.py | 5 +++ .../unit/core/test_flyte_directory.py | 32 +++++++++++++++++ tests/flytekit/unit/core/test_flyte_file.py | 34 +++++++++++++++++++ 4 files changed, 76 insertions(+) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 90d7bedff0..6b6fbbaabf 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -168,6 +168,11 @@ def to_literal( # There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory. # Handle the FlyteDirectory case if isinstance(python_val, FlyteDirectory): + # If the object has a remote source, then we just convert it back. + if python_val._remote_source is not None: + meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type))) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source))) + source_path = python_val.path if python_val.remote_directory is False: # If the user specified the remote_path to be False, that means no matter what, do not upload diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index bc72fe3233..699745350f 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -217,6 +217,11 @@ def to_literal( if python_val is None: raise AssertionError("None value cannot be converted to a file.") if isinstance(python_val, FlyteFile): + # If the object has a remote source, then we just convert it back. + if python_val._remote_source is not None: + meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type))) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source))) + source_path = python_val.path if python_val.remote_path is False: # If the user specified the remote_path to be False, that means no matter what, do not upload diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 5ecdc35b09..fe07db7a1a 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -6,11 +6,14 @@ import flytekit from flytekit.core import context_manager +from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.interfaces.data.data_proxy import FileAccessProvider from flytekit.models.core.types import BlobType +from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer @@ -133,3 +136,32 @@ def wf2() -> int: x = wf2() assert x == 5 + + +def test_dont_convert_remotes(): + @task + def t1(in1: FlyteDirectory): + print(in1) + + @dynamic + def dyn(in1: FlyteDirectory): + t1(in1=in1) + + fd = FlyteDirectory("s3://anything") + + with context_manager.FlyteContext.current_context().new_serialization_settings( + serialization_settings=context_manager.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + ) as ctx: + with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: + lit = TypeEngine.to_literal( + ctx, fd, FlyteDirectory, BlobType("", dimensionality=BlobType.BlobDimensionality.MULTIPART) + ) + lm = LiteralMap(literals={"in1": lit}) + wf = dyn.dispatch_execute(ctx, lm) + assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything" diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index f063d3d4be..9ac6c8b189 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -2,9 +2,14 @@ import flytekit from flytekit.core import context_manager +from flytekit.core.context_manager import ExecutionState, Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.interfaces.data.data_proxy import FileAccessProvider +from flytekit.models.core.types import BlobType +from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile @@ -207,3 +212,32 @@ def my_wf() -> FlyteFile: # The file name is maintained on download. assert str(workflow_output).endswith(os.path.split(SAMPLE_DATA)[1]) + + +def test_dont_convert_remotes(): + @task + def t1(in1: FlyteFile): + print(in1) + + @dynamic + def dyn(in1: FlyteFile): + t1(in1=in1) + + fd = FlyteFile("s3://anything") + + with context_manager.FlyteContext.current_context().new_serialization_settings( + serialization_settings=context_manager.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + ) as ctx: + with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: + lit = TypeEngine.to_literal( + ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE) + ) + lm = LiteralMap(literals={"in1": lit}) + wf = dyn.dispatch_execute(ctx, lm) + assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything" From d8345d27373eb07297a8ad75c6ac95feed3655fb Mon Sep 17 00:00:00 2001 From: Honnix Date: Fri, 26 Mar 2021 22:30:32 +0100 Subject: [PATCH 05/33] Spark2 CI (#435) * CI for spark2 Signed-off-by: Hongxin Liang * Simplify Makefile Signed-off-by: Hongxin Liang Signed-off-by: Max Hoffman --- .github/workflows/pythonbuild.yml | 8 ++--- Makefile | 32 +++++++++---------- dev-requirements.txt | 1 - ...ements-spark3.in => requirements-spark2.in | 2 +- ...ents-spark3.txt => requirements-spark2.txt | 8 ++--- setup.py | 1 - 6 files changed, 24 insertions(+), 28 deletions(-) rename requirements-spark3.in => requirements-spark2.in (58%) rename requirements-spark3.txt => requirements-spark2.txt (98%) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 3260c8a473..cc0d7f0a52 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - include: + python-version: [3.7, 3.8] + spark-version-suffix: ["", "-spark2"] + exclude: - python-version: 3.8 - spark-version-suffix: "-spark3" - - python-version: 3.7 - spark-version-suffix: "" + spark-version-suffix: "-spark2" steps: - uses: actions/checkout@v2 diff --git a/Makefile b/Makefile index 6f6c79f293..821588a9b6 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,4 @@ -define PIP_COMPILE -pip-compile $(1) --upgrade --verbose -endef +PIP_COMPILE = pip-compile --upgrade --verbose .SILENT: help .PHONY: help @@ -17,9 +15,9 @@ install-piptools: setup: install-piptools ## Install requirements pip-sync requirements.txt dev-requirements.txt -.PHONY: setup-spark3 -setup-spark3: install-piptools ## Install requirements - pip-sync requirements-spark3.txt dev-requirements.txt +.PHONY: setup-spark2 +setup-spark2: install-piptools ## Install requirements + pip-sync requirements-spark2.txt dev-requirements.txt .PHONY: fmt fmt: ## Format code with black and isort @@ -47,24 +45,24 @@ unit_test: pytest tests/scripts pytest plugins/tests -requirements-spark3.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark3.txt -requirements-spark3.txt: requirements-spark3.in install-piptools - $(call PIP_COMPILE,requirements-spark3.in) +requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt +requirements-spark2.txt: requirements-spark2.in install-piptools + $(PIP_COMPILE) $< requirements.txt: export CUSTOM_COMPILE_COMMAND := make requirements.txt -requirements.txt: install-piptools - $(call PIP_COMPILE,requirements.in) +requirements.txt: requirements.in install-piptools + $(PIP_COMPILE) $< dev-requirements.txt: export CUSTOM_COMPILE_COMMAND := make dev-requirements.txt -dev-requirements.txt: requirements.txt install-piptools - $(call PIP_COMPILE,dev-requirements.in) +dev-requirements.txt: dev-requirements.in requirements.txt install-piptools + $(PIP_COMPILE) $< doc-requirements.txt: export CUSTOM_COMPILE_COMMAND := make doc-requirements.txt -doc-requirements.txt: dev-requirements.txt install-piptools - $(call PIP_COMPILE,doc-requirements.in) +doc-requirements.txt: doc-requirements.in install-piptools + $(PIP_COMPILE) $< .PHONY: requirements -requirements: requirements.txt dev-requirements.txt requirements-spark3.txt doc-requirements.txt ## Compile requirements +requirements: requirements.txt dev-requirements.txt requirements-spark2.txt doc-requirements.txt ## Compile requirements # TODO: Change this in the future to be all of flytekit .PHONY: coverage @@ -80,6 +78,6 @@ update_version: # it exits with exit code 1 and github actions aborts the build. grep "$(PLACEHOLDER)" "flytekit/__init__.py" sed -i "s/$(PLACEHOLDER)/__version__ = \"${VERSION}\"/g" "flytekit/__init__.py" - + grep "$(PLACEHOLDER)" "setup.py" sed -i "s/$(PLACEHOLDER)/__version__ = \"${VERSION}\"/g" "setup.py" diff --git a/dev-requirements.txt b/dev-requirements.txt index 5acc7855b8..ba18954a7b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,7 +14,6 @@ attrs==20.3.0 # pytest black==20.8b1 # via - # -c requirements.txt # -r dev-requirements.in # flake8-black click==7.1.2 diff --git a/requirements-spark3.in b/requirements-spark2.in similarity index 58% rename from requirements-spark3.in rename to requirements-spark2.in index f12785fbdb..d5f2633c7e 100644 --- a/requirements-spark3.in +++ b/requirements-spark2.in @@ -1,2 +1,2 @@ -.[all-spark3] +.[all-spark2.4] -e file:.#egg=flytekit diff --git a/requirements-spark3.txt b/requirements-spark2.txt similarity index 98% rename from requirements-spark3.txt rename to requirements-spark2.txt index 69aec3fb38..c58e52ca25 100644 --- a/requirements-spark3.txt +++ b/requirements-spark2.txt @@ -2,10 +2,10 @@ # This file is autogenerated by pip-compile # To update, run: # -# make requirements-spark3.txt +# make requirements-spark2.txt # -e file:.#egg=flytekit - # via -r requirements-spark3.in + # via -r requirements-spark2.in ansiwrap==0.8.4 # via papermill appdirs==1.4.4 @@ -190,7 +190,7 @@ psutil==5.8.0 # via sagemaker-training ptyprocess==0.7.0 # via pexpect -py4j==0.10.9 +py4j==0.10.7 # via pyspark py==1.10.0 # via retry @@ -209,7 +209,7 @@ pyparsing==2.4.7 # via packaging pyrsistent==0.17.3 # via jsonschema -pyspark==3.1.1 +pyspark==2.4.7 # via flytekit python-dateutil==2.8.1 # via diff --git a/setup.py b/setup.py index 0ea2a69fa8..cd687e7c53 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,6 @@ "notebook": notebook, "sagemaker": sagemaker, "all-spark2.4": spark + all_but_spark, - "all-spark3": spark3 + all_but_spark, "all": spark3 + all_but_spark, } From 2f051987e6cb74b135f94888ab60b445971a1233 Mon Sep 17 00:00:00 2001 From: Miguel Toledo Date: Mon, 29 Mar 2021 17:28:46 -0400 Subject: [PATCH 06/33] add requests and limits parameter to ContainerTask (#438) * add requests and limits parameter Signed-off-by: Miguel Toledo * fix typo Signed-off-by: Miguel Toledo * signoff Signed-off-by: Miguel Toledo Signed-off-by: Max Hoffman --- flytekit/core/container_task.py | 14 ++++++++++++++ .../flytekit/unit/common_tests/test_translator.py | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index bf54986058..13d1a6e731 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -5,6 +5,7 @@ from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.context_manager import SerializationSettings from flytekit.core.interface import Interface +from flytekit.core.resources import Resources, ResourceSpec from flytekit.models import task as _task_model @@ -31,6 +32,8 @@ def __init__( metadata: Optional[TaskMetadata] = None, arguments: List[str] = None, outputs: Dict[str, Type] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, input_data_dir: str = None, output_data_dir: str = None, metadata_format: MetadataFormat = MetadataFormat.JSON, @@ -52,6 +55,13 @@ def __init__( self._output_data_dir = output_data_dir self._md_format = metadata_format self._io_strategy = io_strategy + self._resources = ResourceSpec( + requests=requests if requests else Resources(), limits=limits if limits else Resources() + ) + + @property + def resources(self) -> ResourceSpec: + return self._resources def execute(self, **kwargs) -> Any: print(kwargs) @@ -78,4 +88,8 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe io_strategy=self._io_strategy.value if self._io_strategy else None, ), environment=env, + cpu_request=self.resources.requests.cpu, + cpu_limit=self.resources.limits.cpu, + memory_request=self.resources.requests.mem, + memory_limit=self.resources.limits.mem, ) diff --git a/tests/flytekit/unit/common_tests/test_translator.py b/tests/flytekit/unit/common_tests/test_translator.py index 70927fbfdf..7dec6bbe8a 100644 --- a/tests/flytekit/unit/common_tests/test_translator.py +++ b/tests/flytekit/unit/common_tests/test_translator.py @@ -1,7 +1,7 @@ import typing from collections import OrderedDict -from flytekit import ContainerTask +from flytekit import ContainerTask, Resources from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import kwtypes @@ -92,6 +92,7 @@ def t1(a: int) -> (int, str): output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], + requests=Resources(mem="400Mi", cpu="1"), ) sdk_task = get_serializable(OrderedDict(), serialization_settings, t2, fast=True) From 438c5f0d13237bb072988757fa405651f67f67ad Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Tue, 30 Mar 2021 11:39:59 -0400 Subject: [PATCH 07/33] update flytekit docs theme, fix index links (#439) * update flytekit docs theme, fix index links Signed-off-by: cosmicBboy * add readthedocs sphinx search Signed-off-by: cosmicBboy * update flyteidl link Signed-off-by: cosmicBboy * update overview and homepage Signed-off-by: cosmicBboy * address comments @wild-endeavor Signed-off-by: cosmicBboy * update ipykernel Signed-off-by: cosmicBboy * add community link Signed-off-by: cosmicBboy Signed-off-by: Max Hoffman --- dev-requirements.txt | 1 + doc-requirements.in | 2 + doc-requirements.txt | 35 +++++++------- docs/source/_static/custom.css | 33 +++++++++++++ docs/source/_templates/sidebar/brand.html | 18 +++++++ docs/source/conf.py | 58 +++++++++++------------ docs/source/design/authoring.rst | 7 ++- docs/source/design/index.rst | 13 ++--- docs/source/design/models.rst | 2 +- docs/source/generator.rst | 9 ---- docs/source/index.rst | 31 ++++++------ docs/source/reference/index.rst | 24 ++++++++++ flytekit/__init__.py | 18 +++++-- requirements.txt | 30 +++++------- 14 files changed, 179 insertions(+), 102 deletions(-) create mode 100644 docs/source/_static/custom.css create mode 100644 docs/source/_templates/sidebar/brand.html delete mode 100644 docs/source/generator.rst create mode 100644 docs/source/reference/index.rst diff --git a/dev-requirements.txt b/dev-requirements.txt index ba18954a7b..5acc7855b8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,6 +14,7 @@ attrs==20.3.0 # pytest black==20.8b1 # via + # -c requirements.txt # -r dev-requirements.in # flake8-black click==7.1.2 diff --git a/doc-requirements.in b/doc-requirements.in index 4383aa26f1..1b104414c4 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -1,6 +1,8 @@ .[all] -e file:.#egg=flytekit +furo +readthedocs-sphinx-search sphinx sphinx-gallery sphinx-prompt diff --git a/doc-requirements.txt b/doc-requirements.txt index 963d54ce5e..84195f76de 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -12,7 +12,11 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -astroid==2.5.1 +appnope==0.1.2 + # via + # ipykernel + # ipython +astroid==2.5.2 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -28,15 +32,16 @@ bcrypt==3.2.0 # via paramiko beautifulsoup4==4.9.3 # via + # furo # sphinx-code-include # sphinx-material black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.39 +boto3==1.17.40 # via sagemaker-training -botocore==1.20.39 +botocore==1.20.40 # via # boto3 # s3transfer @@ -61,7 +66,6 @@ cryptography==3.4.7 # via # -r doc-requirements.in # paramiko - # secretstorage css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.2 @@ -86,6 +90,8 @@ entrypoints==0.3 # papermill flyteidl==0.18.26 # via flytekit +furo==2021.3.20b30 + # via -r doc-requirements.in gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 @@ -100,7 +106,7 @@ idna==2.10 # via requests imagesize==1.2.0 # via sphinx -importlib-metadata==3.7.3 +importlib-metadata==3.9.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training @@ -110,14 +116,10 @@ ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.21.0 +ipython==7.22.0 # via ipykernel jedi==0.18.0 # via ipython -jeepney==0.6.0 - # via - # keyring - # secretstorage jinja2==2.11.3 # via # nbconvert @@ -152,7 +154,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.10.0 +marshmallow==3.11.0 # via # dataclasses-json # marshmallow-enum @@ -177,7 +179,7 @@ nbformat==5.1.2 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.1 +numpy==1.20.2 # via # flytekit # pandas @@ -263,6 +265,8 @@ pyyaml==5.4.1 # sphinx-autoapi pyzmq==22.0.3 # via jupyter-client +readthedocs-sphinx-search==0.1.0 + # via -r doc-requirements.in regex==2021.3.17 # via # black @@ -273,7 +277,7 @@ requests==2.25.1 # papermill # responses # sphinx -responses==0.13.1 +responses==0.13.2 # via flytekit retry==0.9.2 # via flytekit @@ -281,14 +285,12 @@ retrying==1.3.3 # via sagemaker-training s3transfer==0.3.6 # via boto3 -sagemaker-training==3.7.3 +sagemaker-training==3.7.4 # via flytekit scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training -secretstorage==3.3.1 - # via keyring six==1.15.0 # via # bcrypt @@ -327,6 +329,7 @@ sphinx-prompt==1.4.0 sphinx==3.5.3 # via # -r doc-requirements.in + # furo # sphinx-autoapi # sphinx-code-include # sphinx-copybutton diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css new file mode 100644 index 0000000000..786fcd2611 --- /dev/null +++ b/docs/source/_static/custom.css @@ -0,0 +1,33 @@ +h1, h2, h3, h4, h5, h6 { + font-weight: bold; +} + +.sidebar-logo { + max-width: 30%; +} + + +.sidebar-tree .reference.external:after { + content: none; +} + +div.sphx-glr-download a { + color: #4300c9; + background-color: rgb(241, 241, 241); + background-image: none; + border: 1px solid rgb(202, 202, 202); +} + +div.sphx-glr-download a:hover { + background-color: rgb(230, 230, 230); + box-shadow: none; +} + +div.sphx-glr-thumbcontainer a.headerlink { + display: none; +} + +div.sphx-glr-thumbcontainer:hover { + border-color: white; + box-shadow: none; +} diff --git a/docs/source/_templates/sidebar/brand.html b/docs/source/_templates/sidebar/brand.html new file mode 100644 index 0000000000..a170d6c6d1 --- /dev/null +++ b/docs/source/_templates/sidebar/brand.html @@ -0,0 +1,18 @@ + diff --git a/docs/source/conf.py b/docs/source/conf.py index 98a324a6c4..15df491168 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,6 +47,7 @@ "sphinx.ext.graphviz", "sphinx-prompt", "sphinx_copybutton", + "sphinx_search.extension", ] # build the templated autosummary files @@ -87,55 +88,50 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "sphinx_material" +html_theme = "furo" +html_title = "Flyte Docs" + html_theme_options = { - # Set the name of the project to appear in the navigation. - "nav_title": "Flytekit Python Reference", - # Set you GA account ID to enable tracking - "google_analytics_account": "G-YQL24L5CKY", - # Specify a base_url used to generate sitemap.xml. If not - # specified, then no sitemap will be built. - "base_url": "https://github.com/lyft/flytekit", - # Set the color and the accent color - "color_primary": "deep-purple", - "color_accent": "blue", - # Set the repo location to get a badge with stats - "repo_url": "https://github.com/lyft/flyte/", - "repo_name": "flyte", - # Visible levels of the global TOC; -1 means unlimited - "globaltoc_depth": 1, - # If False, expand all TOC entries - "globaltoc_collapse": False, - # If True, show hidden TOC entries - "globaltoc_includehidden": False, - # don't include home link in breadcrumb bar, since it's included - # in the nav_links key below. - "master_doc": False, - # custom nav in breadcrumb bar - "nav_links": [ - {"href": "https://flyte.readthedocs.io/", "internal": False, "title": "Flyte"}, - {"href": "https://flytecookbook.readthedocs.io", "internal": False, "title": "Flytekit Tutorials"}, - {"href": "index", "internal": True, "title": "Flytekit Python Reference"}, - ], + "light_css_variables": { + "color-brand-primary": "#4300c9", + "color-brand-content": "#4300c9", + }, + "dark_css_variables": { + "color-brand-primary": "#9D68E4", + "color-brand-content": "#9D68E4", + }, } +templates_path = ["_templates"] + # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. -html_sidebars = {"**": ["logo-text.html", "globaltoc.html", "localtoc.html", "searchbox.html"]} +# html_sidebars = {"**": ["logo-text.html", "globaltoc.html", "localtoc.html", "searchbox.html"]} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = [] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # html_logo = "flyte_circle_gradient_1_4x4.png" +pygments_style = "tango" +pygments_dark_style = "paraiso-dark" + +html_css_files = [ + "custom.css", +] + +html_context = { + "home_page": "https://docs.flyte.org", +} + # -- Options for HTMLHelp output --------------------------------------------- diff --git a/docs/source/design/authoring.rst b/docs/source/design/authoring.rst index d0b569faf5..62a953bf30 100644 --- a/docs/source/design/authoring.rst +++ b/docs/source/design/authoring.rst @@ -1,7 +1,7 @@ .. _design-authoring: ############################ -Flytekit Authoring Structure +Authoring Structure ############################ Enabling users to write tasks and workflows is the core feature of flytekit, it is why it exists. This document goes over how some of the internals work. @@ -51,13 +51,16 @@ Workflows ========= There is currently only one :py:class:`Workflow ` class. -.. autoclass:: flytekit.core.workflow.Workflow +.. autoclass:: flytekit.core.workflow.PythonFunctionWorkflow :noindex: Launch Plan =========== There is also only one :py:class:`LaunchPlan ` class. +.. autoclass:: flytekit.core.launch_plan.LaunchPlan + :noindex: + ************** Call Patterns ************** diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 2b9c2005fe..24eef916a1 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -1,20 +1,21 @@ .. _design: ############################ -Structure of Flytekit +Overview ############################ Flytekit is comprised of a handful of different logical components, each discusssed in greater detail in each link -* Models - These are almost Protobuf generated files. -* Authoring - This provides the core Flyte authoring experiences, allowing users to write tasks, workflows, and launch plans. -* Control Plane - The code here allows users to interact with the control plane through Python objecs. -* Execution - A small shim layer basically that handles interaction with the Flyte ecosystem at execution time. -* CLIs and Clients - Command line tools users may find themselves interacting with and the control plane client the CLIs call. +* :ref:`Models Files ` - These are almost Protobuf generated files. +* :ref:`Authoring ` - This provides the core Flyte authoring experiences, allowing users to write tasks, workflows, and launch plans. +* :ref:`Control Plane ` - The code here allows users to interact with the control plane through Python objecs. +* :ref:`Execution ` - A small shim layer basically that handles interaction with the Flyte ecosystem at execution time. +* :ref:`CLIs and Clients ` - Command line tools users may find themselves interacting with and the control plane client the CLIs call. .. toctree:: :maxdepth: 1 :caption: Structure and Layout of Flytekit + :hidden: models authoring diff --git a/docs/source/design/models.rst b/docs/source/design/models.rst index f561c06119..63be15098f 100644 --- a/docs/source/design/models.rst +++ b/docs/source/design/models.rst @@ -1,7 +1,7 @@ .. _design-models: ###################### -Flytekit Model Files +Model Files ###################### *********** diff --git a/docs/source/generator.rst b/docs/source/generator.rst deleted file mode 100644 index eb937fdbab..0000000000 --- a/docs/source/generator.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. currentmodule:: flytekit - -.. autosummary:: - :nosignatures: - :toctree: generated/ - - core.dynamic_workflow_task - core.notification - core.schedule diff --git a/docs/source/index.rst b/docs/source/index.rst index 0035b514db..a47b7dc653 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,17 +9,29 @@ Flytekit Python Reference This section of the documentation provides more detailed descriptions of the high-level design of ``flytekit`` and an API reference for specific usage details of python functions, classes, and decorators that you import to specify tasks, -build workflows, extend ``flytekit``. +build workflows, and extend ``flytekit``. + +.. toctree:: + :maxdepth: 4 + :hidden: + + Getting Started + Tutorials + reference/index + Community .. toctree:: :maxdepth: 1 - :caption: Design + :caption: Flytekit SDK + :hidden: + Flytekit Python design/index .. toctree:: :maxdepth: 1 :caption: APIs + :hidden: flytekit testing @@ -28,21 +40,8 @@ build workflows, extend ``flytekit``. .. toctree:: :maxdepth: 1 :caption: Contributing + :hidden: contributing tasks.extend types.extend - - -.. toctree:: - :maxdepth: 1 - :hidden: - - generator - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst new file mode 100644 index 0000000000..f77a447060 --- /dev/null +++ b/docs/source/reference/index.rst @@ -0,0 +1,24 @@ +############# +API Reference +############# + +.. toctree:: + :maxdepth: 1 + :caption: API Reference + :name: apitoc + + Flytekit Python + Flytekit Java + FlyteIDL + Flytectl + +.. toctree:: + :maxdepth: 1 + :caption: Component Reference (Code docs) + :name: componentreftoc + + FlytePropeller + FlyteAdmin + FlytePlugins + DataCatalog + \ No newline at end of file diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 249f87b71f..2e84c8a6a0 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -52,12 +52,11 @@ TaskMetadata - Wrapper object that allows users to specify Task Resources - Things like CPUs/Memory, etc. WorkflowFailurePolicy - Customizes what happens when a workflow fails. - dynamic Dynamic and Nested Workflows ============================== -Please see the :py:mod:`Dynamic ` module for more information as well. +See the :py:mod:`Dynamic ` module for more information. .. autosummary:: :nosignatures: @@ -67,9 +66,9 @@ Scheduling and Notifications ============================ -:py:mod:`Notifications Module ` -:py:mod:`Schedules Module ` +See the :py:mod:`Notifications Module ` and +:py:mod:`Schedules Module ` for more information. .. autosummary:: :nosignatures: @@ -118,6 +117,17 @@ Secret SecurityContext +Core Modules +============= + +.. autosummary:: + :nosignatures: + :toctree: generated/ + + core.dynamic_workflow_task + core.notification + core.schedule + """ diff --git a/requirements.txt b/requirements.txt index 866325c7c7..0473ae8da4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,10 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black +appnope==0.1.2 + # via + # ipykernel + # ipython async-generator==1.10 # via nbclient attrs==20.3.0 @@ -24,9 +28,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.39 +boto3==1.17.40 # via sagemaker-training -botocore==1.20.39 +botocore==1.20.40 # via # boto3 # s3transfer @@ -48,9 +52,7 @@ click==7.1.2 croniter==1.0.10 # via flytekit cryptography==3.4.7 - # via - # paramiko - # secretstorage + # via paramiko dataclasses-json==0.5.2 # via flytekit decorator==4.4.2 @@ -81,7 +83,7 @@ hmsclient==0.1.1 # via flytekit idna==2.10 # via requests -importlib-metadata==3.7.3 +importlib-metadata==3.9.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training @@ -91,14 +93,10 @@ ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.21.0 +ipython==7.22.0 # via ipykernel jedi==0.18.0 # via ipython -jeepney==0.6.0 - # via - # keyring - # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -126,7 +124,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.10.0 +marshmallow==3.11.0 # via # dataclasses-json # marshmallow-enum @@ -151,7 +149,7 @@ nbformat==5.1.2 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.1 +numpy==1.20.2 # via # flytekit # pandas @@ -237,7 +235,7 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.1 +responses==0.13.2 # via flytekit retry==0.9.2 # via flytekit @@ -245,14 +243,12 @@ retrying==1.3.3 # via sagemaker-training s3transfer==0.3.6 # via boto3 -sagemaker-training==3.7.3 +sagemaker-training==3.7.4 # via flytekit scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training -secretstorage==3.3.1 - # via keyring six==1.15.0 # via # bcrypt From b37cc0b26f8b5e7cdd14531ddf23e9c10740e725 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Wed, 31 Mar 2021 12:00:04 -0400 Subject: [PATCH 08/33] dark theme updates (#441) Signed-off-by: cosmicBboy Signed-off-by: Max Hoffman --- docs/source/_static/custom.css | 66 +++++++++++++++++++++++++++++++--- docs/source/conf.py | 2 +- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css index 786fcd2611..d9851b7d8f 100644 --- a/docs/source/_static/custom.css +++ b/docs/source/_static/custom.css @@ -11,15 +11,24 @@ h1, h2, h3, h4, h5, h6 { content: none; } +.sphx-glr-thumbcontainer { + background-color: transparent; + border: transparent; +} + +.sphx-glr-thumbcontainer:hover { + border: transparent; +} + div.sphx-glr-download a { - color: #4300c9; - background-color: rgb(241, 241, 241); + color:white; + background-color: #9d68e4cf; background-image: none; - border: 1px solid rgb(202, 202, 202); + border: 1px solid #9d68e4cf; } div.sphx-glr-download a:hover { - background-color: rgb(230, 230, 230); + background-color: #8b48e2cf; box-shadow: none; } @@ -31,3 +40,52 @@ div.sphx-glr-thumbcontainer:hover { border-color: white; box-shadow: none; } + +.sphx-glr-script-out .highlight pre { + background-color: #f8f8f8; +} + +p.sphx-glr-script-out { + padding-top: 0em; +} + +.search__outer::-webkit-scrollbar-track { + border-radius: 0px; +} + +@media (prefers-color-scheme: dark) { + .search__outer { + background-color: #131416 !important; + border: 1px solid #131416 !important; + } + .search__outer__input { + background-color: #1a1c1e !important; + } + .search__result__single { + border-bottom: #303335 !important; + } + .outer_div_page_results:hover { + background-color: black; + } + .search__result__title, .rtd_ui_search_subtitle { + color: #9D68E4 !important; + border-bottom: 1px solid #9D68E4 !important; + } + .search__outer .search__result__title span, .search__outer .search__result__content span { + background-color: #9d68e454; + } + .search__result__subheading, .search__result__content { + color: #ffffffd9 !important; + } + .search__outer::-webkit-scrollbar-track { + background-color: #131416 !important; + } + .rtd__search__credits { + background-color: #1a1c1e !important; + border: 1px solid #1a1c1e !important; + color: #81868d !important; + } + .rtd__search__credits a, .search__error__box { + color: #9ca0a5 !important; + } + } diff --git a/docs/source/conf.py b/docs/source/conf.py index 15df491168..9b3d04564a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -122,7 +122,7 @@ html_logo = "flyte_circle_gradient_1_4x4.png" pygments_style = "tango" -pygments_dark_style = "paraiso-dark" +pygments_dark_style = "native" html_css_files = [ "custom.css", From b2ee9388123a8431008a740218e9a2ad675b1398 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Tue, 6 Apr 2021 21:33:03 -0700 Subject: [PATCH 09/33] Added a missing test for != `ne` condition (#443) * Added a missing test for != `ne` condition Signed-off-by: Ketan Umare * improved the test to use previous task output Signed-off-by: Ketan Umare Signed-off-by: Max Hoffman --- tests/flytekit/unit/core/test_type_hints.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 271cc3cd62..4590bfb1b8 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -461,6 +461,7 @@ def print_expr(expr): print_expr(((px == py) & (px < py)) | (px > py)) print_expr(px < 5) print_expr(px >= 5) + print_expr(px != 5) def test_comparison_lits(): @@ -478,6 +479,8 @@ def eval_expr(expr, expected: bool): eval_expr(px < 5, False) eval_expr(px >= 5, True) eval_expr(py >= 5, True) + eval_expr(py != 5, True) + eval_expr(px != 5, False) def test_wf1_branches(): @@ -511,6 +514,27 @@ def my_wf(a: int, b: str) -> (int, str): assert x == (4, "It is hello") +def test_wf1_branches_ne(): + @task + def t1(a: int) -> int: + return a + 1 + + @task + def t2(a: str) -> str: + return a + + @workflow + def my_wf(a: int, b: str) -> str: + new_a = t1(a=a) + return conditional("test1").if_(new_a != 5).then(t2(a=b)).else_().fail("Unable to choose branch") + + with pytest.raises(ValueError): + my_wf(a=4, b="hello") + + x = my_wf(a=5, b="hello") + assert x == "hello" + + def test_wf1_branches_no_else(): with pytest.raises(NotImplementedError): From 8a1778664f75342f4e6ebaf71152b16535460c95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Barrag=C3=A1n?= Date: Wed, 7 Apr 2021 14:18:30 -0700 Subject: [PATCH 10/33] Propagate required envs to executors and fix bug to set configs in Spark task (#444) Signed-off-by: Max Hoffman --- flytekit/core/type_engine.py | 2 +- plugins/spark/flytekitplugins/spark/task.py | 48 ++++++++++++++------- plugins/tests/spark/test_spark_task.py | 27 ++++++++++++ plugins/tests/spark/test_wf.py | 4 +- 4 files changed, 62 insertions(+), 19 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e17137f8ae..cd0b6a3f90 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -137,7 +137,7 @@ def __init__(self, name: str, t: Type[T]): super().__init__(name, t) def get_literal_type(self, t: Type[T] = None) -> LiteralType: - raise RestrictedTypeError(f"Transformer for type{self.python_type} is restricted currently") + raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently") class DataclassTransformer(TypeTransformer[object]): diff --git a/plugins/spark/flytekitplugins/spark/task.py b/plugins/spark/flytekitplugins/spark/task.py index 2415392376..bd504e2542 100644 --- a/plugins/spark/flytekitplugins/spark/task.py +++ b/plugins/spark/flytekitplugins/spark/task.py @@ -34,6 +34,9 @@ def __post_init__(self): self.hadoop_conf = {} +# This method does not reset the SparkSession since it's a bit hard to handle multiple +# Spark sessions in a single application as it's described in: +# https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. def new_spark_session(name: str, conf: typing.Dict[str, str] = None): """ Optionally creates a new spark session and returns it. @@ -46,17 +49,26 @@ def new_spark_session(name: str, conf: typing.Dict[str, str] = None): # We run in cluster-mode in Flyte. # Ref https://github.com/lyft/flyteplugins/blob/master/go/tasks/v1/flytek8s/k8s_resource_adds.go#L46 + sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {name}") if "FLYTE_INTERNAL_EXECUTION_ID" not in os.environ and conf is not None: # If either of above cases is not true, then we are in local execution of this task # Add system spark-conf for local/notebook based execution. - spark_conf = set() + sess_builder = sess_builder.master("local[*]") + spark_conf = _pyspark.SparkConf() for k, v in conf.items(): - spark_conf.add((k, v)) - spark_conf.add(("spark.master", "local")) - _pyspark.SparkConf().setAll(spark_conf) - - sess = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {name}").getOrCreate() - return sess + spark_conf.set(k, v) + spark_conf.set("spark.driver.bindAddress", "127.0.0.1") + # In local execution, propagate PYTHONPATH to executors too. This makes the spark + # execution hermetic to the execution environment. For example, it allows running + # Spark applications using Bazel, without major changes. + if "PYTHONPATH" in os.environ: + spark_conf.setExecutorEnv("PYTHONPATH", os.environ["PYTHONPATH"]) + sess_builder = sess_builder.config(conf=spark_conf) + + # If there is a global SparkSession available, get it and try to stop it. + _pyspark.sql.SparkSession.builder.getOrCreate().stop() + + return sess_builder.getOrCreate() # SparkSession.Stop does not work correctly, as it stops the session before all the data is written # sess.stop() @@ -75,12 +87,13 @@ def __init__(self, task_config: Spark, task_function: Callable, **kwargs): task_function=task_function, **kwargs, ) + self.sess = None def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job = _task_model.SparkJob( spark_conf=self.task_config.spark_conf, hadoop_conf=self.task_config.hadoop_conf, - application_file="local://" + settings.entrypoint_settings.path, + application_file="local://" + settings.entrypoint_settings.path if settings.entrypoint_settings else "", executor_path=settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, @@ -91,17 +104,22 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark ctx = FlyteContext.current_context() + sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}") if not (ctx.execution_state and ctx.execution_state.Mode == ExecutionState.Mode.TASK_EXECUTION): # If either of above cases is not true, then we are in local execution of this task # Add system spark-conf for local/notebook based execution. - spark_conf = set() + spark_conf = _pyspark.SparkConf() for k, v in self.task_config.spark_conf.items(): - spark_conf.add((k, v)) - spark_conf.add(("spark.master", "local")) - _pyspark.SparkConf().setAll(spark_conf) - - sess = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}").getOrCreate() - return user_params.builder().add_attr("SPARK_SESSION", sess).build() + spark_conf.set(k, v) + # In local execution, propagate PYTHONPATH to executors too. This makes the spark + # execution hermetic to the execution environment. For example, it allows running + # Spark applications using Bazel, without major changes. + if "PYTHONPATH" in os.environ: + spark_conf.setExecutorEnv("PYTHONPATH", os.environ["PYTHONPATH"]) + sess_builder = sess_builder.config(conf=spark_conf) + + self.sess = sess_builder.getOrCreate() + return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() # Inject the Spark plugin into flytekits dynamic plugin loading system diff --git a/plugins/tests/spark/test_spark_task.py b/plugins/tests/spark/test_spark_task.py index 49677b690c..4fc80d541f 100644 --- a/plugins/tests/spark/test_spark_task.py +++ b/plugins/tests/spark/test_spark_task.py @@ -1,7 +1,9 @@ from flytekitplugins.spark import Spark +from flytekitplugins.spark.task import new_spark_session import flytekit from flytekit import task +from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.extend import Image, ImageConfig, SerializationSettings @@ -26,3 +28,28 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} + + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = my_spark.pre_execute(p) + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + + assert my_spark.sess is not None + configs = my_spark.sess.sparkContext.getConf().getAll() + assert ("spark", "1") in configs + assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + + +def test_new_spark_session(): + name = "SessionName" + spark_conf = {"spark1": "1", "spark2": "2"} + new_sess = new_spark_session(name, spark_conf) + configs = new_sess.sparkContext.getConf().getAll() + assert new_sess is not None + assert ("spark.driver.bindAddress", "127.0.0.1") in configs + assert ("spark.master", "local[*]") in configs + assert ("spark1", "1") in configs + assert ("spark2", "2") in configs diff --git a/plugins/tests/spark/test_wf.py b/plugins/tests/spark/test_wf.py index f610495184..56b7734f5f 100644 --- a/plugins/tests/spark/test_wf.py +++ b/plugins/tests/spark/test_wf.py @@ -1,5 +1,3 @@ -import typing - import pandas import pyspark from flytekitplugins.spark.task import Spark @@ -11,7 +9,7 @@ def test_wf1_with_spark(): @task(task_config=Spark()) - def my_spark(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def my_spark(a: int) -> (int, str): session = flytekit.current_context().spark_session assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" return a + 2, "world" From 6a745c11fc8449ef9f31b96fba7f93e6e0311917 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 12 Apr 2021 12:30:50 -0700 Subject: [PATCH 11/33] Add a lot more user friendly launch plan creating function (#442) Signed-off-by: Max Hoffman --- codecov.yml | 2 +- flytekit/core/launch_plan.py | 64 ++++++++++++++++++++ flytekit/core/workflow.py | 4 +- tests/flytekit/unit/core/test_launch_plan.py | 39 ++++++++++++ 4 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 tests/flytekit/unit/core/test_launch_plan.py diff --git a/codecov.yml b/codecov.yml index 1ab1b0d088..45721a38f7 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,7 +1,7 @@ ignore: - "flytekit/bin" - "test_*.py" - - "tests/flytekit/unit/core/test_node_creation.py" - "flytekit/__init__.py" - "flytekit/extend/__init__.py" - "flytekit/testing/__init__.py" + - "tests/*" diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 9853fd3c2c..0da5e6c79d 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -20,6 +20,15 @@ class LaunchPlan(object): @staticmethod def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: + """ + Users should probably call the get_or_create function defined below instead. A default launch plan is the one + that will just pick up whatever default values are defined in the workflow function signature (if any) and + use the default auth information supplied during serialization, with no notifications or schedules. + + :param ctx: This is not flytekit.current_context(). This is an internal context object. Users familiar with + flytekit should feel free to use this however. + :param workflow: The workflow to create a launch plan for. + """ if workflow.name in LaunchPlan.CACHE: return LaunchPlan.CACHE[workflow.name] @@ -93,6 +102,61 @@ def create( cls.CACHE[name] = lp return lp + @classmethod + def get_or_create( + cls, + workflow: _annotated_workflow.WorkflowBase, + name: Optional[str] = None, + default_inputs: Dict[str, Any] = None, + fixed_inputs: Dict[str, Any] = None, + schedule: _schedule_model.Schedule = None, + notifications: List[_common_models.Notification] = None, + auth_role: _common_models.AuthRole = None, + ) -> LaunchPlan: + """ + This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not + supplied, this assumes you are looking for the default launch plan for the workflow. If it is specified, it + will be used. If creating the default launch plan, none of the other arguments may be specified. + + The resulting launch plan is also cached and if called again with the same name, the + cached version is returned + + :param workflow: The Workflow to create a launch plan for. + :param name: If you supply a name, keep it mind it needs to be unique. That is, project, domain, version, and + this name form a primary key. If you do not supply a name, this function will assume you want the default + launch plan for the given workflow. + :param default_inputs: Default inputs, expressed as Python values. + :param fixed_inputs: Fixed inputs, expressed as Python values. At call time, these cannot be changed. + :param schedule: Optional schedule to run on. + :param notifications: Notifications to send. + :param auth_role: Add an auth role if necessary. + """ + if name is None and ( + default_inputs is not None + or fixed_inputs is not None + or schedule is not None + or notifications is not None + or auth_role is not None + ): + raise ValueError( + "Only named launchplans can be created that have other properties. Drop the name if you want to create a default launchplan. Default launchplans cannot have any other associations" + ) + + if name is not None and name in LaunchPlan.CACHE: + # TODO: Add checking of the other arguments (default_inputs, fixed_inputs, etc.) to make sure they match + return LaunchPlan.CACHE[name] + elif name is None and workflow.name in LaunchPlan.CACHE: + return LaunchPlan.CACHE[workflow.name] + + # Otherwise, handle the default launch plan case + if name is None: + ctx = FlyteContext.current_context() + lp = cls.get_default_launch_plan(ctx, workflow) + else: + lp = cls.create(name, workflow, default_inputs, fixed_inputs, schedule, notifications, auth_role) + LaunchPlan.CACHE[name or workflow.name] = lp + return lp + # TODO: Add QoS after it's done def __init__( self, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 69b7b27b4f..9aacfd6bf8 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -562,7 +562,9 @@ def ready(self) -> bool: class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): """ - More comments to come. + Please read :std:ref:`flyte:divedeep-workflows` first for a high-level understanding of what workflows are in Flyte. + This Python object represents a workflow defined by a function and decorated with the + :py:func:`@workflow ` decorator. Please see notes on that object for additional information. """ def __init__( diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py new file mode 100644 index 0000000000..98b853c5df --- /dev/null +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -0,0 +1,39 @@ +import typing + +import pytest + +from flytekit.core import launch_plan +from flytekit.core.task import task +from flytekit.core.workflow import workflow + + +def test_lp(): + @task + def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + a = a + 2 + return a, "world-" + str(a) + + @task + def t2(a: str, b: str) -> str: + return b + a + + @workflow + def wf(a: int) -> (str, str): + x, y = t1(a=a) + u, v = t1(a=x) + return y, v + + lp = launch_plan.LaunchPlan.get_or_create(wf, "get_or_create1") + lp2 = launch_plan.LaunchPlan.get_or_create(wf, "get_or_create1") + assert lp.name == "get_or_create1" + assert lp is lp2 + + default_lp = launch_plan.LaunchPlan.get_or_create(wf) + default_lp2 = launch_plan.LaunchPlan.get_or_create(wf) + assert default_lp is default_lp2 + + with pytest.raises(ValueError): + launch_plan.LaunchPlan.get_or_create(wf, default_inputs={"a": 3}) + + lp_with_defaults = launch_plan.LaunchPlan.create("get_or_create2", wf, default_inputs={"a": 3}) + assert lp_with_defaults.parameters.parameters["a"].default.scalar.primitive.integer == 3 From 5ee9bf094b45b49549c40748830ebbbbd661c65c Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 13 Apr 2021 08:45:42 -0700 Subject: [PATCH 12/33] Fast register for dynamic tasks (#437) Signed-off-by: Max Hoffman --- flytekit/bin/entrypoint.py | 86 +++++++++++++++++-- flytekit/common/translator.py | 2 +- flytekit/core/context_manager.py | 8 +- flytekit/core/python_function_task.py | 38 +++++++- plugins/tests/pod/test_pod.py | 2 +- .../unit/core/test_context_manager.py | 13 ++- .../unit/core/test_dynamic_conditional.py | 2 +- tests/flytekit/unit/core/test_type_hints.py | 45 +++++++++- 8 files changed, 178 insertions(+), 18 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 86c94fa2d7..3c33b13f30 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -142,7 +142,14 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, _logging.info(f"Engine folder written successfully to the output prefix {output_prefix}") -def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str): +def _handle_annotated_task( + task_def: PythonTask, + inputs: str, + output_prefix: str, + raw_output_data_prefix: str, + dynamic_addl_distro: str = None, + dynamic_dest_dir: str = None, +): """ Entrypoint for all PythonTask extensions """ @@ -224,7 +231,9 @@ def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx: # Because execution states do not look up the context chain, it has to be made last with ctx.new_execution_context( - mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters + mode=ExecutionState.Mode.TASK_EXECUTION, + execution_params=execution_parameters, + additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir}, ) as ctx: _dispatch_execute(ctx, task_def, inputs, output_prefix) @@ -281,7 +290,16 @@ def _load_resolver(resolver_location: str) -> TaskResolverMixin: @_scopes.system_entry_point -def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver: str, resolver_args: List[str]): +def _execute_task( + inputs, + output_prefix, + raw_output_data_prefix, + test, + resolver: str, + resolver_args: List[str], + dynamic_addl_distro: str = None, + dynamic_dest_dir: str = None, +): """ This function should be called for new API tasks (those only available in 0.16 and later that leverage Python native typing). @@ -299,6 +317,10 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver: :param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be nested). :param resolver_args: Args that will be passed to the aforementioned resolver's load_task function + :param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the + compressed code archive has been uploaded. + :param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed + code archives should be installed in the flyte task container. :return: """ if len(resolver_args) < 1: @@ -313,12 +335,22 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver: f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" ) return - _handle_annotated_task(_task_def, inputs, output_prefix, raw_output_data_prefix) + _handle_annotated_task( + _task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir + ) @_scopes.system_entry_point def _execute_map_task( - inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver: str, resolver_args: List[str] + inputs, + output_prefix, + raw_output_data_prefix, + max_concurrency, + test, + dynamic_addl_distro: str, + dynamic_dest_dir: str, + resolver: str, + resolver_args: List[str], ): if len(resolver_args) < 1: raise Exception(f"Resolver args cannot be <1, got {resolver_args}") @@ -342,7 +374,9 @@ def _execute_map_task( ) return - _handle_annotated_task(map_task, inputs, output_prefix, raw_output_data_prefix) + _handle_annotated_task( + map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir + ) @_click.group() @@ -357,6 +391,8 @@ def _pass_through(): @_click.option("--output-prefix", required=True) @_click.option("--raw-output-data-prefix", required=False) @_click.option("--test", is_flag=True) +@_click.option("--dynamic-addl-distro", required=False) +@_click.option("--dynamic-dest-dir", required=False) @_click.option("--resolver", required=False) @_click.argument( "resolver-args", @@ -364,7 +400,16 @@ def _pass_through(): nargs=-1, ) def execute_task_cmd( - task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args + task_module, + task_name, + inputs, + output_prefix, + raw_output_data_prefix, + test, + dynamic_addl_distro, + dynamic_dest_dir, + resolver, + resolver_args, ): _click.echo(_utils.get_version_message()) # Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original @@ -382,7 +427,16 @@ def execute_task_cmd( _legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test) else: _click.echo(f"Attempting to run with {resolver}...") - _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args) + _execute_task( + inputs, + output_prefix, + raw_output_data_prefix, + test, + resolver, + resolver_args, + dynamic_addl_distro, + dynamic_dest_dir, + ) @_pass_through.command("pyflyte-fast-execute") @@ -405,7 +459,15 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): # Use the commandline to run the task execute command rather than calling it directly in python code # since the current runtime bytecode references the older user code, rather than the downloaded distribution. - _os.system(" ".join(task_execute_cmd)) + + # Insert the call to fast before the unbounded resolver args + cmd = [] + for arg in task_execute_cmd: + if arg == "--resolver": + cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir]) + cmd.append(arg) + + _os.system(" ".join(cmd)) @_pass_through.command("pyflyte-map-execute") @@ -414,6 +476,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): @_click.option("--raw-output-data-prefix", required=False) @_click.option("--max-concurrency", type=int, required=False) @_click.option("--test", is_flag=True) +@_click.option("--dynamic-addl-distro", required=False) +@_click.option("--dynamic-dest-dir", required=False) @_click.option("--resolver", required=True) @_click.argument( "resolver-args", @@ -426,6 +490,8 @@ def map_execute_task_cmd( raw_output_data_prefix, max_concurrency, test, + dynamic_addl_distro, + dynamic_dest_dir, resolver, resolver_args, ): @@ -437,6 +503,8 @@ def map_execute_task_cmd( raw_output_data_prefix, max_concurrency, test, + dynamic_addl_distro, + dynamic_dest_dir, resolver, resolver_args, ) diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index 4a4218f07f..0209e63831 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -169,7 +169,7 @@ def get_serializable_workflow( # Translate nodes upstream_sdk_nodes = [ - get_serializable(entity_mapping, settings, n) + get_serializable(entity_mapping, settings, n, fast) for n in entity.nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 95efbe9381..391b42a5a0 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -412,13 +412,19 @@ def new_execution_context( working_dir = working_dir or self.file_access.get_random_local_directory() engine_dir = os.path.join(working_dir, "engine_dir") pathlib.Path(engine_dir).mkdir(parents=True, exist_ok=True) + if additional_context is None: + additional_context = self.execution_state.additional_context if self.execution_state is not None else None + elif self.execution_state is not None and self.execution_state.additional_context is not None: + additional_context = {**self.execution_state.additional_context, **additional_context} exec_state = ExecutionState( mode=mode, working_dir=working_dir, engine_dir=engine_dir, additional_context=additional_context ) # If a wf_params object was not given, use the default (defined at the bottom of this file) new_ctx = FlyteContext( - parent=self, execution_state=exec_state, user_space_params=execution_params or default_user_space_params + parent=self, + execution_state=exec_state, + user_space_params=execution_params or default_user_space_params, ) FlyteContext.OBJS.append(new_ctx) try: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 9cd9866791..eba817084e 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -165,7 +165,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def compile_into_workflow( - self, ctx: FlyteContext, task_function: Callable, **kwargs + self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: with ctx.new_compilation_context(prefix="dynamic"): # TODO: Resolve circular import @@ -178,7 +178,7 @@ def compile_into_workflow( self._wf.compile(**kwargs) wf = self._wf - sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf) + sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf, is_fast_execution) # If no nodes were produced, let's just return the strict outputs if len(sdk_workflow.nodes) == 0: @@ -192,6 +192,33 @@ def compile_into_workflow( for n in sdk_workflow.nodes: self.aggregate(tasks, sub_workflows, n) + if is_fast_execution: + if ( + not ctx.execution_state + or not ctx.execution_state.additional_context + or not ctx.execution_state.additional_context.get("dynamic_addl_distro") + ): + raise AssertionError( + "Compilation for a dynamic workflow called in fast execution mode but no additional code " + "distribution could be retrieved" + ) + logger.warn(f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}") + sanitized_tasks = set() + for task in tasks: + sanitized_args = [] + for arg in task.container.args: + if arg == "{{ .remote_package_path }}": + sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_addl_distro")) + elif arg == "{{ .dest_dir }}": + sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_dest_dir", ".")) + else: + sanitized_args.append(arg) + del task.container.args[:] + task.container.args.extend(sanitized_args) + sanitized_tasks.add(task) + + tasks = sanitized_tasks + dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(sdk_workflow.nodes), tasks=list(tasks), @@ -241,4 +268,9 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: return task_function(**kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: - return self.compile_into_workflow(ctx, task_function, **kwargs) + is_fast_execution = bool( + ctx.execution_state + and ctx.execution_state.additional_context + and ctx.execution_state.additional_context.get("dynamic_addl_distro") + ) + return self.compile_into_workflow(ctx, is_fast_execution, task_function, **kwargs) diff --git a/plugins/tests/pod/test_pod.py b/plugins/tests/pod/test_pod.py index 59adc0313d..6a737331b7 100644 --- a/plugins/tests/pod/test_pod.py +++ b/plugins/tests/pod/test_pod.py @@ -202,7 +202,7 @@ def dynamic_pod_task(a: int) -> List[int]: ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: - dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5) + dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, False, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 304c51f888..fad4d9db4c 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -1,4 +1,4 @@ -from flytekit.core.context_manager import CompilationState, FlyteContext, look_up_image_info +from flytekit.core.context_manager import CompilationState, ExecutionState, FlyteContext, look_up_image_info class SampleTestClass(object): @@ -42,3 +42,14 @@ def test_look_up_image_info(): assert img.name == "x" assert img.tag == "latest" assert img.fqn == "localhost:5000/xyz" + + +def test_additional_context(): + with FlyteContext.current_context() as ctx: + with ctx.new_execution_context( + mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "outer", 2: "foo"} + ) as exec_ctx_outer: + with exec_ctx_outer.new_execution_context( + mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "inner", 3: "baz"} + ) as exec_ctx_inner: + assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"} diff --git a/tests/flytekit/unit/core/test_dynamic_conditional.py b/tests/flytekit/unit/core/test_dynamic_conditional.py index 5d1c956ce6..d9bc83b5d1 100644 --- a/tests/flytekit/unit/core/test_dynamic_conditional.py +++ b/tests/flytekit/unit/core/test_dynamic_conditional.py @@ -88,6 +88,6 @@ def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]: ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = merge_sort_remotely.compile_into_workflow( - ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5] + ctx, False, merge_sort_remotely._task_function, in1=[2, 3, 4, 5] ) assert len(dynamic_job_spec.tasks) == 5 diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 4590bfb1b8..dab95a477c 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -412,10 +412,53 @@ def my_wf(a: int, b: str) -> (str, typing.List[str]): ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: - dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5) + dynamic_job_spec = my_subwf.compile_into_workflow(ctx, False, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 +def test_wf1_with_fast_dynamic(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @dynamic + def my_subwf(a: int) -> typing.List[str]: + s = [] + for i in range(a): + s.append(t1(a=i)) + return s + + @workflow + def my_wf(a: int) -> typing.List[str]: + v = my_subwf(a=a) + return v + + with context_manager.FlyteContext.current_context().new_serialization_settings( + serialization_settings=context_manager.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + ) as ctx: + with ctx.new_execution_context( + mode=ExecutionState.Mode.TASK_EXECUTION, + additional_context={ + "dynamic_addl_distro": "s3::/my-s3-bucket/fast/123", + "dynamic_dest_dir": "/User/flyte/workflows", + }, + ) as ctx: + dynamic_job_spec = my_subwf.compile_into_workflow(ctx, True, my_subwf._task_function, a=5) + assert len(dynamic_job_spec._nodes) == 5 + assert len(dynamic_job_spec.tasks) == 1 + args = " ".join(dynamic_job_spec.tasks[0].container.args) + assert args.startswith( + "pyflyte-fast-execute --additional-distribution s3::/my-s3-bucket/fast/123 --dest-dir /User/flyte/workflows" + ) + + def test_list_output(): @task def t1(a: int) -> str: From 68fa2ca06049aed5dca0049b3c96c1b29c0d2279 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 13 Apr 2021 10:09:38 -0700 Subject: [PATCH 13/33] Do not omit variables without defaults from parameter map when constructing launch plan (#446) Signed-off-by: Jeev B Signed-off-by: Max Hoffman --- flytekit/core/launch_plan.py | 4 +--- tests/flytekit/unit/core/test_composition.py | 22 +++++++++++++++----- tests/flytekit/unit/core/test_type_hints.py | 8 ++++++- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 0da5e6c79d..6d746052ca 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -174,9 +174,7 @@ def __init__( self._name = name self._workflow = workflow # Ensure fixed inputs are not in parameter map - parameters = { - k: v for k, v in parameters.parameters.items() if k not in fixed_inputs.literals and v.default is not None - } + parameters = {k: v for k, v in parameters.parameters.items() if k not in fixed_inputs.literals} self._parameters = _interface_models.ParameterMap(parameters=parameters) self._fixed_inputs = fixed_inputs # See create() for additional information diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index f1e8a37852..6a50531f00 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -3,6 +3,7 @@ from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.models import literals as _literal_models def test_wf1_with_subwf(): @@ -76,16 +77,27 @@ def my_wf(a: int, b: int) -> (str, str, int, int): return y, v, x, u lp = launch_plan.LaunchPlan.create("test1", my_wf) - assert len(lp.parameters.parameters) == 0 + assert len(lp.parameters.parameters) == 2 + assert lp.parameters.parameters["a"].required + assert lp.parameters.parameters["a"].default is None + assert lp.parameters.parameters["b"].required + assert lp.parameters.parameters["b"].default is None assert len(lp.fixed_inputs.literals) == 0 lp_with_defaults = launch_plan.LaunchPlan.create("test2", my_wf, default_inputs={"a": 3}) - assert len(lp_with_defaults.parameters.parameters) == 1 + assert len(lp_with_defaults.parameters.parameters) == 2 + assert not lp_with_defaults.parameters.parameters["a"].required + assert lp_with_defaults.parameters.parameters["a"].default == _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) + ) assert len(lp_with_defaults.fixed_inputs.literals) == 0 lp_with_fixed = launch_plan.LaunchPlan.create("test3", my_wf, fixed_inputs={"a": 3}) - assert len(lp_with_fixed.parameters.parameters) == 0 + assert len(lp_with_fixed.parameters.parameters) == 1 assert len(lp_with_fixed.fixed_inputs.literals) == 1 + assert lp_with_fixed.fixed_inputs.literals["a"] == _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) + ) @workflow def my_wf2(a: int, b: int = 42) -> (str, str, int, int): @@ -94,7 +106,7 @@ def my_wf2(a: int, b: int = 42) -> (str, str, int, int): return y, v, x, u lp = launch_plan.LaunchPlan.create("test4", my_wf2) - assert len(lp.parameters.parameters) == 1 + assert len(lp.parameters.parameters) == 2 assert len(lp.fixed_inputs.literals) == 0 lp_with_defaults = launch_plan.LaunchPlan.create("test5", my_wf2, default_inputs={"a": 3}) @@ -110,7 +122,7 @@ def my_wf2(a: int, b: int = 42) -> (str, str, int, int): assert lp_with_fixed(b=3) == ("world-5", "world-5", 5, 5) lp_with_fixed = launch_plan.LaunchPlan.create("test7", my_wf2, fixed_inputs={"b": 3}) - assert len(lp_with_fixed.parameters.parameters) == 0 + assert len(lp_with_fixed.parameters.parameters) == 1 assert len(lp_with_fixed.fixed_inputs.literals) == 1 diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index dab95a477c..9484c758af 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -23,6 +23,7 @@ from flytekit.core.type_engine import RestrictedTypeError, TypeEngine from flytekit.core.workflow import workflow from flytekit.interfaces.data.data_proxy import FileAccessProvider +from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter from flytekit.models.task import Resources as _resource_models @@ -684,11 +685,16 @@ def my_subwf(a: int) -> (str, str): env={}, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) - assert len(sdk_lp.default_inputs.parameters) == 0 + assert len(sdk_lp.default_inputs.parameters) == 1 + assert sdk_lp.default_inputs.parameters["a"].required assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 + assert not sdk_lp.default_inputs.parameters["a"].required + assert sdk_lp.default_inputs.parameters["a"].default == _literal_models.Literal( + scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) + ) assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the From fa063ff9e08f8cde1ab1a3442ee88bbdc07ac01f Mon Sep 17 00:00:00 2001 From: Maximilian Hoffman Date: Thu, 22 Apr 2021 13:13:48 -0700 Subject: [PATCH 14/33] Sqlalchemy Task (#445) Signed-off-by: Max Hoffman --- dev-requirements.txt | 8 +- doc-requirements.txt | 46 +++++----- plugins/setup.py | 1 + .../flytekitplugins/sqlalchemy/__init__.py | 1 + .../flytekitplugins/sqlalchemy/task.py | 85 +++++++++++++++++++ plugins/sqlalchemy/setup.py | 34 ++++++++ plugins/tests/sqlalchemy/__init__.py | 0 plugins/tests/sqlalchemy/test_sql_tracker.py | 31 +++++++ plugins/tests/sqlalchemy/test_task.py | 85 +++++++++++++++++++ requirements-spark2.txt | 58 ++++++------- requirements.txt | 38 ++++----- 11 files changed, 310 insertions(+), 77 deletions(-) create mode 100644 plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py create mode 100644 plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py create mode 100644 plugins/sqlalchemy/setup.py create mode 100644 plugins/tests/sqlalchemy/__init__.py create mode 100644 plugins/tests/sqlalchemy/test_sql_tracker.py create mode 100644 plugins/tests/sqlalchemy/test_task.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 5acc7855b8..873831d1c2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -27,7 +27,7 @@ flake8-black==0.2.1 # via -r dev-requirements.in flake8-isort==4.0.0 # via -r dev-requirements.in -flake8==3.9.0 +flake8==3.9.1 # via # -r dev-requirements.in # flake8-black @@ -71,9 +71,9 @@ pyparsing==2.4.7 # via # -c requirements.txt # packaging -pytest==6.2.2 +pytest==6.2.3 # via -r dev-requirements.in -regex==2021.3.17 +regex==2021.4.4 # via # -c requirements.txt # black @@ -85,7 +85,7 @@ toml==0.10.2 # black # coverage # pytest -typed-ast==1.4.2 +typed-ast==1.4.3 # via # -c requirements.txt # black diff --git a/doc-requirements.txt b/doc-requirements.txt index 84195f76de..07231a1bd6 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.5.2 +astroid==2.5.3 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -39,9 +39,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.40 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.40 +botocore==1.20.55 # via # boto3 # s3transfer @@ -60,7 +60,7 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 # via @@ -70,7 +70,7 @@ css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -88,15 +88,15 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit -furo==2021.3.20b30 +furo==2021.4.11b34 # via -r doc-requirements.in gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via # -r doc-requirements.in # flytekit @@ -106,11 +106,11 @@ idna==2.10 # via requests imagesize==1.2.0 # via sphinx -importlib-metadata==3.9.1 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via @@ -154,7 +154,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.11.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -172,7 +172,7 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert @@ -190,7 +190,7 @@ packaging==20.9 # via # bleach # sphinx -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -198,7 +198,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -210,7 +210,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -267,7 +267,7 @@ pyzmq==22.0.3 # via jupyter-client readthedocs-sphinx-search==0.1.0 # via -r doc-requirements.in -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -283,9 +283,9 @@ retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.4 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash @@ -314,7 +314,7 @@ sortedcontainers==2.3.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -sphinx-autoapi==1.7.0 +sphinx-autoapi==1.8.0 # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in @@ -326,7 +326,7 @@ sphinx-material==0.0.32 # via -r doc-requirements.in sphinx-prompt==1.4.0 # via -r doc-requirements.in -sphinx==3.5.3 +sphinx==3.5.4 # via # -r doc-requirements.in # furo @@ -368,7 +368,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -379,7 +379,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -414,7 +414,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: diff --git a/plugins/setup.py b/plugins/setup.py index e03e135be2..abd2d8c7c8 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -15,6 +15,7 @@ "flytekitplugins-awssagemaker": "awssagemaker", "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", + "flytekitplugins-sqlalchemy": "sqlalchemy", } diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py new file mode 100644 index 0000000000..aaf8ade06f --- /dev/null +++ b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py @@ -0,0 +1 @@ +from .task import SQLAlchemyConfig, SQLAlchemyTask diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py new file mode 100644 index 0000000000..38f62cb8b7 --- /dev/null +++ b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -0,0 +1,85 @@ +import typing +from dataclasses import dataclass + +import pandas as pd +from sqlalchemy import create_engine + +from flytekit import current_context, kwtypes +from flytekit.core.base_sql_task import SQLTask +from flytekit.core.python_function_task import PythonInstanceTask +from flytekit.models.security import Secret +from flytekit.types.schema import FlyteSchema + + +@dataclass +class SQLAlchemyConfig(object): + """ + Use this configuration to configure task. String should be standard + sqlalchemy connector format + (https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls). + Database can be found: + - within the container + - or from a publicly accessible source + + Args: + uri: default sqlalchemy connector + connect_args: sqlalchemy kwarg overrides -- ex: host + secret_connect_args: flyte secrets loaded into sqlalchemy connect args + -- ex: {"password": {"name": SECRET_NAME, "group": SECRET_GROUP}} + """ + + uri: str + connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None + secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None + + +class SQLAlchemyTask(PythonInstanceTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]): + """ + Makes it possible to run client side SQLAlchemy queries that optionally return a FlyteSchema object + + TODO: How should we use pre-built containers for running portable tasks like this. Should this always be a + referenced task type? + """ + + _SQLALCHEMY_TASK_TYPE = "sqlalchemy" + + def __init__( + self, + name: str, + query_template: str, + task_config: SQLAlchemyConfig, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + **kwargs, + ): + output_schema = output_schema_type if output_schema_type else FlyteSchema + outputs = kwtypes(results=output_schema) + self._uri = task_config.uri + self._connect_args = task_config.connect_args or {} + self._secret_connect_args = task_config.secret_connect_args + + super().__init__( + name=name, + task_config=task_config, + task_type=self._SQLALCHEMY_TASK_TYPE, + query_template=query_template, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + @property + def output_columns(self) -> typing.Optional[typing.List[str]]: + c = self.python_interface.outputs["results"].column_names() + return c if c else None + + def execute(self, **kwargs) -> typing.Any: + if self._secret_connect_args is not None: + for key, secret in self._secret_connect_args.items(): + value = current_context().secrets.get(secret.group, secret.key) + self._connect_args[key] = value + engine = create_engine(self._uri, connect_args=self._connect_args, echo=False) + print(f"Connecting to db {self._uri}") + with engine.begin() as connection: + df = pd.read_sql_query(self.get_query(**kwargs), connection) + return df diff --git a/plugins/sqlalchemy/setup.py b/plugins/sqlalchemy/setup.py new file mode 100644 index 0000000000..e39b139552 --- /dev/null +++ b/plugins/sqlalchemy/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +PLUGIN_NAME = "sqlalchemy" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.17.0,<1.0.0", "sqlalchemy>=1.4.7"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="dolthub", + author_email="max@dolthub.com", + description="SQLAlchemy plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/tests/sqlalchemy/__init__.py b/plugins/tests/sqlalchemy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/sqlalchemy/test_sql_tracker.py b/plugins/tests/sqlalchemy/test_sql_tracker.py new file mode 100644 index 0000000000..aad37cd254 --- /dev/null +++ b/plugins/tests/sqlalchemy/test_sql_tracker.py @@ -0,0 +1,31 @@ +from collections import OrderedDict + +from flytekit.common.translator import get_serializable +from flytekit.core import context_manager +from flytekit.core.context_manager import Image, ImageConfig +from plugins.tests.sqlalchemy.test_task import tk as not_tk + + +def test_sql_lhs(): + assert not_tk.lhs == "tk" + + +def test_sql_command(): + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) + assert srz_t.container.args[-7:] == [ + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "plugins.tests.sqlalchemy.test_task", + "task-name", + "tk", + ] diff --git a/plugins/tests/sqlalchemy/test_task.py b/plugins/tests/sqlalchemy/test_task.py new file mode 100644 index 0000000000..6adeca7519 --- /dev/null +++ b/plugins/tests/sqlalchemy/test_task.py @@ -0,0 +1,85 @@ +import contextlib +import os +import shutil +import sqlite3 +import tempfile + +import pandas +import pytest +from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask + +from flytekit import kwtypes, task, workflow +from flytekit.types.schema import FlyteSchema + +tk = SQLAlchemyTask( + "test", + query_template="select * from tracks", + task_config=SQLAlchemyConfig( + uri="sqlite://", + ), +) + + +@pytest.fixture(scope="function") +def sql_server(): + try: + d = tempfile.TemporaryDirectory() + db_path = os.path.join(d.name, "tracks.db") + with contextlib.closing(sqlite3.connect(db_path)) as con: + con.execute("create table tracks (TrackId bigint, Name text)") + con.execute("insert into tracks values (0, 'Sue'), (1, 'L'), (2, 'M'), (3, 'Ji'), (4, 'Po')") + con.commit() + yield f"sqlite:///{db_path}" + finally: + if os.path.exists(d.name): + shutil.rmtree(d.name) + + +def test_task_static(sql_server): + tk = SQLAlchemyTask( + "test", + query_template="select * from tracks", + task_config=SQLAlchemyConfig( + uri=sql_server, + ), + ) + + assert tk.output_columns is None + + df = tk() + assert df is not None + + +def test_task_schema(sql_server): + sql_task = SQLAlchemyTask( + "test", + query_template="select TrackId, Name from tracks limit {{.inputs.limit}}", + inputs=kwtypes(limit=int), + output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], + task_config=SQLAlchemyConfig( + uri=sql_server, + ), + ) + + assert sql_task.output_columns is not None + df = sql_task(limit=1) + assert df is not None + + +def test_workflow(sql_server): + @task + def my_task(df: pandas.DataFrame) -> int: + return len(df[df.columns[0]]) + + sql_task = SQLAlchemyTask( + "test", + query_template="select * from tracks limit {{.inputs.limit}}", + inputs=kwtypes(limit=int), + task_config=SQLAlchemyConfig(uri=sql_server), + ) + + @workflow + def wf(limit: int) -> int: + return my_task(df=sql_task(limit=limit)) + + assert wf(limit=5) == 5 diff --git a/requirements-spark2.txt b/requirements-spark2.txt index c58e52ca25..f4f4e6115c 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -10,6 +10,10 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black +appnope==0.1.2 + # via + # ipykernel + # ipython async-generator==1.10 # via nbclient attrs==20.3.0 @@ -24,9 +28,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.39 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.39 +botocore==1.20.55 # via # boto3 # s3transfer @@ -45,15 +49,13 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 - # via - # paramiko - # secretstorage + # via paramiko dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -69,36 +71,32 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via flytekit hmsclient==0.1.1 # via flytekit idna==2.10 # via requests -importlib-metadata==3.7.3 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.21.0 +ipython==7.22.0 # via ipykernel jedi==0.18.0 # via ipython -jeepney==0.6.0 - # via - # keyring - # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -126,7 +124,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.10.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -144,14 +142,14 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.1 +numpy==1.20.2 # via # flytekit # pandas @@ -160,7 +158,7 @@ numpy==1.20.1 # scipy packaging==20.9 # via bleach -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -168,7 +166,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -180,7 +178,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -228,7 +226,7 @@ pyyaml==5.4.1 # via papermill pyzmq==22.0.3 # via jupyter-client -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -237,22 +235,20 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.1 +responses==0.13.2 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.3 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training -secretstorage==3.3.1 - # via keyring six==1.15.0 # via # bcrypt @@ -289,7 +285,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -300,7 +296,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -330,7 +326,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.txt b/requirements.txt index 0473ae8da4..2595fcfc44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,9 +28,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.40 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.40 +botocore==1.20.55 # via # boto3 # s3transfer @@ -49,13 +49,13 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -71,23 +71,23 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via flytekit hmsclient==0.1.1 # via flytekit idna==2.10 # via requests -importlib-metadata==3.9.1 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via @@ -124,7 +124,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.11.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -142,7 +142,7 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert @@ -158,7 +158,7 @@ numpy==1.20.2 # scipy packaging==20.9 # via bleach -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -166,7 +166,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -178,7 +178,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -226,7 +226,7 @@ pyyaml==5.4.1 # via papermill pyzmq==22.0.3 # via jupyter-client -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -241,9 +241,9 @@ retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.4 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash @@ -285,7 +285,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -296,7 +296,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -326,7 +326,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: From 86af4cf8f31a2f032219e05c8d7b34fac7130338 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 26 Apr 2021 09:13:26 -0700 Subject: [PATCH 15/33] Fix 0.0 value floats (#452) Signed-off-by: wild-endeavor Signed-off-by: Max Hoffman --- flytekit/core/type_engine.py | 4 ++-- tests/flytekit/unit/core/test_type_engine.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index cd0b6a3f90..0c7f7e12cb 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -530,9 +530,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def _check_and_covert_float(lv: Literal) -> float: - if lv.scalar.primitive.float_value: + if lv.scalar.primitive.float_value is not None: return lv.scalar.primitive.float_value - elif lv.scalar.primitive.integer: + elif lv.scalar.primitive.integer is not None: return float(lv.scalar.primitive.integer) raise RuntimeError(f"Cannot convert literal {lv} to float") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index af82e8786a..04a9e57160 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -197,3 +197,13 @@ def test_protos(): l0 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) with pytest.raises(AssertionError): TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError) + + +def test_zero_floats(): + ctx = FlyteContext.current_context() + + l0 = Literal(scalar=Scalar(primitive=Primitive(integer=0))) + l1 = Literal(scalar=Scalar(primitive=Primitive(float_value=0.0))) + + assert TypeEngine.to_python_value(ctx, l0, float) == 0 + assert TypeEngine.to_python_value(ctx, l1, float) == 0 From dd0767bb6f02a253163834724b8d9494c9c3a4b8 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 27 Apr 2021 10:05:19 -0700 Subject: [PATCH 16/33] Unit test for dynamic create node (#457) Signed-off-by: Max Hoffman --- .../flytekit/unit/core/test_node_creation.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index ae3aa7db4e..1ad74a56bf 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -7,6 +7,7 @@ from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.context_manager import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -17,13 +18,22 @@ def test_normal_task(): def t1(a: str) -> str: return a + " world" + @dynamic + def my_subwf(a: int) -> typing.List[str]: + s = [] + for i in range(a): + s.append(t1(a=str(i))) + return s + @workflow - def my_wf(a: str) -> str: + def my_wf(a: str) -> (str, typing.List[str]): t1_node = create_node(t1, a=a) - return t1_node.o0 + dyn_node = create_node(my_subwf, a=3) + return t1_node.o0, dyn_node.o0 - r = my_wf(a="hello") + r, x = my_wf(a="hello") assert r == "hello world" + assert x == ["0 world", "1 world", "2 world"] serialization_settings = context_manager.SerializationSettings( project="test_proj", @@ -33,8 +43,8 @@ def my_wf(a: str) -> str: env={}, ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, my_wf) - assert len(sdk_wf.nodes) == 1 - assert len(sdk_wf.outputs) == 1 + assert len(sdk_wf.nodes) == 2 + assert len(sdk_wf.outputs) == 2 @task def t2(): From a30e4c7a66a107215e88f2e39bbd87a280e4ef44 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Wed, 28 Apr 2021 13:52:18 -0400 Subject: [PATCH 17/33] add new control plane classes (#425) * implement new control plane classes Signed-off-by: cosmicBboy * revert dep changes Signed-off-by: cosmicBboy * remove unneeded mock integration test files Signed-off-by: cosmicBboy * remove pytest.ini Signed-off-by: cosmicBboy * add integration tests to ci, update reqs Signed-off-by: cosmicBboy * add unit tests Signed-off-by: cosmicBboy * lint Signed-off-by: cosmicBboy * address comments @wild-endeavor Signed-off-by: cosmicBboy Signed-off-by: Max Hoffman --- .github/workflows/pythonbuild.yml | 5 + .gitignore | 2 + dev-requirements.in | 1 + dev-requirements.txt | 264 +++++++++++++++- doc-requirements.txt | 11 +- docs/source/design/control_plane.rst | 4 +- flytekit/control_plane/component_nodes.py | 136 +++++++++ flytekit/control_plane/identifier.py | 137 +++++++++ flytekit/control_plane/interface.py | 24 ++ flytekit/control_plane/launch_plan.py | 196 ++++++++++++ flytekit/control_plane/nodes.py | 281 ++++++++++++++++++ flytekit/control_plane/tasks/__init__.py | 0 flytekit/control_plane/tasks/executions.py | 132 ++++++++ flytekit/control_plane/tasks/task.py | 95 ++++++ flytekit/control_plane/workflow.py | 167 +++++++++++ flytekit/control_plane/workflow_execution.py | 150 ++++++++++ requirements-spark2.txt | 7 +- requirements.txt | 7 +- .../control_plane/mock_flyte_repo/.gitignore | 1 + .../control_plane/mock_flyte_repo/README.md | 4 + .../control_plane/mock_flyte_repo/__init__.py | 0 .../mock_flyte_repo/in_container.mk | 24 ++ .../mock_flyte_repo/workflows/Dockerfile | 35 +++ .../mock_flyte_repo/workflows/Makefile | 208 +++++++++++++ .../mock_flyte_repo/workflows/__init__.py | 0 .../workflows/basic/__init__.py | 0 .../workflows/basic/basic_workflow.py | 54 ++++ .../workflows/basic/hello_world.py | 40 +++ .../mock_flyte_repo/workflows/requirements.in | 4 + .../workflows/requirements.txt | 136 +++++++++ .../mock_flyte_repo/workflows/sandbox.config | 7 + .../control_plane/test_workflow.py | 90 ++++++ tests/flytekit/unit/control_plane/__init__.py | 0 .../unit/control_plane/tasks/test_task.py | 34 +++ .../unit/control_plane/test_identifier.py | 77 +++++ .../unit/control_plane/test_workflow.py | 23 ++ 36 files changed, 2346 insertions(+), 10 deletions(-) create mode 100644 flytekit/control_plane/component_nodes.py create mode 100644 flytekit/control_plane/identifier.py create mode 100644 flytekit/control_plane/interface.py create mode 100644 flytekit/control_plane/launch_plan.py create mode 100644 flytekit/control_plane/nodes.py create mode 100644 flytekit/control_plane/tasks/__init__.py create mode 100644 flytekit/control_plane/tasks/executions.py create mode 100644 flytekit/control_plane/tasks/task.py create mode 100644 flytekit/control_plane/workflow.py create mode 100644 flytekit/control_plane/workflow_execution.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/README.md create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt create mode 100644 tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config create mode 100644 tests/flytekit/integration/control_plane/test_workflow.py create mode 100644 tests/flytekit/unit/control_plane/__init__.py create mode 100644 tests/flytekit/unit/control_plane/tasks/test_task.py create mode 100644 tests/flytekit/unit/control_plane/test_identifier.py create mode 100644 tests/flytekit/unit/control_plane/test_workflow.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index cc0d7f0a52..f56f8b300f 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -44,6 +44,11 @@ jobs: - name: Test with coverage run: | coverage run -m pytest tests/flytekit/unit tests/scripts plugins/tests + - name: Integration Tests with coverage + # https://github.com/actions/runner/issues/241#issuecomment-577360161 + shell: 'script -q -e -c "bash {0}"' + run: | + coverage run --append -m pytest tests/flytekit/integration - name: Codecov uses: codecov/codecov-action@v1 with: diff --git a/.gitignore b/.gitignore index 43cf5e7135..ca971b1a31 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ dist .python-version _build/ docs/source/generated/ +.pytest-flyte +htmlcov diff --git a/dev-requirements.in b/dev-requirements.in index 6fa5ebb080..dd38d0f7fc 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,6 @@ -c requirements.txt +git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte black coverage[toml] flake8 diff --git a/dev-requirements.txt b/dev-requirements.txt index 873831d1c2..7462d342cb 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,6 +4,10 @@ # # make dev-requirements.txt # +-e file:.#egg=flytekit + # via + # -c requirements.txt + # pytest-flyte appdirs==1.4.4 # via # -c requirements.txt @@ -11,18 +15,82 @@ appdirs==1.4.4 attrs==20.3.0 # via # -c requirements.txt + # jsonschema # pytest + # pytest-docker + # scantree +bcrypt==3.2.0 + # via + # -c requirements.txt + # paramiko black==20.8b1 # via # -c requirements.txt # -r dev-requirements.in # flake8-black +cached-property==1.5.2 + # via docker-compose +certifi==2020.12.5 + # via + # -c requirements.txt + # requests +cffi==1.14.5 + # via + # -c requirements.txt + # bcrypt + # cryptography + # pynacl +chardet==4.0.0 + # via + # -c requirements.txt + # requests click==7.1.2 # via # -c requirements.txt # black + # flytekit coverage[toml]==5.5 # via -r dev-requirements.in +croniter==1.0.12 + # via + # -c requirements.txt + # flytekit +cryptography==3.4.7 + # via + # -c requirements.txt + # paramiko +dataclasses-json==0.5.2 + # via + # -c requirements.txt + # flytekit +decorator==5.0.7 + # via + # -c requirements.txt + # retry +deprecated==1.2.12 + # via + # -c requirements.txt + # flytekit +dirhash==0.2.1 + # via + # -c requirements.txt + # flytekit +distro==1.5.0 + # via docker-compose +docker-compose==1.29.1 + # via + # pytest-docker + # pytest-flyte +docker-image-py==0.1.10 + # via + # -c requirements.txt + # flytekit +docker[ssh]==5.0.0 + # via docker-compose +dockerpty==0.4.1 + # via docker-compose +docopt==0.6.2 + # via docker-compose flake8-black==0.2.1 # via -r dev-requirements.in flake8-isort==4.0.0 @@ -32,12 +100,57 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort +flyteidl==0.18.38 + # via + # -c requirements.txt + # flytekit +grpcio==1.37.0 + # via + # -c requirements.txt + # flytekit +idna==2.10 + # via + # -c requirements.txt + # requests +importlib-metadata==4.0.1 + # via + # -c requirements.txt + # flake8 + # jsonschema + # keyring + # pluggy + # pytest iniconfig==1.1.1 # via pytest isort==5.8.0 # via # -r dev-requirements.in # flake8-isort +jinja2==2.11.3 + # via + # -c requirements.txt + # pytest-flyte +jsonschema==3.2.0 + # via + # -c requirements.txt + # docker-compose +keyring==23.0.1 + # via + # -c requirements.txt + # flytekit +markupsafe==1.1.1 + # via + # -c requirements.txt + # jinja2 +marshmallow-enum==1.5.1 + # via + # -c requirements.txt + # dataclasses-json +marshmallow==3.11.1 + # via + # -c requirements.txt + # dataclasses-json + # marshmallow-enum mccabe==0.6.1 # via flake8 mock==4.0.3 @@ -47,38 +160,155 @@ mypy-extensions==0.4.3 # -c requirements.txt # black # mypy + # typing-inspect mypy==0.812 # via -r dev-requirements.in +natsort==7.1.1 + # via + # -c requirements.txt + # flytekit +numpy==1.20.2 + # via + # -c requirements.txt + # pandas + # pyarrow packaging==20.9 # via # -c requirements.txt # pytest +pandas==1.2.4 + # via + # -c requirements.txt + # flytekit +paramiko==2.7.2 + # via + # -c requirements.txt + # docker pathspec==0.8.1 # via # -c requirements.txt # black + # scantree pluggy==0.13.1 # via pytest +protobuf==3.15.8 + # via + # -c requirements.txt + # flyteidl + # flytekit py==1.10.0 # via # -c requirements.txt # pytest + # retry +pyarrow==3.0.0 + # via + # -c requirements.txt + # flytekit pycodestyle==2.7.0 # via flake8 +pycparser==2.20 + # via + # -c requirements.txt + # cffi pyflakes==2.3.1 # via flake8 +pynacl==1.4.0 + # via + # -c requirements.txt + # paramiko pyparsing==2.4.7 # via # -c requirements.txt # packaging -pytest==6.2.3 +pyrsistent==0.17.3 + # via + # -c requirements.txt + # jsonschema +pytest-docker==0.10.1 + # via pytest-flyte +git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte # via -r dev-requirements.in +pytest==6.2.3 + # via + # -r dev-requirements.in + # pytest-docker + # pytest-flyte +python-dateutil==2.8.1 + # via + # -c requirements.txt + # croniter + # flytekit + # pandas +python-dotenv==0.17.0 + # via docker-compose +pytimeparse==1.1.8 + # via + # -c requirements.txt + # flytekit +pytz==2018.4 + # via + # -c requirements.txt + # flytekit + # pandas +pyyaml==5.4.1 + # via + # -c requirements.txt + # docker-compose regex==2021.4.4 # via # -c requirements.txt # black + # docker-image-py +requests==2.25.1 + # via + # -c requirements.txt + # docker + # docker-compose + # flytekit + # responses +responses==0.13.2 + # via + # -c requirements.txt + # flytekit +retry==0.9.2 + # via + # -c requirements.txt + # flytekit +scantree==0.0.1 + # via + # -c requirements.txt + # dirhash +six==1.15.0 + # via + # -c requirements.txt + # bcrypt + # dockerpty + # flytekit + # grpcio + # jsonschema + # protobuf + # pynacl + # python-dateutil + # responses + # scantree + # websocket-client +sortedcontainers==2.3.0 + # via + # -c requirements.txt + # flytekit +statsd==3.3.0 + # via + # -c requirements.txt + # flytekit +stringcase==1.2.0 + # via + # -c requirements.txt + # dataclasses-json testfixtures==6.17.1 # via flake8-isort +texttable==1.6.3 + # via docker-compose toml==0.10.2 # via # -c requirements.txt @@ -94,4 +324,36 @@ typing-extensions==3.7.4.3 # via # -c requirements.txt # black + # importlib-metadata # mypy + # typing-inspect +typing-inspect==0.6.0 + # via + # -c requirements.txt + # dataclasses-json +urllib3==1.25.11 + # via + # -c requirements.txt + # flytekit + # requests + # responses +websocket-client==0.58.0 + # via + # docker + # docker-compose +wheel==0.36.2 + # via + # -c requirements.txt + # flytekit +wrapt==1.12.1 + # via + # -c requirements.txt + # deprecated + # flytekit +zipp==3.4.1 + # via + # -c requirements.txt + # importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/doc-requirements.txt b/doc-requirements.txt index 07231a1bd6..105ba18676 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.38 # via flytekit furo==2021.4.11b34 # via -r doc-requirements.in @@ -107,7 +107,9 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -380,10 +382,13 @@ traitlets==5.0.5 # nbconvert # nbformat typed-ast==1.4.3 - # via black + # via + # astroid + # black typing-extensions==3.7.4.3 # via # black + # importlib-metadata # typing-inspect typing-inspect==0.6.0 # via dataclasses-json diff --git a/docs/source/design/control_plane.rst b/docs/source/design/control_plane.rst index 7b8f49539b..1d24b50b8d 100644 --- a/docs/source/design/control_plane.rst +++ b/docs/source/design/control_plane.rst @@ -3,9 +3,9 @@ ############################ Control Plane Objects ############################ -For those who require programmatic access to the control place, certain APIs are available through "control plane classes". +For those who require programmatic access to the control plane, certain APIs are available through "control plane classes". -.. note:: +.. warning:: The syntax of this section, while it will continue to work, is subject to change. diff --git a/flytekit/control_plane/component_nodes.py b/flytekit/control_plane/component_nodes.py new file mode 100644 index 0000000000..10434ab830 --- /dev/null +++ b/flytekit/control_plane/component_nodes.py @@ -0,0 +1,136 @@ +import logging as _logging +from typing import Dict + +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.models import task as _task_model +from flytekit.models.core import workflow as _workflow_model + + +class FlyteTaskNode(_workflow_model.TaskNode): + def __init__(self, flyte_task: "flytekit.control_plane.tasks.task.FlyteTask"): + self._flyte_task = flyte_task + super(FlyteTaskNode, self).__init__(None) + + @property + def reference_id(self) -> _identifier.Identifier: + """A globally unique identifier for the task.""" + return self._flyte_task.id + + @property + def flyte_task(self) -> "flytekit.control_plane.tasks.task.FlyteTask": + return self._flyte_task + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.TaskNode, + tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + ) -> "FlyteTaskNode": + """ + Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the + FlyteTask control plane. + + :param base_model: + :param tasks: + """ + from flytekit.control_plane.tasks import task as _task + + if base_model.reference_id in tasks: + task = tasks[base_model.reference_id] + _logging.info(f"Found existing task template for {task.id}, will not retrieve from Admin") + flyte_task = _task.FlyteTask.promote_from_model(task) + return cls(flyte_task) + + # if not found, fetch it from Admin + _logging.debug(f"Fetching task template for {base_model.reference_id} from Admin") + return cls( + _task.FlyteTask.fetch( + base_model.reference_id.project, + base_model.reference_id.domain, + base_model.reference_id.name, + base_model.reference_id.version, + ) + ) + + +class FlyteWorkflowNode(_workflow_model.WorkflowNode): + def __init__( + self, + flyte_workflow: "flytekit.control_plane.workflow.FlyteWorkflow" = None, + flyte_launch_plan: "flytekit.control_plane.launch_plan.FlyteLaunchPlan" = None, + ): + if flyte_workflow and flyte_launch_plan: + raise _system_exceptions.FlyteSystemException( + "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " + f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", + ) + + self._flyte_workflow = flyte_workflow + self._flyte_launch_plan = flyte_launch_plan + super(FlyteWorkflowNode, self).__init__( + launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, + sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, + ) + + def __repr__(self) -> str: + if self.flyte_workflow is not None: + return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" + return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" + + @property + def launchplan_ref(self) -> _identifier.Identifier: + """A globally unique identifier for the launch plan, which should map to Admin.""" + return self._flyte_launch_plan.id if self._flyte_launch_plan else None + + @property + def sub_workflow_ref(self): + return self._flyte_workflow.id if self._flyte_workflow else None + + @property + def flyte_launch_plan(self) -> "flytekit.control_plane.launch_plan.FlyteLaunchPlan": + return self._flyte_launch_plan + + @property + def flyte_workflow(self) -> "flytekit.control_plane.workflow.FlyteWorkflow": + return self._flyte_workflow + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.WorkflowNode, + sub_workflows: Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate], + tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + ) -> "FlyteWorkflowNode": + from flytekit.control_plane import launch_plan as _launch_plan + from flytekit.control_plane import workflow as _workflow + + fetch_args = ( + base_model.reference.project, + base_model.reference.domain, + base_model.reference.name, + base_model.reference.version, + ) + + if base_model.launch_plan_ref is not None: + return cls(flyte_launch_plan=_launch_plan.FlyteLaunchPlan.fetch(*fetch_args)) + elif base_model.sub_workflow_ref is not None: + # the workflow tempaltes for sub-workflows should have been included in the original response + if base_model.reference in sub_workflows: + return cls( + flyte_workflow=_workflow.FlyteWorkflow.promote_from_model( + sub_workflows[base_model.reference], + sub_workflows=sub_workflows, + tasks=tasks, + ) + ) + + # If not found for some reason, fetch it from Admin again. The reason there is a warning here but not for + # tasks is because sub-workflows should always be passed along. Ideally subworkflows are never even + # registered with Admin, so fetching from Admin ideelly doesn't return anything + _logging.warning(f"Your subworkflow with id {base_model.reference} is not included in the promote call.") + return cls(flyte_workflow=_workflow.FlyteWorkflow.fetch(*fetch_args)) + + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor launchplan specified." + ) diff --git a/flytekit/control_plane/identifier.py b/flytekit/control_plane/identifier.py new file mode 100644 index 0000000000..611c9af639 --- /dev/null +++ b/flytekit/control_plane/identifier.py @@ -0,0 +1,137 @@ +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models.core import identifier as _core_identifier + + +class Identifier(_core_identifier.Identifier): + + _STRING_TO_TYPE_MAP = { + "lp": _core_identifier.ResourceType.LAUNCH_PLAN, + "wf": _core_identifier.ResourceType.WORKFLOW, + "tsk": _core_identifier.ResourceType.TASK, + } + _TYPE_TO_STRING_MAP = {v: k for k, v in _STRING_TO_TYPE_MAP.items()} + + @classmethod + def promote_from_model(cls, base_model: _core_identifier.Identifier) -> "Identifier": + return cls(base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version) + + @classmethod + def from_urn(cls, urn: str) -> "Identifier": + """ + Parses a string urn in the correct format into an identifier + """ + segments = urn.split(":") + if len(segments) != 5: + raise _user_exceptions.FlyteValueException( + urn, + "The provided string was not in a parseable format. The string for an identifier must be in the " + "format entity_type:project:domain:name:version.", + ) + + resource_type, project, domain, name, version = segments + + if resource_type not in cls._STRING_TO_TYPE_MAP: + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an identifier must be one of: " + f"{list(cls._STRING_TO_TYPE_MAP.keys())}. ", + ) + + return cls(cls._STRING_TO_TYPE_MAP[resource_type], project, domain, name, version) + + def __str__(self): + return ( + f"{type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, '')}:" + f"{self.project}:" + f"{self.domain}:" + f"{self.name}:" + f"{self.version}" + ) + + +class WorkflowExecutionIdentifier(_core_identifier.WorkflowExecutionIdentifier): + @classmethod + def promote_from_model( + cls, base_model: _core_identifier.WorkflowExecutionIdentifier + ) -> "WorkflowExecutionIdentifier": + return cls(base_model.project, base_model.domain, base_model.name) + + @classmethod + def from_urn(cls, string: str) -> "WorkflowExecutionIdentifier": + """ + Parses a string in the correct format into an identifier + """ + segments = string.split(":") + if len(segments) != 4: + raise _user_exceptions.FlyteValueException( + string, + "The provided string was not in a parseable format. The string for an identifier must be in the format" + " ex:project:domain:name.", + ) + + resource_type, project, domain, name = segments + + if resource_type != "ex": + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", + ) + + return cls(project, domain, name) + + def __str__(self): + return f"ex:{self.project}:{self.domain}:{self.name}" + + +class TaskExecutionIdentifier(_core_identifier.TaskExecutionIdentifier): + @classmethod + def promote_from_model(cls, base_model: _core_identifier.TaskExecutionIdentifier) -> "TaskExecutionIdentifier": + return cls( + task_id=base_model.task_id, + node_execution_id=base_model.node_execution_id, + retry_attempt=base_model.retry_attempt, + ) + + @classmethod + def from_urn(cls, string: str) -> "TaskExecutionIdentifier": + """ + Parses a string in the correct format into an identifier + """ + segments = string.split(":") + if len(segments) != 10: + raise _user_exceptions.FlyteValueException( + string, + "The provided string was not in a parseable format. The string for an identifier must be in the format" + " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", + ) + + resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments + + if resource_type != "te": + raise _user_exceptions.FlyteValueException( + resource_type, + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", + ) + + return cls( + task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), + node_execution_id=_core_identifier.NodeExecutionIdentifier( + node_id=node_id, + execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + ), + retry_attempt=int(retry), + ) + + def __str__(self): + return ( + "te:" + f"{self.node_execution_id.execution_id.project}:" + f"{self.node_execution_id.execution_id.domain}:" + f"{self.node_execution_id.execution_id.name}:" + f"{self.node_execution_id.node_id}:" + f"{self.task_id.project}:" + f"{self.task_id.domain}:" + f"{self.task_id.name}:" + f"{self.task_id.version}:" + f"{self.retry_attempt}" + ) diff --git a/flytekit/control_plane/interface.py b/flytekit/control_plane/interface.py new file mode 100644 index 0000000000..1a7b2c6c15 --- /dev/null +++ b/flytekit/control_plane/interface.py @@ -0,0 +1,24 @@ +from typing import Any, Dict, List, Tuple + +from flytekit.control_plane import nodes as _nodes +from flytekit.models import interface as _interface_models +from flytekit.models import literals as _literal_models + + +class TypedInterface(_interface_models.TypedInterface): + @classmethod + def promote_from_model(cls, model): + """ + :param flytekit.models.interface.TypedInterface model: + :rtype: TypedInterface + """ + return cls(model.inputs, model.outputs) + + def create_bindings_for_inputs( + self, map_of_bindings: Dict[str, Any] + ) -> Tuple[List[_literal_models.Binding], List[_nodes.FlyteNode]]: + """ + :param: map_of_bindings: this can be scalar primitives, it can be node output references, lists, etc. + :raises: flytekit.common.exceptions.user.FlyteAssertion + """ + return [], [] diff --git a/flytekit/control_plane/launch_plan.py b/flytekit/control_plane/launch_plan.py new file mode 100644 index 0000000000..ca189fd33d --- /dev/null +++ b/flytekit/control_plane/launch_plan.py @@ -0,0 +1,196 @@ +import uuid as _uuid +from typing import Any, List + +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interface +from flytekit.control_plane import nodes as _nodes +from flytekit.control_plane import workflow_execution as _workflow_execution +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution_models +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import literals as _literal_models +from flytekit.models.core import identifier as _identifier_model + + +class FlyteLaunchPlan(_launch_plan_models.LaunchPlanSpec): + def __init__(self, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = None + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + + @classmethod + def promote_from_model(cls, model: _launch_plan_models.LaunchPlanSpec) -> "FlyteLaunchPlan": + return cls( + workflow_id=_identifier.Identifier.promote_from_model(model.workflow_id), + default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + ) + + @_exception_scopes.system_entry_point + def register(self, project, domain, name, version): + # NOTE: does this need to be implemented in the control plane? + pass + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteLaunchPlan": + """ + This function uses the engine loader to call create a hydrated task from Admin. + :param project: + :param domain: + :param name: + :param version: + """ + from flytekit.control_plane import workflow as _workflow + + launch_plan_id = _identifier.Identifier( + _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version + ) + + lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) + flyte_lp = cls.promote_from_model(lp.spec) + flyte_lp._id = lp.id + + # TODO: Add a test for this, and this function as a whole + wf_id = flyte_lp.workflow_id + lp_wf = _workflow.FlyteWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) + flyte_lp._interface = lp_wf.interface + return flyte_lp + + @_exception_scopes.system_entry_point + def serialize(self): + """ + Serializing a launch plan should produce an object similar to what the registration step produces, + in preparation for actual registration to Admin. + + :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan + """ + # NOTE: does this need to be implemented in the control plane? + pass + + @property + def id(self) -> _identifier.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def workflow_id(self) -> _identifier.Identifier: + return self._workflow_id + + @property + def interface(self) -> _interface.TypedInterface: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + @_exception_scopes.system_entry_point + def validate(self): + # TODO: Validate workflow is satisfied + pass + + @_exception_scopes.system_entry_point + def update(self, state: _launch_plan_models.LaunchPlanState): + if not self.id: + raise _user_exceptions.FlyteAssertion( + "Failed to update launch plan because the launch plan's ID is not set. Please call register to fetch " + "or register the identifier first" + ) + return _flyte_engine.get_client().update_launch_plan(self.id, state) + + @_exception_scopes.system_entry_point + def launch_with_literals( + self, + project: str, + domain: str, + literal_inputs: _literal_models.LiteralMap, + name: str = None, + notification_overrides: List[_common_models.Notification] = None, + label_overrides: _common_models.Labels = None, + annotation_overrides: _common_models.Annotations = None, + ) -> _workflow_execution.FlyteWorkflowExecution: + """ + Executes the launch plan and returns the execution identifier. This version of execution is meant for when + you already have a LiteralMap of inputs. + + :param project: + :param domain: + :param literal_inputs: Inputs to the execution. + :param name: If specified, an execution will be created with this name. Note: the name must + be unique within the context of the project and domain. + :param notification_overrides: If specified, these are the notifications that will be honored for this + execution. An empty list signals to disable all notifications. + :param label_overrides: + :param annotation_overrides: + """ + # Kubernetes requires names starting with an alphabet for some resources. + name = name or "f" + _uuid.uuid4().hex[:19] + disable_all = notification_overrides == [] + if disable_all: + notification_overrides = None + else: + notification_overrides = _execution_models.NotificationList(notification_overrides or []) + disable_all = None + + client = _flyte_engine.get_client() + try: + exec_id = client.create_execution( + project, + domain, + name, + _execution_models.ExecutionSpec( + self.id, + _execution_models.ExecutionMetadata( + _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, + "sdk", # TODO: get principle + 0, # TODO: Detect nesting + ), + notifications=notification_overrides, + disable_all=disable_all, + labels=label_overrides, + annotations=annotation_overrides, + ), + literal_inputs, + ) + except _user_exceptions.FlyteEntityAlreadyExistsException: + exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) + return _workflow_execution.FlyteWorkflowExecution.promote_from_model(client.get_execution(exec_id)) + + @_exception_scopes.system_entry_point + def __call__(self, *args, **input_map: Any) -> _nodes.FlyteNode: + raise NotImplementedError + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface} WF ID: {self.workflow_id})" diff --git a/flytekit/control_plane/nodes.py b/flytekit/control_plane/nodes.py new file mode 100644 index 0000000000..287695bae0 --- /dev/null +++ b/flytekit/control_plane/nodes.py @@ -0,0 +1,281 @@ +import logging as _logging +import os as _os +from typing import Any, Dict, List, Optional + +from flyteidl.core import literals_pb2 as _literals_pb2 + +from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions +from flytekit.common import constants as _constants +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact_mixin +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.common.utils import _dnsify +from flytekit.control_plane import component_nodes as _component_nodes +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane.tasks import executions as _task_executions +from flytekit.core.promise import NodeOutput +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.models import literals as _literal_models +from flytekit.models import node_execution as _node_execution_models +from flytekit.models import task as _task_model +from flytekit.models.core import execution as _execution_models +from flytekit.models.core import workflow as _workflow_model + + +class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + def __init__( + self, + id, + upstream_nodes, + bindings, + metadata, + flyte_task: "flytekit.control_plan.tasks.task.FlyteTask" = None, + flyte_workflow: "flytekit.control_plane.workflow.FlyteWorkflow" = None, + flyte_launch_plan=None, + flyte_branch=None, + parameter_mapping=True, + ): + non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch])) + if len(non_none_entities) != 1: + raise _user_exceptions.FlyteAssertion( + "An Flyte node must have one underlying entity specified at once. Received the following " + "entities: {}".format(non_none_entities) + ) + + workflow_node = None + if flyte_workflow is not None: + workflow_node = _component_nodes.FlyteWorkflowNode(flyte_workflow=flyte_workflow) + elif flyte_launch_plan is not None: + workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) + + super(FlyteNode, self).__init__( + id=_dnsify(id) if id else None, + metadata=metadata, + inputs=bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + task_node=_component_nodes.FlyteTaskNode(flyte_task) if flyte_task else None, + workflow_node=workflow_node, + branch_node=flyte_branch, + ) + self._upstream = upstream_nodes + + @classmethod + def promote_from_model( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate]], + tasks: Optional[Dict[_identifier.Identifier, _task_model.TaskTemplate]], + ) -> "FlyteNode": + id = model.id + if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: + _logging.warning(f"Should not call promote from model on a start node or end node {model}") + return None + + flyte_task_node, flyte_workflow_node = None, None + if model.task_node is not None: + flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks) + elif model.workflow_node is not None: + flyte_workflow_node = _component_nodes.FlyteWorkflowNode.promote_from_model( + model.workflow_node, sub_workflows, tasks + ) + else: + raise _system_exceptions.FlyteSystemException("Bad Node model, neither task nor workflow detected.") + + # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a + # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + for model_input in model.inputs: + if ( + model_input.binding.promise is not None + and model_input.binding.promise.node_id == _constants.START_NODE_ID + ): + model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID + + if flyte_task_node is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + flyte_task=flyte_task_node.flyte_task, + ) + elif flyte_workflow_node is not None: + if flyte_workflow_node.flyte_workflow is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + flyte_workflow=flyte_workflow_node.flyte_workflow, + ) + elif flyte_workflow_node.flyte_launch_plan is not None: + return cls( + id=id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=models.inputs, + metadata=model.metadata, + flyte_launch_plan=flyte_workflow_node.flyte_launch_plan, + ) + raise _system_exceptions.FlyteSystemException( + "Bad FlyteWorkflowNode model, both launch plan and workflow are None" + ) + raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") + + @property + def upstream_nodes(self) -> List["FlyteNode"]: + return self._upstream + + @property + def upstream_node_ids(self) -> List[str]: + return list(sorted(n.id for n in self.upstream_nodes)) + + @property + def outputs(self) -> Dict[str, NodeOutput]: + return self._outputs + + def assign_id_and_return(self, id: str): + if self.id: + raise _user_exceptions.FlyteAssertion( + f"Error assigning ID: {id} because {self} is already assigned. Has this node been ssigned to another " + "workflow already?" + ) + self._id = _dnsify(id) if id else None + self._metadata.name = id + return self + + def with_overrides(self, *args, **kwargs): + # TODO: Implement overrides + raise NotImplementedError("Overrides are not supported in Flyte yet.") + + def __repr__(self) -> str: + return f"Node(ID: {self.id} Executable: {self._executable_flyte_object})" + + +class FlyteNodeExecution(_node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteNodeExecution, self).__init__(*args, **kwargs) + self._task_executions = None + self._workflow_executions = None + self._inputs = None + self._outputs = None + + @property + def task_executions(self) -> List["flytekit.control_plane.tasks.executions.FlyteTaskExecution"]: + return self._task_executions or [] + + @property + def workflow_executions(self) -> List["flytekit.control_plane.workflow_executions.FlyteWorkflowExecution"]: + return self._workflow_executions or [] + + @property + def executions(self) -> _artifact_mixin.ExecutionArtifact: + return self.task_executions or self.workflow_executions or [] + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dicatated by the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_node_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + + # TODO: need to convert flyte literals to python types. For now just use literals + # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + self._inputs = input_map + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_node_execution_data(self.id) + + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + output_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + self._outputs = output_map + return self._outputs + + @property + def error(self) -> _execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting error information." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + _execution_models.NodeExecutionPhase.ABORTED, + _execution_models.NodeExecutionPhase.FAILED, + _execution_models.NodeExecutionPhase.SKIPPED, + _execution_models.NodeExecutionPhase.SUCCEEDED, + _execution_models.NodeExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: _node_execution_models.NodeExecution) -> "FlyteNodeExecution": + return cls(closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + if not self.is_complete or self.task_executions is not None: + client = _flyte_engine.get_client() + self._closure = client.get_node_execution(self.id).closure + self._task_executions = [ + _task_executions.FlyteTaskExecution.promote_from_model(t) + for t in _iterate_task_executions(client, self.id) + ] + # TODO: sync sub-workflows as well + + def _sync_closure(self): + """ + Syncs the closure of the underlying execution artifact with the state observed by the platform. + """ + self._closure = _flyte_engine.get_client().get_node_execution(self.id).closure diff --git a/flytekit/control_plane/tasks/__init__.py b/flytekit/control_plane/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/control_plane/tasks/executions.py b/flytekit/control_plane/tasks/executions.py new file mode 100644 index 0000000000..838746f392 --- /dev/null +++ b/flytekit/control_plane/tasks/executions.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, Optional + +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact_mixin +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models.admin import task_execution as _task_execution_model +from flytekit.models.core import execution as _execution_models + + +class FlyteTaskExecution(_task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteTaskExecution, self).__init__(*args, **kwargs) + self._inputs = None + self._outputs = None + + @property + def is_complete(self) -> bool: + """Whether or not the execution is complete.""" + return self.closure.phase in { + _execution_models.TaskExecutionPhase.ABORTED, + _execution_models.TaskExecutionPhase.FAILED, + _execution_models.TaskExecutionPhase.SUCCEEDED, + } + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs of the task execution in the standard Python format that is produced by + the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_task_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + + self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs of the task execution, if available, in the standard Python format that is produced by + the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_task_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as t: + tmp_name = _os.path.join(t.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + output_map = _literal_models.LiteralMap({}) + + self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + return self._outputs + + @property + def error(self) -> Optional[_execution_models.ExecutionError]: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before requesting error information." + ) + return self.closure.error + + def get_child_executions(self, filters=None): + from flytekit.control_plane import nodes as _nodes + + if not self.is_parent: + raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.") + client = _flyte_engine.get_client() + models = { + v.id.node_id: v + for v in _iterate_node_executions(client, task_execution_identifier=self.id, filters=filters) + } + + return {k: _nodes.FlyteNodeExecution.promote_from_model(v) for k, v in models.items()} + + @classmethod + def promote_from_model(cls, base_model: _task_execution_model.TaskExecution) -> "FlyteTaskExecution": + return cls( + closure=base_model.closure, + id=base_model.id, + input_uri=base_model.input_uri, + is_parent=base_model.is_parent, + ) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + self._sync_closure() + + def _sync_closure(self): + """ + Syncs the closure of the underlying execution artifact with the state observed by the platform. + """ + self._closure = _flyte_engine.get_client().get_task_execution(self.id).closure diff --git a/flytekit/control_plane/tasks/task.py b/flytekit/control_plane/tasks/task.py new file mode 100644 index 0000000000..71f159e523 --- /dev/null +++ b/flytekit/control_plane/tasks/task.py @@ -0,0 +1,95 @@ +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interfaces +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import common as _common_model +from flytekit.models import task as _task_model +from flytekit.models.admin import common as _admin_common +from flytekit.models.core import identifier as _identifier_model + + +class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): + def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): + super(FlyteTask, self).__init__( + id, + type, + metadata, + interface, + custom, + container=container, + task_type_version=task_type_version, + config=config, + ) + + @property + def interface(self) -> _interfaces.TypedInterface: + return super(FlyteTask, self).interface + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.TASK + + @property + def entity_type_text(self) -> str: + return "Task" + + @classmethod + def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": + t = cls( + id=base_model.id, + type=base_model.type, + metadata=base_model.metadata, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + custom=base_model.custom, + container=base_model.container, + task_type_version=base_model.task_type_version, + ) + # Override the newly generated name if one exists in the base model + if not base_model.id.is_empty: + t._id = _identifier.Identifier.promote_from_model(base_model.id) + + return t + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str) -> "FlyteTask": + """ + This function uses the engine loader to call create a hydrated task from Admin. + + :param project: + :param domain: + :param name: + :param version: + """ + task_id = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version) + admin_task = _flyte_engine.get_client().get_task(task_id) + + flyte_task = cls.promote_from_model(admin_task.closure.compiled_task.template) + flyte_task._id = task_id + return flyte_task + + @classmethod + @_exception_scopes.system_entry_point + def fetch_latest(cls, project: str, domain: str, name: str) -> "FlyteTask": + """ + This function uses the engine loader to call create a latest hydrated task from Admin. + + :param project: + :param domain: + :param name: + """ + named_task = _common_model.NamedEntityIdentifier(project, domain, name) + client = _flyte_engine.get_client() + task_list, _ = client.list_tasks_paginated( + named_task, + limit=1, + sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), + ) + admin_task = task_list[0] if task_list else None + + if not admin_task: + raise _user_exceptions.FlyteEntityNotExistException("Named task {} not found".format(named_task)) + flyte_task = cls.promote_from_model(admin_task.closure.compiled_task.template) + flyte_task._id = admin_task.id + return flyte_task diff --git a/flytekit/control_plane/workflow.py b/flytekit/control_plane/workflow.py new file mode 100644 index 0000000000..a164b98667 --- /dev/null +++ b/flytekit/control_plane/workflow.py @@ -0,0 +1,167 @@ +from typing import Dict, List, Optional + +from flytekit.common import constants as _constants +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.control_plane import identifier as _identifier +from flytekit.control_plane import interface as _interfaces +from flytekit.control_plane import nodes as _nodes +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import task as _task_models +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import workflow as _workflow_models + + +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, _workflow_models.WorkflowTemplate): + """A Flyte control plane construct.""" + + def __init__( + self, + nodes: List[_nodes.FlyteNode], + interface, + output_bindings, + id, + metadata, + metadata_defaults, + ): + for node in nodes: + for upstream in node.upstream_nodes: + if upstream.id is None: + raise _user_exceptions.FlyteAssertion( + "Some nodes contained in the workflow were not found in the workflow description. Please " + "ensure all nodes are either assigned to attributes within the class or an element in a " + "list, dict, or tuple which is stored as an attribute in the class." + ) + super(FlyteWorkflow, self).__init__( + id=id, + metadata=metadata, + metadata_defaults=metadata_defaults, + interface=interface, + nodes=nodes, + outputs=output_bindings, + ) + self._flyte_nodes = nodes + + @property + def upstream_entities(self): + return set(n.executable_flyte_object for n in self._flyte_nodes) + + @property + def interface(self) -> _interfaces.TypedInterface: + return super(FlyteWorkflow, self).interface + + @property + def entity_type_text(self) -> str: + return "Workflow" + + @property + def resource_type(self): + return _identifier_model.ResourceType.WORKFLOW + + def get_sub_workflows(self) -> List["FlyteWorkflow"]: + result = [] + for node in self.nodes: + if node.workflow_node is not None and node.workflow_node.sub_workflow_ref is not None: + if ( + node.executable_flyte_object is not None + and node.executable_flyte_object.entity_type_text == "Workflow" + ): + result.append(node.executable_flyte_object) + result.extend(node.executable_flyte_object.get_sub_workflows()) + else: + raise _system_exceptions.FlyteSystemException( + "workflow node with subworkflow found but bad executable " + "object {}".format(node.executable_flyte_object) + ) + + # get subworkflows in conditional branches + if node.branch_node is not None: + if_else: _workflow_models.IfElseBlock = node.branch_node.if_else + leaf_nodes: List[_nodes.FlyteNode] = filter( + None, + [ + if_else.case.then_node, + *([] if if_else.other is None else [x.then_node for x in if_else.other]), + if_else.else_node, + ], + ) + for leaf_node in leaf_nodes: + exec_flyte_obj = leaf_node.executable_flyte_object + if exec_flyte_obj is not None and exec_flyte_obj.entity_type_text == "Workflow": + result.append(exec_flyte_obj) + result.extend(exec_flyte_obj.get_sub_workflows()) + + return result + + @classmethod + @_exception_scopes.system_entry_point + def fetch(cls, project: str, domain: str, name: str, version: str): + workflow_id = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version) + admin_workflow = _flyte_engine.get_client().get_workflow(workflow_id) + cwc = admin_workflow.closure.compiled_workflow + flyte_workflow = cls.promote_from_model( + base_model=cwc.primary.template, + sub_workflows={sw.template.id: sw.template for sw in cwc.sub_workflows}, + tasks={t.template.id: t.template for t in cwc.tasks}, + ) + flyte_workflow._id = workflow_id + return flyte_workflow + + @classmethod + def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: + return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[_identifier.Identifier, _task_models.TaskTemplate]] = None, + ) -> "FlyteWorkflow": + base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + sub_workflows = sub_workflows or {} + tasks = tasks or {} + node_map = { + n.id: _nodes.FlyteNode.promote_from_model(n, sub_workflows, tasks) for n in base_model_non_system_nodes + } + + # Set upstream nodes for each node + for n in base_model_non_system_nodes: + current = node_map[n.id] + for upstream_id in current.upstream_node_ids: + upstream_node = node_map[upstream_id] + current << upstream_node + + # No inputs/outputs specified, see the constructor for more information on the overrides. + return cls( + nodes=list(node_map.values()), + id=_identifier.Identifier.promote_from_model(base_model.id), + metadata=base_model.metadata, + metadata_defaults=base_model.metadata_defaults, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + output_bindings=base_model.outputs, + ) + + @_exception_scopes.system_entry_point + def register(self, project, domain, name, version): + # TODO + pass + + @_exception_scopes.system_entry_point + def serialize(self): + # TODO + pass + + @_exception_scopes.system_entry_point + def validate(self): + # TODO + pass + + @_exception_scopes.system_entry_point + def create_launch_plan(self, *args, **kwargs): + # TODO + pass + + @_exception_scopes.system_entry_point + def __call__(self, *args, **input_map): + raise NotImplementedError diff --git a/flytekit/control_plane/workflow_execution.py b/flytekit/control_plane/workflow_execution.py new file mode 100644 index 0000000000..11eb352351 --- /dev/null +++ b/flytekit/control_plane/workflow_execution.py @@ -0,0 +1,150 @@ +import os as _os +from typing import Any, Dict, List + +from flyteidl.core import literals_pb2 as _literals_pb2 + +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact +from flytekit.control_plane import identifier as _core_identifier +from flytekit.control_plane import nodes as _nodes +from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.models import execution as _execution_models +from flytekit.models import filters as _filter_models +from flytekit.models import literals as _literal_models +from flytekit.models.core import execution as _core_execution_models + + +class FlyteWorkflowExecution(_execution_models.Execution, _artifact.ExecutionArtifact): + def __init__(self, *args, **kwargs): + super(FlyteWorkflowExecution, self).__init__(*args, **kwargs) + self._node_executions = None + self._inputs = None + self._outputs = None + + @property + def node_executions(self) -> Dict[str, _nodes.FlyteNodeExecution]: + return self._node_executions or {} + + @property + def inputs(self) -> Dict[str, Any]: + """ + Returns the inputs to the execution in the standard python format as dictated by the type engine. + """ + if self._inputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_execution_data(self.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + input_map: LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_inputs.literals): + input_map = execution_data.full_inputs + elif execution_data.inputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "inputs.pb") + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) + input_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.Literalmap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + self._inputs = input_map + return self._inputs + + @property + def outputs(self) -> Dict[str, Any]: + """ + Returns the outputs to the execution in the standard python format as dictated by the type engine. + + :raises: ``FlyteAssertion`` error if execution is in progress or execution ended in error. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until the node execution has completed before requesting the outputs." + ) + if self.error: + raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + + if self._outputs is None: + client = _flyte_engine.get_client() + execution_data = client.get_execution_data(self.id) + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + output_map: LiteralMap = _literal_models.LiteralMap({}) + if bool(execution_data.full_outputs.literals): + output_map = execution_data.full_outputs + elif execution_data.outputs.bytes > 0: + with _common_utils.AutoDeletingTempDir() as tmp_dir: + tmp_name = _os.path.join(tmp_dir.name, "outputs.pb") + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) + output_map = _literal_models.LiteralMap.from_flyte_idl( + _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) + ) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + self._outputs = output_map + return self._outputs + + @property + def error(self) -> _core_execution_models.ExecutionError: + """ + If execution is in progress, raise an exception. Otherwise, return None if no error was present upon + reaching completion. + """ + if not self.is_complete: + raise _user_exceptions.FlyteAssertion( + "Please wait until a workflow has completed before checking for an " "error." + ) + return self.closure.error + + @property + def is_complete(self) -> bool: + """ + Whether or not the execution is complete. + """ + return self.closure.phase in { + _core_execution_models.WorkflowExecutionPhase.ABORTED, + _core_execution_models.WorkflowExecutionPhase.FAILED, + _core_execution_models.WorkflowExecutionPhase.SUCCEEDED, + _core_execution_models.WorkflowExecutionPhase.TIMED_OUT, + } + + @classmethod + def promote_from_model(cls, base_model: _execution_models.Execution) -> "FlyteWorkflowExecution": + return cls( + closure=base_model.closure, + id=_core_identifier.WorkflowExecutionIdentifier.promote_from_model(base_model.id), + spec=base_model.spec, + ) + + @classmethod + def fetch(cls, project: str, domain: str, name: str) -> "FlyteWorkflowExecution": + return cls.promote_from_model( + _flyte_engine.get_client().get_execution( + _core_identifier.WorkflowExecutionIdentifier(project=project, domain=domain, name=name) + ) + ) + + def sync(self): + """ + Syncs the state of the underlying execution artifact with the state observed by the platform. + """ + if not self.is_complete or self._node_executions is None: + self._sync_closure() + self._node_executions = self.get_node_executions() + + def _sync_closure(self): + if not self.is_complete: + client = _flyte_engine.get_client() + self._closure = client.get_execution(self.id).closure + + def get_node_executions(self, filters: List[_filter_models.Filter] = None) -> Dict[str, _nodes.FlyteNodeExecution]: + client = _flyte_engine.get_client() + return { + node.id.node_id: _nodes.FlyteNodeExecution.promote_from_model(node) + for node in _iterate_node_executions(client, self.id, filters=filters) + } + + def terminate(self, cause: str): + _flyte_engine.get_client().terminate_execution(self.id, cause) diff --git a/requirements-spark2.txt b/requirements-spark2.txt index f4f4e6115c..6aca3d177a 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.38 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,7 +84,9 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -301,6 +303,7 @@ typed-ast==1.4.3 typing-extensions==3.7.4.3 # via # black + # importlib-metadata # typing-inspect typing-inspect==0.6.0 # via dataclasses-json diff --git a/requirements.txt b/requirements.txt index 2595fcfc44..5533903f08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.38 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,7 +84,9 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via keyring + # via + # jsonschema + # keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -301,6 +303,7 @@ typed-ast==1.4.3 typing-extensions==3.7.4.3 # via # black + # importlib-metadata # typing-inspect typing-inspect==0.6.0 # via dataclasses-json diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore b/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore new file mode 100644 index 0000000000..9bf95ea680 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/.gitignore @@ -0,0 +1 @@ +*.pb \ No newline at end of file diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md b/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md new file mode 100644 index 0000000000..1972a7c658 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/README.md @@ -0,0 +1,4 @@ +# Mock Flyte Repo + +This is a trimmed down version of the [flytesnacks](https://github.com/flyteorg/flytesnacks) +repo for the purposes of local integration testing. diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk b/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk new file mode 100644 index 0000000000..15bc979759 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/in_container.mk @@ -0,0 +1,24 @@ +SERIALIZED_PB_OUTPUT_DIR := /tmp/output + +.PHONY: clean +clean: + rm -rf $(SERIALIZED_PB_OUTPUT_DIR)/* + +$(SERIALIZED_PB_OUTPUT_DIR): clean + mkdir -p $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: serialize +serialize: $(SERIALIZED_PB_OUTPUT_DIR) + pyflyte --config /root/sandbox.config serialize workflows -f $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: register +register: serialize + flyte-cli register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development -v ${VERSION} --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} $(SERIALIZED_PB_OUTPUT_DIR)/* + +.PHONY: fast_serialize +fast_serialize: $(SERIALIZED_PB_OUTPUT_DIR) + pyflyte --config /root/sandbox.config serialize fast workflows -f $(SERIALIZED_PB_OUTPUT_DIR) + +.PHONY: fast_register +fast_register: fast_serialize + flyte-cli fast-register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} --additional-distribution-dir ${ADDL_DISTRIBUTION_DIR} $(SERIALIZED_PB_OUTPUT_DIR)/* diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile new file mode 100644 index 0000000000..7e5d01829f --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Dockerfile @@ -0,0 +1,35 @@ +FROM python:3.8-slim-buster +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks + +WORKDIR /root +ENV VENV /opt/venv +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +# This is necessary for opencv to work +RUN apt-get update && apt-get install -y libsm6 libxext6 libxrender-dev ffmpeg build-essential + +# Install the AWS cli separately to prevent issues with boto being written over +RUN pip3 install awscli + +ENV VENV /opt/venv +# Virtual environment +RUN python3 -m venv ${VENV} +ENV PATH="${VENV}/bin:$PATH" + +# Install Python dependencies +COPY workflows/requirements.txt /root +RUN pip install -r /root/requirements.txt + +# Copy the makefile targets to expose on the container. This makes it easier to register +COPY in_container.mk /root/Makefile +COPY workflows/sandbox.config /root + +# Copy the actual code +COPY workflows /root/workflows + +# This tag is supplied by the build script and will be used to determine the version +# when registering tasks, workflows, and launch plans +ARG tag +ENV FLYTE_INTERNAL_IMAGE $tag diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile new file mode 100644 index 0000000000..5812f4893c --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/Makefile @@ -0,0 +1,208 @@ +.SILENT: + +PREFIX=workflows + +# This is used by the image building script referenced below. Normally it just takes the directory name but in this +# case we want it to be called something else. +IMAGE_NAME=flytecookbook +export VERSION ?= $(shell git rev-parse HEAD) + +define PIP_COMPILE +pip-compile $(1) ${PIP_ARGS} --upgrade --verbose +endef + +# Set SANDBOX=1 to automatically fill in sandbox config +ifdef SANDBOX + +# The url for Flyte Control plane +export FLYTE_HOST ?= localhost:30081 + +# Overrides s3 url. This is solely needed for SANDBOX deployments. Shouldn't be overriden in production AWS S3. +export FLYTE_AWS_ENDPOINT ?= http://localhost:30084/ + +# Used to authenticate to s3. For a production AWS S3, it's discouraged to use keys and key ids. +export FLYTE_AWS_ACCESS_KEY_ID ?= minio + +# Used to authenticate to s3. For a production AWS S3, it's discouraged to use keys and key ids. +export FLYTE_AWS_SECRET_ACCESS_KEY ?= miniostorage + +# Used to publish artifacts for fast registration +export ADDL_DISTRIBUTION_DIR ?= s3://my-s3-bucket/fast/ + +# The base of where Blobs, Schemas and other offloaded types are, by default, serialized. +export OUTPUT_DATA_PREFIX ?= s3://my-s3-bucket/raw-data + +# Instructs flyte-cli commands to use insecure channel when communicating with Flyte's control plane. +# If you're port-forwarding your service or running the sandbox Flyte deployment, specify INSECURE=1 before your make command. +# If your Flyte Admin is behind SSL, don't specify anything. +ifndef INSECURE + export INSECURE_FLAG=-i +endif + +# The docker registry that should be used to push images. +# e.g.: +# export REGISTRY ?= ghcr.io/flyteorg +endif + +# The Flyte project that we want to register under +export PROJECT ?= flytesnacks + +# If the REGISTRY environment variable has been set, that means the image name will not just be tagged as +# flytecookbook: but rather, +# ghcr.io/flyteorg/flytecookbook: or whatever your REGISTRY is. +ifdef REGISTRY + FULL_IMAGE_NAME = ${REGISTRY}/${IMAGE_NAME} +endif +ifndef REGISTRY + FULL_IMAGE_NAME = ${IMAGE_NAME} +endif + +# If you are using a different service account on your k8s cluster, add SERVICE_ACCOUNT=my_account before your make command +ifndef SERVICE_ACCOUNT + SERVICE_ACCOUNT=default +endif + +.PHONY: help +help: ## show help message + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[$$()% a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + +.PHONY: debug +debug: + echo "IMAGE NAME ${IMAGE_NAME}" + echo "FULL IMAGE NAME ${FULL_IMAGE_NAME}" + echo "VERSION TAG ${VERSION}" + echo "REGISTRY ${REGISTRY}" + +TAGGED_IMAGE=${FULL_IMAGE_NAME}:${PREFIX}-${VERSION} + +# This should only be used by Admins to push images to the public Dockerhub repo. Make sure you +# specify REGISTRY=ghcr.io/flyteorg or your registry before the make command otherwise this won't actually push +# Also if you want to push the docker image for sagemaker consumption then +# specify ECR_REGISTRY +.PHONY: docker_push +docker_push: docker_build +ifdef REGISTRY + docker push ${TAGGED_IMAGE} +endif + +.PHONY: fmt +fmt: # Format code with black and isort + black . + isort . + +.PHONY: install-piptools +install-piptools: + pip install -U pip-tools + +.PHONY: setup +setup: install-piptools # Install requirements + pip-sync dev-requirements.txt + +.PHONY: lint +lint: # Run linters + flake8 . + +requirements.txt: export CUSTOM_COMPILE_COMMAND := $(MAKE) requirements.txt +requirements.txt: requirements.in install-piptools + $(call PIP_COMPILE,requirements.in) + +.PHONY: requirements +requirements: requirements.txt + +.PHONY: fast_serialize +fast_serialize: clean _pb_output + echo ${CURDIR} + docker run -it --rm \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + -v ${CURDIR}:/root/$(shell basename $(CURDIR)) \ + ${TAGGED_IMAGE} make fast_serialize + +.PHONY: fast_register +fast_register: clean _pb_output ## Packages code and registers without building docker images. + @echo "Tagged Image: " + @echo ${TAGGED_IMAGE} + @echo ${CURDIR} + docker run -it --rm \ + --network host \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + -v ${CURDIR}:/root/$(shell basename $(CURDIR)) \ + ${TAGGED_IMAGE} make fast_register + +.PHONY: docker_build +docker_build: + echo "Tagged Image: " + echo ${TAGGED_IMAGE} + docker build ../ --build-arg tag="${TAGGED_IMAGE}" -t "${TAGGED_IMAGE}" -f Dockerfile + +.PHONY: serialize +serialize: clean _pb_output docker_build + @echo ${VERSION} + @echo ${CURDIR} + docker run -it --rm \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + ${TAGGED_IMAGE} make serialize + + +.PHONY: register +register: clean _pb_output docker_push + @echo ${VERSION} + @echo ${CURDIR} + docker run -it --rm \ + --network host \ + -e REGISTRY=${REGISTRY} \ + -e MAKEFLAGS=${MAKEFLAGS} \ + -e FLYTE_HOST=${FLYTE_HOST} \ + -e INSECURE_FLAG=${INSECURE_FLAG} \ + -e PROJECT=${PROJECT} \ + -e FLYTE_AWS_ENDPOINT=${FLYTE_AWS_ENDPOINT} \ + -e FLYTE_AWS_ACCESS_KEY_ID=${FLYTE_AWS_ACCESS_KEY_ID} \ + -e FLYTE_AWS_SECRET_ACCESS_KEY=${FLYTE_AWS_SECRET_ACCESS_KEY} \ + -e OUTPUT_DATA_PREFIX=${OUTPUT_DATA_PREFIX} \ + -e ADDL_DISTRIBUTION_DIR=${ADDL_DISTRIBUTION_DIR} \ + -e SERVICE_ACCOUNT=$(SERVICE_ACCOUNT) \ + -e VERSION=${VERSION} \ + -v ${CURDIR}/_pb_output:/tmp/output \ + ${TAGGED_IMAGE} make register + +_pb_output: + mkdir -p _pb_output + +.PHONY: clean +clean: + rm -rf _pb_output/* diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py new file mode 100644 index 0000000000..49c42c5911 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/basic_workflow.py @@ -0,0 +1,54 @@ +""" +Write a simple workflow +------------------------------ + +Once you've had a handle on tasks, we can move to workflows. Workflow are the other basic building block of Flyte. + +Workflows string together two or more tasks. They are also written as Python functions, but it is important to make a +critical distinction between tasks and workflows. + +The body of a task's function runs at "run time", i.e. on the K8s cluster, using the task's container. The body of a +workflow is not used for computation, it is only used to structure the tasks, i.e. the output of ``t1`` is an input +of ``t2`` in the workflow below. As such, the body of workflows is run at "registration" time. Please refer to the +registration docs for additional information as well since it is actually a two-step process. + +Take a look at the conceptual `discussion `__ +behind workflows for additional information. + +""" +import typing + +from flytekit import task, workflow + + +@task +def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + return a + 2, "world" + + +@task +def t2(a: str, b: str) -> str: + return b + a + + +# %% +# You can treat the outputs of a task as you normally would a Python function. Assign the output to two variables +# and use them in subsequent tasks as normal. See :py:func:`flytekit.workflow` +@workflow +def my_wf(a: int, b: str) -> (int, str): + x, y = t1(a=a) + d = t2(a=y, b=b) + return x, d + + +# %% +# Execute the Workflow, simply by invoking it like a function and passing in +# the necessary parameters +# +# .. note:: +# +# One thing to remember, currently we only support ``Keyword arguments``. So +# every argument should be passed in the form ``arg=value``. Failure to do so +# will result in an error +if __name__ == "__main__": + print(f"Running my_wf(a=50, b='hello') {my_wf(a=50, b='hello')}") diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py new file mode 100644 index 0000000000..da7e61536f --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/basic/hello_world.py @@ -0,0 +1,40 @@ +""" +Hello World Workflow +-------------------- + +This simple workflow calls a task that returns "Hello World" and then just sets that as the final output of the workflow. + +""" + +from flytekit import task, workflow + + +# You can change the signature of the workflow to take in an argument like this: +# def say_hello(name: str) -> str: +@task +def say_hello() -> str: + return "hello world" + + +# %% +# You can treat the outputs of a task as you normally would a Python function. Assign the output to two variables +# and use them in subsequent tasks as normal. See :py:func:`flytekit.workflow` +# You can change the signature of the workflow to take in an argument like this: +# def my_wf(name: str) -> str: +@workflow +def my_wf() -> str: + res = say_hello() + return res + + +# %% +# Execute the Workflow, simply by invoking it like a function and passing in +# the necessary parameters +# +# .. note:: +# +# One thing to remember, currently we only support ``Keyword arguments``. So +# every argument should be passed in the form ``arg=value``. Failure to do so +# will result in an error +if __name__ == "__main__": + print(f"Running my_wf() {my_wf()}") diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in new file mode 100644 index 0000000000..f7d015b843 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.in @@ -0,0 +1,4 @@ +flytekit>=0.17.0b0 +wheel +matplotlib +opencv-python diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt new file mode 100644 index 0000000000..7b2b880306 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/requirements.txt @@ -0,0 +1,136 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# /Library/Developer/CommandLineTools/usr/bin/make requirements.txt +# +attrs==20.3.0 + # via scantree +certifi==2020.12.5 + # via requests +chardet==4.0.0 + # via requests +click==7.1.2 + # via flytekit +croniter==1.0.10 + # via flytekit +cycler==0.10.0 + # via matplotlib +dataclasses-json==0.5.2 + # via flytekit +decorator==5.0.4 + # via retry +deprecated==1.2.12 + # via flytekit +dirhash==0.2.1 + # via flytekit +docker-image-py==0.1.10 + # via flytekit +flyteidl==0.18.31 + # via flytekit +flytekit==0.17.0 + # via -r ../common/requirements-common.in +grpcio==1.36.1 + # via flytekit +idna==2.10 + # via requests +importlib-metadata==3.10.0 + # via keyring +keyring==23.0.1 + # via flytekit +kiwisolver==1.3.1 + # via matplotlib +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow==3.11.1 + # via + # dataclasses-json + # marshmallow-enum +matplotlib==3.4.1 + # via -r ../common/requirements-common.in +mypy-extensions==0.4.3 + # via typing-inspect +natsort==7.1.1 + # via flytekit +numpy==1.20.2 + # via + # matplotlib + # opencv-python + # pandas + # pyarrow +opencv-python==4.5.1.48 + # via -r requirements.in +pandas==1.2.3 + # via flytekit +pathspec==0.8.1 + # via scantree +pillow==8.2.0 + # via matplotlib +protobuf==3.15.7 + # via + # flyteidl + # flytekit +py==1.10.0 + # via retry +pyarrow==3.0.0 + # via flytekit +pyparsing==2.4.7 + # via matplotlib +python-dateutil==2.8.1 + # via + # croniter + # flytekit + # matplotlib + # pandas +pytimeparse==1.1.8 + # via flytekit +pytz==2018.4 + # via + # flytekit + # pandas +regex==2021.3.17 + # via docker-image-py +requests==2.25.1 + # via + # flytekit + # responses +responses==0.13.2 + # via flytekit +retry==0.9.2 + # via flytekit +scantree==0.0.1 + # via dirhash +six==1.15.0 + # via + # cycler + # flytekit + # grpcio + # protobuf + # python-dateutil + # responses + # scantree +sortedcontainers==2.3.0 + # via flytekit +statsd==3.3.0 + # via flytekit +stringcase==1.2.0 + # via dataclasses-json +typing-extensions==3.7.4.3 + # via typing-inspect +typing-inspect==0.6.0 + # via dataclasses-json +urllib3==1.25.11 + # via + # flytekit + # requests + # responses +wheel==0.36.2 + # via + # -r ../common/requirements-common.in + # flytekit +wrapt==1.12.1 + # via + # deprecated + # flytekit +zipp==3.4.1 + # via importlib-metadata diff --git a/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config new file mode 100644 index 0000000000..da3362a4b0 --- /dev/null +++ b/tests/flytekit/integration/control_plane/mock_flyte_repo/workflows/sandbox.config @@ -0,0 +1,7 @@ +[sdk] +workflow_packages=workflows +python_venv=flytekit_venv + +[auth] +assumable_iam_role=arn:aws:iam::173840052742:role/flytefunctionaltestsbatchworker-production-iad +raw_output_data_prefix=s3://lyft-modelbuilder/cookbook diff --git a/tests/flytekit/integration/control_plane/test_workflow.py b/tests/flytekit/integration/control_plane/test_workflow.py new file mode 100644 index 0000000000..8b72cd7c2a --- /dev/null +++ b/tests/flytekit/integration/control_plane/test_workflow.py @@ -0,0 +1,90 @@ +import datetime +import os +import pathlib +import time + +import pytest + +from flytekit.common.exceptions.user import FlyteAssertion +from flytekit.control_plane import launch_plan +from flytekit.models import literals + +PROJECT = "flytesnacks" +VERSION = os.getpid() + + +@pytest.fixture(scope="session") +def flyte_workflows_source_dir(): + return pathlib.Path(os.path.dirname(__file__)) / "mock_flyte_repo" + + +@pytest.fixture(scope="session") +def flyte_workflows_register(docker_compose): + docker_compose.execute( + f"exec -w /flyteorg/src -e SANDBOX=1 -e PROJECT={PROJECT} -e VERSION=v{VERSION} " + "backend make -C workflows register" + ) + + +def test_client(flyteclient, flyte_workflows_register): + projects = flyteclient.list_projects_paginated(limit=5, token=None) + assert len(projects) <= 5 + + +def test_launch_workflow(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" + ).launch_with_literals(PROJECT, "development", literals.LiteralMap({})) + execution.wait_for_completion() + assert execution.outputs.literals["o0"].scalar.primitive.string_value == "hello world" + + +def test_launch_workflow_with_args(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.basic_workflow.my_wf", f"v{VERSION}" + ).launch_with_literals( + PROJECT, + "development", + literals.LiteralMap( + { + "a": literals.Literal(literals.Scalar(literals.Primitive(integer=10))), + "b": literals.Literal(literals.Scalar(literals.Primitive(string_value="foobar"))), + } + ), + ) + execution.wait_for_completion() + assert execution.outputs.literals["o0"].scalar.primitive.integer == 12 + assert execution.outputs.literals["o1"].scalar.primitive.string_value == "foobarworld" + + +def test_monitor_workflow(flyteclient, flyte_workflows_register): + execution = launch_plan.FlyteLaunchPlan.fetch( + PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" + ).launch_with_literals(PROJECT, "development", literals.LiteralMap({})) + + poll_interval = datetime.timedelta(seconds=1) + time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60) + + execution.sync() + while datetime.datetime.utcnow() < time_to_give_up: + + if execution.is_complete: + execution.sync() + break + + with pytest.raises( + FlyteAssertion, match="Please wait until the node execution has completed before requesting the outputs" + ): + execution.outputs + + time.sleep(poll_interval.total_seconds()) + execution.sync() + + if execution.node_executions: + assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED + + for key in execution.node_executions: + assert execution.node_executions[key].closure.phase == 3 + + assert execution.node_executions["n0"].outputs.literals["o0"].scalar.primitive.string_value == "hello world" + assert execution.outputs.literals["o0"].scalar.primitive.string_value == "hello world" diff --git a/tests/flytekit/unit/control_plane/__init__.py b/tests/flytekit/unit/control_plane/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/control_plane/tasks/test_task.py b/tests/flytekit/unit/control_plane/tasks/test_task.py new file mode 100644 index 0000000000..479221ae9d --- /dev/null +++ b/tests/flytekit/unit/control_plane/tasks/test_task.py @@ -0,0 +1,34 @@ +from mock import MagicMock as _MagicMock +from mock import patch as _patch + +from flytekit.control_plane.tasks import task as _task +from flytekit.models import task as _task_models +from flytekit.models.core import identifier as _identifier + + +@_patch("flytekit.engines.flyte.engine._FlyteClientManager") +@_patch("flytekit.configuration.platform.URL") +def test_flyte_task_fetch(mock_url, mock_client_manager): + mock_url.get.return_value = "localhost" + admin_task_v1 = _task_models.Task( + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + _MagicMock(), + ) + admin_task_v2 = _task_models.Task( + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v2"), + _MagicMock(), + ) + mock_client = _MagicMock() + mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task_v2, admin_task_v1], "")) + mock_client_manager.return_value.client = mock_client + + latest_task = _task.FlyteTask.fetch_latest("p1", "d1", "n1") + task_v1 = _task.FlyteTask.fetch("p1", "d1", "n1", "v1") + task_v2 = _task.FlyteTask.fetch("p1", "d1", "n1", "v2") + assert task_v1.id == admin_task_v1.id + assert task_v1.id != latest_task.id + assert task_v2.id == latest_task.id == admin_task_v2.id + + for task in [task_v1, task_v2]: + assert task.entity_type_text == "Task" + assert task.resource_type == _identifier.ResourceType.TASK diff --git a/tests/flytekit/unit/control_plane/test_identifier.py b/tests/flytekit/unit/control_plane/test_identifier.py new file mode 100644 index 0000000000..8976df0bd3 --- /dev/null +++ b/tests/flytekit/unit/control_plane/test_identifier.py @@ -0,0 +1,77 @@ +import pytest + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.control_plane import identifier as _identifier +from flytekit.models.core import identifier as _core_identifier + + +def test_identifier(): + identifier = _identifier.Identifier(_core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1") + assert identifier == _identifier.Identifier.from_urn("wf:project:domain:name:v1") + assert identifier == _core_identifier.Identifier( + _core_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "v1" + ) + assert identifier.__str__() == "wf:project:domain:name:v1" + + +@pytest.mark.parametrize( + "urn", + [ + "", + "project:domain:name:v1", + "wf:project:domain:name:v1:foobar", + "foobar:project:domain:name:v1", + ], +) +def test_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.Identifier.from_urn(urn) + + +def test_workflow_execution_identifier(): + identifier = _identifier.WorkflowExecutionIdentifier("project", "domain", "name") + assert identifier == _identifier.WorkflowExecutionIdentifier.from_urn("ex:project:domain:name") + assert identifier == _identifier.WorkflowExecutionIdentifier.promote_from_model( + _core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + assert identifier.__str__() == "ex:project:domain:name" + + +@pytest.mark.parametrize( + "urn", ["", "project:domain:name", "project:domain:name:foobar", "ex:project:domain:name:foobar"] +) +def test_workflow_execution_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.WorkflowExecutionIdentifier.from_urn(urn) + + +def test_task_execution_identifier(): + task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK, "project", "domain", "name", "version") + node_execution_id = _core_identifier.NodeExecutionIdentifier( + node_id="n0", execution_id=_core_identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + identifier = _identifier.TaskExecutionIdentifier( + task_id=task_id, + node_execution_id=node_execution_id, + retry_attempt=0, + ) + assert identifier == _identifier.TaskExecutionIdentifier.from_urn( + "te:project:domain:name:n0:project:domain:name:version:0" + ) + assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model( + _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id, 0) + ) + assert identifier.__str__() == "te:project:domain:name:n0:project:domain:name:version:0" + + +@pytest.mark.parametrize( + "urn", + [ + "", + "te:project:domain:name:n0:project:domain:name:version", + "foobar:project:domain:name:n0:project:domain:name:version:0", + ], +) +def test_task_execution_identifier_exceptions(urn): + with pytest.raises(_user_exceptions.FlyteValueException): + _identifier.TaskExecutionIdentifier.from_urn(urn) diff --git a/tests/flytekit/unit/control_plane/test_workflow.py b/tests/flytekit/unit/control_plane/test_workflow.py new file mode 100644 index 0000000000..81c82bf706 --- /dev/null +++ b/tests/flytekit/unit/control_plane/test_workflow.py @@ -0,0 +1,23 @@ +from mock import MagicMock as _MagicMock +from mock import patch as _patch + +from flytekit.control_plane import workflow as _workflow +from flytekit.models.admin import workflow as _workflow_models +from flytekit.models.core import identifier as _identifier + + +@_patch("flytekit.engines.flyte.engine._FlyteClientManager") +@_patch("flytekit.configuration.platform.URL") +def test_flyte_workflow_integration(mock_url, mock_client_manager): + mock_url.get.return_value = "localhost" + admin_workflow = _workflow_models.Workflow( + _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p1", "d1", "n1", "v1"), + _MagicMock(), + ) + mock_client = _MagicMock() + mock_client.list_workflows_paginated = _MagicMock(returnValue=([admin_workflow], "")) + mock_client_manager.return_value.client = mock_client + + workflow = _workflow.FlyteWorkflow.fetch("p1", "d1", "n1", "v1") + assert workflow.entity_type_text == "Workflow" + assert workflow.id == admin_workflow.id From 59ffbb54e653cfb391a8ddc9f33fef41cb92acb8 Mon Sep 17 00:00:00 2001 From: ajsalow Date: Wed, 28 Apr 2021 14:15:07 -0500 Subject: [PATCH 18/33] Use default arguments in addition to kwargs in local wf execution (#458) Signed-off-by: Max Hoffman --- flytekit/core/workflow.py | 12 +++++++----- tests/flytekit/unit/core/test_workflows.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 9aacfd6bf8..47d4f7af23 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -223,10 +223,12 @@ def __call__(self, *args, **kwargs): ctx = FlyteContext.current_context() + # Get default agruements and override with kwargs passed in + input_kwargs = self.python_interface.default_inputs_as_kwargs + input_kwargs.update(kwargs) + # The first condition is compilation. if ctx.compilation_state is not None: - input_kwargs = self.python_interface.default_inputs_as_kwargs - input_kwargs.update(kwargs) return create_and_link_node(ctx, entity=self, interface=self.python_interface, **input_kwargs) # This condition is hit when this workflow (self) is being called as part of a parent's workflow local run. @@ -243,7 +245,7 @@ def __call__(self, *args, **kwargs): else: return None # We are already in a local execution, just continue the execution context - return self._local_execute(ctx, **kwargs) + return self._local_execute(ctx, **input_kwargs) # Last is starting a local workflow execution else: @@ -251,14 +253,14 @@ def __call__(self, *args, **kwargs): # Even though the _local_execute call generally expects inputs to be Promises, we don't have to do the # conversion here in this loop. The reason is because we don't prevent users from specifying inputs # as direct scalars, which means there's another Promise-generating loop inside _local_execute too - for k, v in kwargs.items(): + for k, v in input_kwargs.items(): if k not in self.interface.inputs: raise ValueError(f"Received unexpected keyword argument {k}") if isinstance(v, Promise): raise ValueError(f"Received a promise for a workflow call, when expecting a native value for {k}") with ctx.new_execution_context(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) as ctx: - result = self._local_execute(ctx, **kwargs) + result = self._local_execute(ctx, **input_kwargs) expected_outputs = len(self.python_interface.outputs) if expected_outputs == 0: diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 0750e95be0..167d429fab 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -6,6 +6,7 @@ from flytekit.common.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.common.translator import get_serializable from flytekit.core import context_manager +from flytekit.core.condition import conditional from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow @@ -51,6 +52,23 @@ def wf(a: int) -> (str, str): assert sdk_wf.metadata.on_failure == 1 +def test_default_values(): + @task + def t() -> bool: + return True + + @task + def f() -> bool: + return False + + @workflow + def wf(a: bool = True) -> bool: + return conditional("bool").if_(a.is_true()).then(t()).else_().then(f()) + + assert wf() is True + assert wf(a=False) is False + + def test_list_output_wf(): @task def t1(a: int) -> int: From 04090f089fed9c8e1b07b655b507e21ac18192a6 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Wed, 28 Apr 2021 17:12:42 -0400 Subject: [PATCH 19/33] fix control_plane imports (#459) * fix imports Signed-off-by: cosmicBboy * update Signed-off-by: cosmicBboy Signed-off-by: Max Hoffman --- flytekit/control_plane/tasks/executions.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/flytekit/control_plane/tasks/executions.py b/flytekit/control_plane/tasks/executions.py index 838746f392..ae1eda09ba 100644 --- a/flytekit/control_plane/tasks/executions.py +++ b/flytekit/control_plane/tasks/executions.py @@ -1,11 +1,11 @@ from typing import Any, Dict, Optional +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions from flytekit.common import utils as _common_utils from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import artifact as _artifact_mixin -from flytekit.core.context_manager import FlyteContext -from flytekit.core.type_engine import TypeEngine from flytekit.engines.flyte import engine as _flyte_engine +from flytekit.models import literals as _literal_models from flytekit.models.admin import task_execution as _task_execution_model from flytekit.models.core import execution as _execution_models @@ -36,7 +36,7 @@ def inputs(self) -> Dict[str, Any]: execution_data = client.get_task_execution_data(self.id) # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. - input_map: _literal_models.LiteralMap = _literal_models.LiteralMap({}) + input_map = _literal_models.LiteralMap({}) if bool(execution_data.full_inputs.literals): input_map = execution_data.full_inputs elif execution_data.inputs.bytes > 0: @@ -47,7 +47,9 @@ def inputs(self) -> Dict[str, Any]: _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) - self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._inputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=input_map) + self._inputs = input_map return self._inputs @property @@ -70,9 +72,9 @@ def outputs(self) -> Dict[str, Any]: execution_data = client.get_task_execution_data(self.id) # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + output_map = _literal_models.LiteralMap({}) if bool(execution_data.full_outputs.literals): output_map = execution_data.full_outputs - elif execution_data.outputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "outputs.pb") @@ -80,9 +82,10 @@ def outputs(self) -> Dict[str, Any]: output_map = _literal_models.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) - output_map = _literal_models.LiteralMap({}) - self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + # TODO: need to convert flyte literals to python types. For now just use literals + # self._outputs = TypeEngine.literal_map_to_kwargs(ctx=FlyteContext.current_context(), lm=output_map) + self._outputs = output_map return self._outputs @property From 01a9a959c3e332b2bd901a08d323e34b03ddeafe Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 28 Apr 2021 17:44:15 -0700 Subject: [PATCH 20/33] Initial DoltTable implementation Signed-off-by: Max Hoffman --- dev-requirements.txt | 15 +- doc-requirements.txt | 35 ++--- plugins/dolt/flytekitplugins/dolt/__init__.py | 1 + plugins/dolt/flytekitplugins/dolt/schema.py | 105 ++++++++++++++ plugins/dolt/scripts/flytekit_install_dolt.sh | 12 ++ plugins/dolt/setup.py | 35 +++++ plugins/setup.py | 1 + plugins/tests/dolt/__init__.py | 0 plugins/tests/dolt/test_wf.py | 128 ++++++++++++++++++ requirements-spark2.txt | 25 ++-- requirements.txt | 25 ++-- 11 files changed, 318 insertions(+), 64 deletions(-) create mode 100644 plugins/dolt/flytekitplugins/dolt/__init__.py create mode 100644 plugins/dolt/flytekitplugins/dolt/schema.py create mode 100644 plugins/dolt/scripts/flytekit_install_dolt.sh create mode 100644 plugins/dolt/setup.py create mode 100644 plugins/tests/dolt/__init__.py create mode 100644 plugins/tests/dolt/test_wf.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 7462d342cb..4b0f0134a4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -17,13 +17,7 @@ attrs==20.3.0 # -c requirements.txt # jsonschema # pytest - # pytest-docker - # scantree -bcrypt==3.2.0 - # via - # -c requirements.txt - # paramiko -black==20.8b1 +black==21.4b2 # via # -c requirements.txt # -r dev-requirements.in @@ -316,15 +310,10 @@ toml==0.10.2 # coverage # pytest typed-ast==1.4.3 - # via - # -c requirements.txt - # black - # mypy + # via mypy typing-extensions==3.7.4.3 # via # -c requirements.txt - # black - # importlib-metadata # mypy # typing-inspect typing-inspect==0.6.0 diff --git a/doc-requirements.txt b/doc-requirements.txt index 105ba18676..99f2a9c6da 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.5.3 +astroid==2.5.6 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -24,7 +24,7 @@ attrs==20.3.0 # via # jsonschema # scantree -babel==2.9.0 +babel==2.9.1 # via sphinx backcall==0.2.0 # via ipython @@ -35,13 +35,13 @@ beautifulsoup4==4.9.3 # furo # sphinx-code-include # sphinx-material -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -68,7 +68,7 @@ cryptography==3.4.7 # paramiko css-html-js-minify==2.5.5 # via sphinx-material -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.38 +flyteidl==0.18.39 # via flytekit furo==2021.4.11b34 # via -r doc-requirements.in @@ -279,19 +279,19 @@ requests==2.25.1 # papermill # responses # sphinx -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -316,13 +316,13 @@ sortedcontainers==2.3.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -sphinx-autoapi==1.8.0 +sphinx-autoapi==1.8.1 # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in sphinx-copybutton==0.3.1 # via -r doc-requirements.in -sphinx-gallery==0.8.2 +sphinx-gallery==0.9.0 # via -r doc-requirements.in sphinx-material==0.0.32 # via -r doc-requirements.in @@ -381,15 +381,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via - # astroid - # black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/plugins/dolt/flytekitplugins/dolt/__init__.py b/plugins/dolt/flytekitplugins/dolt/__init__.py new file mode 100644 index 0000000000..fd9379e283 --- /dev/null +++ b/plugins/dolt/flytekitplugins/dolt/__init__.py @@ -0,0 +1 @@ +from .schema import DoltConfig, DoltTable, DoltTableNameTransformer diff --git a/plugins/dolt/flytekitplugins/dolt/schema.py b/plugins/dolt/flytekitplugins/dolt/schema.py new file mode 100644 index 0000000000..01182b806c --- /dev/null +++ b/plugins/dolt/flytekitplugins/dolt/schema.py @@ -0,0 +1,105 @@ +import tempfile +import typing +from dataclasses import dataclass +from typing import Type + +import dolt_integrations.core as dolt_int +import doltcli as dolt +import pandas +from dataclasses_json import dataclass_json +from google.protobuf.struct_pb2 import Struct + +from flytekit import FlyteContext +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models import types as _type_models +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType + + +@dataclass_json +@dataclass +class DoltConfig: + db_path: str + tablename: typing.Optional[str] = None + sql: typing.Optional[str] = None + io_args: typing.Optional[dict] = None + branch_conf: typing.Optional[dolt_int.Branch] = None + meta_conf: typing.Optional[dolt_int.Meta] = None + remote_conf: typing.Optional[dolt_int.Remote] = None + + +@dataclass_json +@dataclass +class DoltTable: + config: DoltConfig + data: typing.Optional[pandas.DataFrame] = None + + +class DoltTableNameTransformer(TypeTransformer[DoltTable]): + def __init__(self): + super().__init__(name="DoltTable", t=DoltTable) + + def get_literal_type(self, t: Type[DoltTable]) -> LiteralType: + return LiteralType(simple=_type_models.SimpleType.STRUCT, metadata={}) + + def to_literal( + self, + ctx: FlyteContext, + python_val: DoltTable, + python_type: typing.Type[DoltTable], + expected: LiteralType, + ) -> Literal: + + if not isinstance(python_val, DoltTable): + raise AssertionError(f"Value cannot be converted to a table: {python_val}") + + conf = python_val.config + if python_val.data is not None and python_val.tablename is not None: + db = dolt.Dolt(conf.db_path) + with tempfile.NamedTemporaryFile() as f: + python_val.data.to_csv(f.name, index=False) + dolt_int.save( + db=db, + tablename=conf.tablename, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + save_args=conf.io_args, + ) + + s = Struct() + s.update(python_val.to_dict()) + return Literal(Scalar(generic=s)) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: typing.Type[DoltTable], + ) -> DoltTable: + + if not (lv and lv.scalar and lv.scalar.generic and lv.scalar.generic["config"]): + return pandas.DataFrame() + + conf = DoltConfig(**lv.scalar.generic["config"]) + db = dolt.Dolt(conf.db_path) + + with tempfile.NamedTemporaryFile() as f: + dolt_int.load( + db=db, + tablename=conf.tablename, + sql=conf.sql, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + load_args=conf.io_args, + ) + df = pandas.read_csv(f) + lv.data = df + + return lv + + +TypeEngine.register(DoltTableNameTransformer()) diff --git a/plugins/dolt/scripts/flytekit_install_dolt.sh b/plugins/dolt/scripts/flytekit_install_dolt.sh new file mode 100644 index 0000000000..c2f4841789 --- /dev/null +++ b/plugins/dolt/scripts/flytekit_install_dolt.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# Fetches and install Dolt. To be invoked by the Dockerfile + +# echo commands to the terminal output +set -eox pipefail + +# Install Dolt + +apt-get update -y \ + && apt-get install curl \ + && sudo bash -c 'curl -L https://github.com/dolthub/dolt/releases/latest/download/install.sh | sudo bash' diff --git a/plugins/dolt/setup.py b/plugins/dolt/setup.py new file mode 100644 index 0000000000..ddda6f8813 --- /dev/null +++ b/plugins/dolt/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup + +PLUGIN_NAME = "dolt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="dolthub", + author_email="max@dolthub.com", + description="Dolt plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + scripts=["scripts/flytekit_install_dolt.sh"], +) diff --git a/plugins/setup.py b/plugins/setup.py index abd2d8c7c8..9e54cea7a1 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -16,6 +16,7 @@ "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", "flytekitplugins-sqlalchemy": "sqlalchemy", + "flytekitplugins-dolt": "dolt", } diff --git a/plugins/tests/dolt/__init__.py b/plugins/tests/dolt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/dolt/test_wf.py b/plugins/tests/dolt/test_wf.py new file mode 100644 index 0000000000..c8b41e5bcb --- /dev/null +++ b/plugins/tests/dolt/test_wf.py @@ -0,0 +1,128 @@ +import os +import shutil +import tempfile + +import doltcli as dolt +import pandas +import pytest +from flytekitplugins.dolt.schema import DoltConfig, DoltTable + +from flytekit import task, workflow + + +@pytest.fixture(scope="function") +def doltdb_path(): + d = tempfile.TemporaryDirectory() + try: + db_path = os.path.join(d.name, "foo") + yield db_path + finally: + shutil.rmtree(d.name) + + +@pytest.fixture(scope="function") +def dolt_config(doltdb_path): + yield DoltConfig( + db_path=doltdb_path, + tablename="foo", + ) + + +@pytest.fixture(scope="function") +def db(doltdb_path): + try: + db = dolt.Dolt.init(doltdb_path) + db.sql("create table bar (name text primary key, count bigint)") + db.sql("insert into bar values ('Dilly', 3)") + db.sql("select dolt_commit('-am', 'Initialize bar table')") + yield db + finally: + pass + + +def test_dolt_table_write(db, dolt_config): + @task + def my_dolt(a: int) -> DoltTable: + df = pandas.DataFrame([("Alice", a)], columns=["name", "count"]) + return DoltTable(data=df, config=dolt_config) + + @workflow + def my_wf(a: int) -> DoltTable: + return my_dolt(a=a) + + x = my_wf(a=5) + assert x + assert (x.data == pandas.DataFrame([("Alice", 5)], columns=["name", "count"])).all().all() + + +def test_dolt_table_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" + + +def test_dolt_table_read_task_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table() -> DoltTable: + dolt_config.tablename = "bar" + t = DoltTable(config=dolt_config) + return t + + @workflow + def my_wf() -> str: + t = my_table() + return my_dolt(t=t) + + x = my_wf() + assert x == "Dilly" + + +def test_dolt_table_read_mixed_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table(conf: DoltConfig) -> DoltTable: + return DoltTable(config=conf) + + @workflow + def my_wf(conf: DoltConfig) -> str: + t = my_table(conf=conf) + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(conf=dolt_config) + + assert x == "Dilly" + + +def test_dolt_sql_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = None + dolt_config.sql = "select * from bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 6aca3d177a..eb28185ae8 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -24,13 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -53,7 +53,7 @@ croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.38 +flyteidl==0.18.39 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -237,19 +237,19 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -298,13 +298,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 5533903f08..9aee73be48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,13 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -53,7 +53,7 @@ croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.38 +flyteidl==0.18.39 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -237,19 +237,19 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -298,13 +298,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 From db370c6d44f7cbd3408cce724c16e94a3e6a7fca Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 28 Apr 2021 19:17:10 -0700 Subject: [PATCH 21/33] Update reqs Signed-off-by: Max Hoffman --- dev-requirements.txt | 18 +++++++++--------- doc-requirements.txt | 4 +--- requirements-spark2.txt | 4 +--- requirements.txt | 4 +--- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 4b0f0134a4..991afbf9a1 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -17,13 +17,17 @@ attrs==20.3.0 # -c requirements.txt # jsonschema # pytest + # pytest-docker + # scantree +bcrypt==3.2.0 + # via + # -c requirements.txt + # paramiko black==21.4b2 # via # -c requirements.txt # -r dev-requirements.in # flake8-black -cached-property==1.5.2 - # via docker-compose certifi==2020.12.5 # via # -c requirements.txt @@ -53,7 +57,7 @@ cryptography==3.4.7 # via # -c requirements.txt # paramiko -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via # -c requirements.txt # flytekit @@ -94,7 +98,7 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort -flyteidl==0.18.38 +flyteidl==0.18.39 # via # -c requirements.txt # flytekit @@ -109,11 +113,7 @@ idna==2.10 importlib-metadata==4.0.1 # via # -c requirements.txt - # flake8 - # jsonschema # keyring - # pluggy - # pytest iniconfig==1.1.1 # via pytest isort==5.8.0 @@ -261,7 +261,7 @@ requests==2.25.1 # docker-compose # flytekit # responses -responses==0.13.2 +responses==0.13.3 # via # -c requirements.txt # flytekit diff --git a/doc-requirements.txt b/doc-requirements.txt index 99f2a9c6da..d3d99642b4 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -107,9 +107,7 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 diff --git a/requirements-spark2.txt b/requirements-spark2.txt index eb28185ae8..1a791f19cd 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 diff --git a/requirements.txt b/requirements.txt index 9aee73be48..26889d6590 100644 --- a/requirements.txt +++ b/requirements.txt @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 From 03c67d81d42c6b4088bfcdb400882e3f9e92ebb5 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 28 Apr 2021 22:05:02 -0700 Subject: [PATCH 22/33] Fix fmt Signed-off-by: Max Hoffman --- flytekit/core/promise.py | 2 +- plugins/setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 894d44fc0e..25347ec1a2 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -621,7 +621,7 @@ def node_id(self): @property def node(self) -> Node: - """""" + """ """ return self._node def __repr__(self) -> str: diff --git a/plugins/setup.py b/plugins/setup.py index 9e54cea7a1..96debbefbd 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -41,7 +41,7 @@ def install_all_plugins(sources, develop=False): class DevelopCmd(develop): - """ Add custom steps for the develop command """ + """Add custom steps for the develop command""" def run(self): install_all_plugins(SOURCES, develop=True) @@ -49,7 +49,7 @@ def run(self): class InstallCmd(install): - """ Add custom steps for the install command """ + """Add custom steps for the install command""" def run(self): install_all_plugins(SOURCES, develop=False) From 3d6bf258e8e6ada0d3445f8091a79572412d9625 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 13:21:54 -0700 Subject: [PATCH 23/33] Different namespace setup, dolt post-install Signed-off-by: Max Hoffman --- .../flytekitplugins/dolt/__init__.py | 0 .../flytekitplugins/dolt/schema.py | 0 .../scripts/flytekit_install_dolt.sh | 0 plugins/{dolt => flyteplugins.dolt}/setup.py | 34 ++++++++++++++++++- plugins/setup.py | 2 +- plugins/sqlalchemy/setup.py | 2 +- 6 files changed, 35 insertions(+), 3 deletions(-) rename plugins/{dolt => flyteplugins.dolt}/flytekitplugins/dolt/__init__.py (100%) rename plugins/{dolt => flyteplugins.dolt}/flytekitplugins/dolt/schema.py (100%) rename plugins/{dolt => flyteplugins.dolt}/scripts/flytekit_install_dolt.sh (100%) rename plugins/{dolt => flyteplugins.dolt}/setup.py (50%) diff --git a/plugins/dolt/flytekitplugins/dolt/__init__.py b/plugins/flyteplugins.dolt/flytekitplugins/dolt/__init__.py similarity index 100% rename from plugins/dolt/flytekitplugins/dolt/__init__.py rename to plugins/flyteplugins.dolt/flytekitplugins/dolt/__init__.py diff --git a/plugins/dolt/flytekitplugins/dolt/schema.py b/plugins/flyteplugins.dolt/flytekitplugins/dolt/schema.py similarity index 100% rename from plugins/dolt/flytekitplugins/dolt/schema.py rename to plugins/flyteplugins.dolt/flytekitplugins/dolt/schema.py diff --git a/plugins/dolt/scripts/flytekit_install_dolt.sh b/plugins/flyteplugins.dolt/scripts/flytekit_install_dolt.sh similarity index 100% rename from plugins/dolt/scripts/flytekit_install_dolt.sh rename to plugins/flyteplugins.dolt/scripts/flytekit_install_dolt.sh diff --git a/plugins/dolt/setup.py b/plugins/flyteplugins.dolt/setup.py similarity index 50% rename from plugins/dolt/setup.py rename to plugins/flyteplugins.dolt/setup.py index ddda6f8813..82784d0039 100644 --- a/plugins/dolt/setup.py +++ b/plugins/flyteplugins.dolt/setup.py @@ -1,13 +1,44 @@ +import shlex +import subprocess +import urllib.request + from setuptools import setup +from setuptools.command.develop import develop PLUGIN_NAME = "dolt" -microlib_name = f"flytekitplugins-{PLUGIN_NAME}" +microlib_name = f"flytekitplugins.{PLUGIN_NAME}" plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] __version__ = "0.0.0+develop" + +class PostDevelopCommand(develop): + """Post-installation for development mode.""" + + def run(self): + develop.run(self) + install, _ = urllib.request.urlretrieve( + "https://github.com/liquidata-inc/dolt/releases/latest/download/install.sh" + ) + subprocess.call(shlex.split(f"chmod +x {install}")) + subprocess.call(shlex.split(f"sudo {install}")) + subprocess.call( + shlex.split("dolt config --global --add user.email bojack@horseman.com"), + ) + subprocess.call( + shlex.split("dolt config --global --add user.name 'Bojack Horseman'"), + ) + subprocess.call( + shlex.split( + "dolt config --global --add metrics.host " + "eventsapi.awsdev.ld-corp.com" + ), + ) + subprocess.call(shlex.split("dolt config --global --add metrics.port 443")) + + setup( name=microlib_name, version=__version__, @@ -17,6 +48,7 @@ namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, + cmdclass=dict(develop=PostDevelopCommand), license="apache2", python_requires=">=3.7", classifiers=[ diff --git a/plugins/setup.py b/plugins/setup.py index 96debbefbd..48f302bc5a 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -16,7 +16,7 @@ "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", "flytekitplugins-sqlalchemy": "sqlalchemy", - "flytekitplugins-dolt": "dolt", + "flytekitplugins-dolt": "flytekitplugins.dolt", } diff --git a/plugins/sqlalchemy/setup.py b/plugins/sqlalchemy/setup.py index e39b139552..ab296c3f3f 100644 --- a/plugins/sqlalchemy/setup.py +++ b/plugins/sqlalchemy/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -PLUGIN_NAME = "sqlalchemy" +PLUGIN_NAME = "flytesqlalchemy" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" From 4a08b2e7f887a23025641a61cc216f6ac8f88c41 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 13:52:38 -0700 Subject: [PATCH 24/33] Fix reqs Signed-off-by: Max Hoffman --- dev-requirements.txt | 44 +-------------------------------- doc-requirements.txt | 54 +++-------------------------------------- requirements-spark2.txt | 44 +++------------------------------ requirements.txt | 44 +++------------------------------ 4 files changed, 10 insertions(+), 176 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 573c7ecb9e..58f1bdf1e4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -23,20 +23,11 @@ bcrypt==3.2.0 # via # -c requirements.txt # paramiko -<<<<<<< HEAD black==21.4b2 -======= -black==20.8b1 ->>>>>>> master # via # -c requirements.txt # -r dev-requirements.in # flake8-black -<<<<<<< HEAD -======= -cached-property==1.5.2 - # via docker-compose ->>>>>>> master certifi==2020.12.5 # via # -c requirements.txt @@ -66,11 +57,7 @@ cryptography==3.4.7 # via # -c requirements.txt # paramiko -<<<<<<< HEAD dataclasses-json==0.5.3 -======= -dataclasses-json==0.5.2 ->>>>>>> master # via # -c requirements.txt # flytekit @@ -111,11 +98,7 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort -<<<<<<< HEAD flyteidl==0.18.39 -======= -flyteidl==0.18.38 ->>>>>>> master # via # -c requirements.txt # flytekit @@ -130,15 +113,7 @@ idna==2.10 importlib-metadata==4.0.1 # via # -c requirements.txt -<<<<<<< HEAD - # keyring -======= - # flake8 - # jsonschema # keyring - # pluggy - # pytest ->>>>>>> master iniconfig==1.1.1 # via pytest isort==5.8.0 @@ -259,7 +234,7 @@ python-dateutil==2.8.1 # croniter # flytekit # pandas -python-dotenv==0.17.0 +python-dotenv==0.17.1 # via docker-compose pytimeparse==1.1.8 # via @@ -286,11 +261,7 @@ requests==2.25.1 # docker-compose # flytekit # responses -<<<<<<< HEAD responses==0.13.3 -======= -responses==0.13.2 ->>>>>>> master # via # -c requirements.txt # flytekit @@ -339,11 +310,8 @@ toml==0.10.2 # coverage # pytest typed-ast==1.4.3 -<<<<<<< HEAD # via mypy typing-extensions==3.7.4.3 -======= ->>>>>>> master # via # -c requirements.txt # mypy @@ -351,16 +319,6 @@ typing-extensions==3.7.4.3 typing-inspect==0.6.0 # via # -c requirements.txt -<<<<<<< HEAD -======= - # black - # importlib-metadata - # mypy - # typing-inspect -typing-inspect==0.6.0 - # via - # -c requirements.txt ->>>>>>> master # dataclasses-json urllib3==1.25.11 # via diff --git a/doc-requirements.txt b/doc-requirements.txt index 6d52ce0ae3..cb85d210b3 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,11 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -<<<<<<< HEAD astroid==2.5.6 -======= -astroid==2.5.3 ->>>>>>> master # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -39,23 +35,13 @@ beautifulsoup4==4.9.3 # furo # sphinx-code-include # sphinx-material -<<<<<<< HEAD black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.60 +boto3==1.17.61 # via sagemaker-training -botocore==1.20.60 -======= -black==20.8b1 - # via papermill -bleach==3.3.0 - # via nbconvert -boto3==1.17.55 - # via sagemaker-training -botocore==1.20.55 ->>>>>>> master +botocore==1.20.61 # via # boto3 # s3transfer @@ -102,11 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -<<<<<<< HEAD flyteidl==0.18.39 -======= -flyteidl==0.18.38 ->>>>>>> master # via flytekit furo==2021.4.11b34 # via -r doc-requirements.in @@ -125,13 +107,7 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 -<<<<<<< HEAD # via keyring -======= - # via - # jsonschema - # keyring ->>>>>>> master inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -301,25 +277,15 @@ requests==2.25.1 # papermill # responses # sphinx -<<<<<<< HEAD responses==0.13.3 -======= -responses==0.13.2 ->>>>>>> master # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -<<<<<<< HEAD s3transfer==0.4.2 # via boto3 sagemaker-training==3.9.2 -======= -s3transfer==0.4.1 - # via boto3 -sagemaker-training==3.9.1 ->>>>>>> master # via flytekit scantree==0.0.1 # via dirhash @@ -348,11 +314,7 @@ sortedcontainers==2.3.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -<<<<<<< HEAD sphinx-autoapi==1.8.1 -======= -sphinx-autoapi==1.8.0 ->>>>>>> master # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in @@ -417,18 +379,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -<<<<<<< HEAD -======= -typed-ast==1.4.3 - # via - # astroid - # black ->>>>>>> master typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 87de3fcd13..1b5b235b75 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -24,23 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -<<<<<<< HEAD:requirements-spark2.txt black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.60 +boto3==1.17.61 # via sagemaker-training -botocore==1.20.60 -======= -black==20.8b1 - # via papermill -bleach==3.3.0 - # via nbconvert -boto3==1.17.55 - # via sagemaker-training -botocore==1.20.55 ->>>>>>> master:requirements-spark3.txt +botocore==1.20.61 # via # boto3 # s3transfer @@ -81,11 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -<<<<<<< HEAD:requirements-spark2.txt flyteidl==0.18.39 -======= -flyteidl==0.18.38 ->>>>>>> master:requirements-spark3.txt # via flytekit gevent==21.1.2 # via sagemaker-training @@ -98,13 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 -<<<<<<< HEAD:requirements-spark2.txt # via keyring -======= - # via - # jsonschema - # keyring ->>>>>>> master:requirements-spark3.txt inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -255,25 +235,15 @@ requests==2.25.1 # flytekit # papermill # responses -<<<<<<< HEAD:requirements-spark2.txt responses==0.13.3 -======= -responses==0.13.2 ->>>>>>> master:requirements-spark3.txt # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -<<<<<<< HEAD:requirements-spark2.txt s3transfer==0.4.2 # via boto3 sagemaker-training==3.9.2 -======= -s3transfer==0.4.1 - # via boto3 -sagemaker-training==3.9.1 ->>>>>>> master:requirements-spark3.txt # via flytekit scantree==0.0.1 # via dirhash @@ -326,16 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -<<<<<<< HEAD:requirements-spark2.txt -======= -typed-ast==1.4.3 - # via black ->>>>>>> master:requirements-spark3.txt typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 03afeea224..ef43f3eff4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,23 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -<<<<<<< HEAD black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.60 +boto3==1.17.61 # via sagemaker-training -botocore==1.20.60 -======= -black==20.8b1 - # via papermill -bleach==3.3.0 - # via nbconvert -boto3==1.17.55 - # via sagemaker-training -botocore==1.20.55 ->>>>>>> master +botocore==1.20.61 # via # boto3 # s3transfer @@ -81,11 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -<<<<<<< HEAD flyteidl==0.18.39 -======= -flyteidl==0.18.38 ->>>>>>> master # via flytekit gevent==21.1.2 # via sagemaker-training @@ -98,13 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 -<<<<<<< HEAD # via keyring -======= - # via - # jsonschema - # keyring ->>>>>>> master inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -255,25 +235,15 @@ requests==2.25.1 # flytekit # papermill # responses -<<<<<<< HEAD responses==0.13.3 -======= -responses==0.13.2 ->>>>>>> master # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -<<<<<<< HEAD s3transfer==0.4.2 # via boto3 sagemaker-training==3.9.2 -======= -s3transfer==0.4.1 - # via boto3 -sagemaker-training==3.9.1 ->>>>>>> master # via flytekit scantree==0.0.1 # via dirhash @@ -326,16 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -<<<<<<< HEAD -======= -typed-ast==1.4.3 - # via black ->>>>>>> master typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 From ff8ac75bc6fc50346da3b5757b1936875746d73a Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 13:56:17 -0700 Subject: [PATCH 25/33] Another merge error Signed-off-by: Max Hoffman --- plugins/sqlalchemy/setup.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/sqlalchemy/setup.py b/plugins/sqlalchemy/setup.py index 19b2e882c4..e39b139552 100644 --- a/plugins/sqlalchemy/setup.py +++ b/plugins/sqlalchemy/setup.py @@ -1,10 +1,6 @@ from setuptools import setup -<<<<<<< HEAD -PLUGIN_NAME = "flytesqlalchemy" -======= PLUGIN_NAME = "sqlalchemy" ->>>>>>> master microlib_name = f"flytekitplugins-{PLUGIN_NAME}" From 3da4213df96a933ad9229c0852908077b766011c Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 14:07:19 -0700 Subject: [PATCH 26/33] Flake error Signed-off-by: Max Hoffman --- plugins/flyteplugins.dolt/setup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/plugins/flyteplugins.dolt/setup.py b/plugins/flyteplugins.dolt/setup.py index 82784d0039..55459ac60a 100644 --- a/plugins/flyteplugins.dolt/setup.py +++ b/plugins/flyteplugins.dolt/setup.py @@ -31,10 +31,7 @@ def run(self): shlex.split("dolt config --global --add user.name 'Bojack Horseman'"), ) subprocess.call( - shlex.split( - "dolt config --global --add metrics.host " - "eventsapi.awsdev.ld-corp.com" - ), + shlex.split("dolt config --global --add metrics.host " "eventsapi.awsdev.ld-corp.com"), ) subprocess.call(shlex.split("dolt config --global --add metrics.port 443")) From aa8fb308e7b14700e8ce459a51cd5aac1b2a2216 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 14:12:27 -0700 Subject: [PATCH 27/33] Flake error Signed-off-by: Max Hoffman --- plugins/flyteplugins.dolt/setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/flyteplugins.dolt/setup.py b/plugins/flyteplugins.dolt/setup.py index 55459ac60a..89c683bf3b 100644 --- a/plugins/flyteplugins.dolt/setup.py +++ b/plugins/flyteplugins.dolt/setup.py @@ -31,7 +31,9 @@ def run(self): shlex.split("dolt config --global --add user.name 'Bojack Horseman'"), ) subprocess.call( - shlex.split("dolt config --global --add metrics.host " "eventsapi.awsdev.ld-corp.com"), + shlex.split( + "dolt config --global --add metrics.host eventsapi.awsdev.ld-corp.com" + ), ) subprocess.call(shlex.split("dolt config --global --add metrics.port 443")) From 2565c4e962de8b731e00a78903727861deddf81d Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 14:20:45 -0700 Subject: [PATCH 28/33] Flake error Signed-off-by: Max Hoffman --- plugins/flyteplugins.dolt/setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/flyteplugins.dolt/setup.py b/plugins/flyteplugins.dolt/setup.py index 89c683bf3b..bf479b0a29 100644 --- a/plugins/flyteplugins.dolt/setup.py +++ b/plugins/flyteplugins.dolt/setup.py @@ -24,18 +24,18 @@ def run(self): ) subprocess.call(shlex.split(f"chmod +x {install}")) subprocess.call(shlex.split(f"sudo {install}")) + + pref = "dolt config --global --add" subprocess.call( - shlex.split("dolt config --global --add user.email bojack@horseman.com"), + shlex.split(f"{pref} user.email bojack@horseman.com"), ) subprocess.call( - shlex.split("dolt config --global --add user.name 'Bojack Horseman'"), + shlex.split(f"{pref} user.name 'Bojack Horseman'"), ) subprocess.call( - shlex.split( - "dolt config --global --add metrics.host eventsapi.awsdev.ld-corp.com" - ), + shlex.split(f"{pref} metrics.host eventsapi.awsdev.ld-corp.com"), ) - subprocess.call(shlex.split("dolt config --global --add metrics.port 443")) + subprocess.call(shlex.split(f"{pref} metrics.port 443")) setup( From 6da7359267d59de56b2457f191bbbce1e4e17cd8 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 14:25:56 -0700 Subject: [PATCH 29/33] Fix plugin name Signed-off-by: Max Hoffman --- .../flytekitplugins/dolt/__init__.py | 0 .../flytekitplugins/dolt/schema.py | 0 .../scripts/flytekit_install_dolt.sh | 0 plugins/{flyteplugins.dolt => flytekitplugins.dolt}/setup.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename plugins/{flyteplugins.dolt => flytekitplugins.dolt}/flytekitplugins/dolt/__init__.py (100%) rename plugins/{flyteplugins.dolt => flytekitplugins.dolt}/flytekitplugins/dolt/schema.py (100%) rename plugins/{flyteplugins.dolt => flytekitplugins.dolt}/scripts/flytekit_install_dolt.sh (100%) rename plugins/{flyteplugins.dolt => flytekitplugins.dolt}/setup.py (100%) diff --git a/plugins/flyteplugins.dolt/flytekitplugins/dolt/__init__.py b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/__init__.py similarity index 100% rename from plugins/flyteplugins.dolt/flytekitplugins/dolt/__init__.py rename to plugins/flytekitplugins.dolt/flytekitplugins/dolt/__init__.py diff --git a/plugins/flyteplugins.dolt/flytekitplugins/dolt/schema.py b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py similarity index 100% rename from plugins/flyteplugins.dolt/flytekitplugins/dolt/schema.py rename to plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py diff --git a/plugins/flyteplugins.dolt/scripts/flytekit_install_dolt.sh b/plugins/flytekitplugins.dolt/scripts/flytekit_install_dolt.sh similarity index 100% rename from plugins/flyteplugins.dolt/scripts/flytekit_install_dolt.sh rename to plugins/flytekitplugins.dolt/scripts/flytekit_install_dolt.sh diff --git a/plugins/flyteplugins.dolt/setup.py b/plugins/flytekitplugins.dolt/setup.py similarity index 100% rename from plugins/flyteplugins.dolt/setup.py rename to plugins/flytekitplugins.dolt/setup.py From 51b1c18f4ab2613cfd5c9f8e92ddd55533c59fb7 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 15:13:52 -0700 Subject: [PATCH 30/33] Dep change Signed-off-by: Max Hoffman --- plugins/flytekitplugins.dolt/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekitplugins.dolt/setup.py b/plugins/flytekitplugins.dolt/setup.py index bf479b0a29..19df6207a1 100644 --- a/plugins/flytekitplugins.dolt/setup.py +++ b/plugins/flytekitplugins.dolt/setup.py @@ -9,7 +9,7 @@ microlib_name = f"flytekitplugins.{PLUGIN_NAME}" -plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.2"] __version__ = "0.0.0+develop" From c2a8df45cc45675d554143437198c1a585c39482 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 17:50:48 -0700 Subject: [PATCH 31/33] Bump dolt version for bug fix Signed-off-by: Max Hoffman --- plugins/flytekitplugins.dolt/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekitplugins.dolt/setup.py b/plugins/flytekitplugins.dolt/setup.py index 19df6207a1..bf479b0a29 100644 --- a/plugins/flytekitplugins.dolt/setup.py +++ b/plugins/flytekitplugins.dolt/setup.py @@ -9,7 +9,7 @@ microlib_name = f"flytekitplugins.{PLUGIN_NAME}" -plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.2"] +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] __version__ = "0.0.0+develop" From 07e90a55327722337e30e2fb01bbed51af0124fc Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 29 Apr 2021 18:35:36 -0700 Subject: [PATCH 32/33] Fix config bug Signed-off-by: Max Hoffman --- plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py index 01182b806c..05d20fab07 100644 --- a/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py +++ b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py @@ -54,7 +54,7 @@ def to_literal( raise AssertionError(f"Value cannot be converted to a table: {python_val}") conf = python_val.config - if python_val.data is not None and python_val.tablename is not None: + if python_val.data is not None and python_val.config.tablename is not None: db = dolt.Dolt(conf.db_path) with tempfile.NamedTemporaryFile() as f: python_val.data.to_csv(f.name, index=False) From fa298da84c71848305586e318cab4f3581eb1acb Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 30 Apr 2021 15:36:07 -0700 Subject: [PATCH 33/33] Make reqs Signed-off-by: Max Hoffman --- dev-requirements.txt | 15 ++------------- doc-requirements.txt | 21 ++++++--------------- requirements-spark2.txt | 17 +++++------------ requirements.txt | 17 +++++------------ 4 files changed, 18 insertions(+), 52 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 9f96cf8b50..bb83f44571 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -28,8 +28,6 @@ black==21.4b2 # -c requirements.txt # -r dev-requirements.in # flake8-black -cached-property==1.5.2 - # via docker-compose certifi==2020.12.5 # via # -c requirements.txt @@ -100,7 +98,7 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort -flyteidl==0.18.40 +flyteidl==0.18.41 # via # -c requirements.txt # flytekit @@ -115,11 +113,7 @@ idna==2.10 importlib-metadata==4.0.1 # via # -c requirements.txt - # flake8 - # jsonschema # keyring - # pluggy - # pytest iniconfig==1.1.1 # via pytest isort==5.8.0 @@ -316,15 +310,10 @@ toml==0.10.2 # coverage # pytest typed-ast==1.4.3 - # via - # -c requirements.txt - # black - # mypy + # via mypy typing-extensions==3.7.4.3 # via # -c requirements.txt - # black - # importlib-metadata # mypy # typing-inspect typing-inspect==0.6.0 diff --git a/doc-requirements.txt b/doc-requirements.txt index 12191d4c38..21b20c0ab0 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -39,9 +39,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit git+git://github.com/flyteorg/furo@main # via -r doc-requirements.in @@ -107,9 +107,7 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -252,7 +250,7 @@ python-dateutil==2.8.1 # flytekit # jupyter-client # pandas -python-slugify[unidecode]==4.0.1 +python-slugify[unidecode]==5.0.0 # via sphinx-material pytimeparse==1.1.8 # via flytekit @@ -381,15 +379,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via - # astroid - # black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 94ff91c6d1..376300bc9c 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -28,9 +28,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -298,13 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 6e6be1b068..bd47439a0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,9 +28,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -298,13 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11